Commit 5c6b038b authored by Giorgos Korfiatis's avatar Giorgos Korfiatis

Wait threads for a limited time when shutting down

parent 194aeb03
...@@ -304,15 +304,15 @@ class WebSocketProtocol(WebSocket): ...@@ -304,15 +304,15 @@ class WebSocketProtocol(WebSocket):
with database.TransactedConnection(self.session_db) as db: with database.TransactedConnection(self.session_db) as db:
db.unregister_heartbeat(self.ui_id) db.unregister_heartbeat(self.ui_id)
def shutdown_syncer(self, syncer_key=0): def shutdown_syncer(self, syncer_key=0, timeout=None):
"""Shutdown the syncer backend object""" """Shutdown the syncer backend object"""
LOGGER.debug('Shutdown syncer') LOGGER.debug('Shutdown syncer')
with SYNCERS.lock() as d: with SYNCERS.lock() as d:
syncer = d.pop(syncer_key, None) syncer = d.pop(syncer_key, None)
if syncer and self.can_sync(): if syncer and self.can_sync():
syncer.stop_all_daemons() remaining = syncer.stop_all_daemons(timeout=timeout)
LOGGER.debug('Wait open syncs to complete') LOGGER.debug('Wait open syncs to complete')
syncer.wait_sync_threads() syncer.wait_sync_threads(timeout=remaining)
def _get_default_sync(self): def _get_default_sync(self):
"""Get global.default_sync or pick the first sync as default """Get global.default_sync or pick the first sync as default
...@@ -619,7 +619,7 @@ class WebSocketProtocol(WebSocket): ...@@ -619,7 +619,7 @@ class WebSocketProtocol(WebSocket):
if action == 'shutdown': if action == 'shutdown':
# Clean db to cause syncer backend to shut down # Clean db to cause syncer backend to shut down
self.set_status(code=STATUS['SHUTTING DOWN']) self.set_status(code=STATUS['SHUTTING DOWN'])
self.shutdown_syncer() self.shutdown_syncer(timeout=5)
self.clean_db() self.clean_db()
return return
{ {
......
...@@ -74,7 +74,7 @@ class FileSyncer(object): ...@@ -74,7 +74,7 @@ class FileSyncer(object):
else: else:
logger.info("Notifier %s already up" % signature) logger.info("Notifier %s already up" % signature)
def stop_notifiers(self): def stop_notifiers(self, timeout=None):
for notifier in self.notifiers.values(): for notifier in self.notifiers.values():
try: try:
notifier.stop() notifier.stop()
...@@ -88,25 +88,24 @@ class FileSyncer(object): ...@@ -88,25 +88,24 @@ class FileSyncer(object):
# when attempting to stop a notifier after the watched # when attempting to stop a notifier after the watched
# directory has been deleted # directory has been deleted
logger.warning("Ignored TypeError: %s" % e) logger.warning("Ignored TypeError: %s" % e)
for notifier in self.notifiers.values(): return utils.wait_joins(self.notifiers.values(), timeout)
notifier.join()
def start_decide(self): def start_decide(self):
if not self.decide_active: if not self.decide_active:
self.decide_thread = self._poll_decide() self.decide_thread = self._poll_decide()
def stop_decide(self): def stop_decide(self, timeout=None):
if self.decide_active: if self.decide_active:
self.decide_thread.stop() self.decide_thread.stop()
self.decide_thread.join() return utils.wait_joins([self.decide_thread], timeout)
return timeout
def stop_all_daemons(self): def stop_all_daemons(self, timeout=None):
self.stop_decide() remaining = self.stop_decide(timeout=timeout)
self.stop_notifiers() return self.stop_notifiers(timeout=remaining)
def wait_sync_threads(self): def wait_sync_threads(self, timeout=None):
for thread in self.sync_threads: return utils.wait_joins(self.sync_threads, timeout=timeout)
thread.join()
def get_next_message(self, block=False): def get_next_message(self, block=False):
return self.messager.get(block=block) return self.messager.get(block=block)
......
...@@ -139,6 +139,22 @@ class StoppableThread(BaseStoppableThread): ...@@ -139,6 +139,22 @@ class StoppableThread(BaseStoppableThread):
self.run_body = target self.run_body = target
def _remaining(timeout, total_elapsed):
return max(0, timeout - total_elapsed) if timeout is not None else None
def wait_joins(threads, timeout=None):
total_elapsed = 0
for thread in threads:
tbefore = datetime.datetime.now()
remaining_timeout = _remaining(timeout, total_elapsed)
thread.join(timeout=remaining_timeout)
tafter = datetime.datetime.now()
elapsed = (tafter - tbefore).total_seconds()
total_elapsed += elapsed
return _remaining(timeout, total_elapsed)
class ThreadSafeDict(object): class ThreadSafeDict(object):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self._DICT = {} self._DICT = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment