From cf5d9c2943195f891a0bd66ec3a9e415ab10035d Mon Sep 17 00:00:00 2001
From: Giorgos Korfiatis <gkorf@grnet.gr>
Date: Thu, 7 May 2015 10:38:29 +0300
Subject: [PATCH] Context manager for thread safe dict

---
 agkyra/agkyra/syncer/common.py         | 35 --------------------
 agkyra/agkyra/syncer/heartbeat.py      | 44 --------------------------
 agkyra/agkyra/syncer/localfs_client.py | 17 ++++++----
 agkyra/agkyra/syncer/pithos_client.py  | 19 ++++++-----
 agkyra/agkyra/syncer/setup.py          |  5 ++-
 agkyra/agkyra/syncer/syncer.py         | 14 ++++----
 agkyra/agkyra/syncer/utils.py          | 19 +++++++++++
 7 files changed, 50 insertions(+), 103 deletions(-)
 delete mode 100644 agkyra/agkyra/syncer/heartbeat.py

diff --git a/agkyra/agkyra/syncer/common.py b/agkyra/agkyra/syncer/common.py
index d08a996..ebe60f0 100644
--- a/agkyra/agkyra/syncer/common.py
+++ b/agkyra/agkyra/syncer/common.py
@@ -14,7 +14,6 @@
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
 from collections import namedtuple
-import threading
 
 OBJECT_DIRSEP = '/'
 
@@ -62,37 +61,3 @@ class HardSyncError(SyncError):
 
 class CollisionError(HardSyncError):
     pass
-
-
-class LockedDict(object):
-    def __init__(self, *args, **kwargs):
-        self._Dict = {}
-        self._Lock = threading.Lock()
-
-    def put(self, key, value):
-        self._Lock.acquire()
-        self._Dict[key] = value
-        self._Lock.release()
-
-    def get(self, key, default=None):
-        self._Lock.acquire()
-        value = self._Dict.get(key, default)
-        self._Lock.release()
-        return value
-
-    def pop(self, key, d=None):
-        self._Lock.acquire()
-        value = self._Dict.pop(key, d)
-        self._Lock.release()
-        return value
-
-    def update(self, d):
-        self._Lock.acquire()
-        self._Dict.update(d)
-        self._Lock.release()
-
-    def keys(self):
-        self._Lock.acquire()
-        value = self._Dict.keys()
-        self._Lock.release()
-        return value
diff --git a/agkyra/agkyra/syncer/heartbeat.py b/agkyra/agkyra/syncer/heartbeat.py
deleted file mode 100644
index 5095dcd..0000000
--- a/agkyra/agkyra/syncer/heartbeat.py
+++ /dev/null
@@ -1,44 +0,0 @@
-# Copyright (C) 2015 GRNET S.A.
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU General Public License
-# along with this program.  If not, see <http://www.gnu.org/licenses/>.
-
-import threading
-
-
-class HeartBeat(object):
-    def __init__(self, *args, **kwargs):
-        self._LOG = {}
-        self._LOCK = threading.Lock()
-
-    def lock(self):
-        class Lock(object):
-            def __enter__(this):
-                self._LOCK.acquire()
-                return this
-
-            def __exit__(this, exctype, value, traceback):
-                self._LOCK.release()
-                if value is not None:
-                    raise value
-
-            def get(this, key):
-                return self._LOG.get(key)
-
-            def set(this, key, value):
-                self._LOG[key] = value
-
-            def delete(this, key):
-                self._LOG.pop(key)
-
-        return Lock()
diff --git a/agkyra/agkyra/syncer/localfs_client.py b/agkyra/agkyra/syncer/localfs_client.py
index bd948c4..84bbebe 100644
--- a/agkyra/agkyra/syncer/localfs_client.py
+++ b/agkyra/agkyra/syncer/localfs_client.py
@@ -464,13 +464,14 @@ class LocalfsFileClient(FileClient):
         self.ROOTPATH = settings.local_root_path
         self.CACHEPATH = settings.cache_path
         self.get_db = settings.get_db
-        self.probe_candidates = common.LockedDict()
+        self.probe_candidates = utils.ThreadSafeDict()
 
     def list_candidate_files(self, forced=False):
-        if forced:
-            candidates = self.walk_filesystem()
-            self.probe_candidates.update(candidates)
-        return self.probe_candidates.keys()
+        with self.probe_candidates.lock() as d:
+            if forced:
+                candidates = self.walk_filesystem()
+                d.update(candidates)
+            return d.keys()
 
     def walk_filesystem(self):
         db = self.get_db()
@@ -506,7 +507,8 @@ class LocalfsFileClient(FileClient):
         return exclude_pattern.match(final_part)
 
     def start_probing_file(self, objname, old_state, ref_state, callback=None):
-        cached_info = self.probe_candidates.pop(objname)
+        with self.probe_candidates.lock() as d:
+            cached_info = d.pop(objname, None)
         if self.exclude_file(objname):
             logger.warning("Ignoring probe archive: %s, object: %s" %
                            (old_state.archive, objname))
@@ -530,7 +532,8 @@ class LocalfsFileClient(FileClient):
         def handle_path(path):
             rel_path = os.path.relpath(path, start=self.ROOTPATH)
             objname = utils.to_standard_sep(rel_path)
-            self.probe_candidates.put(objname, None)
+            with self.probe_candidates.lock() as d:
+                d[objname] = None
 
         class EventHandler(FileSystemEventHandler):
             def on_created(this, event):
diff --git a/agkyra/agkyra/syncer/pithos_client.py b/agkyra/agkyra/syncer/pithos_client.py
index ac573ba..14b54b6 100644
--- a/agkyra/agkyra/syncer/pithos_client.py
+++ b/agkyra/agkyra/syncer/pithos_client.py
@@ -39,7 +39,7 @@ def heartbeat_event(settings, heartbeat, objname):
             assert beat is not None
             new_beat = {"ident": beat["ident"],
                         "tstamp": utils.time_stamp()}
-            hb.set(objname, new_beat)
+            hb[objname] = new_beat
             logger.debug("HEARTBEAT '%s' %s" % (objname, new_beat))
 
     def go():
@@ -252,13 +252,14 @@ class PithosFileClient(FileClient):
         self.get_db = settings.get_db
         self.endpoint = settings.endpoint
         self.last_modification = "0000-00-00"
-        self.probe_candidates = common.LockedDict()
+        self.probe_candidates = utils.ThreadSafeDict()
 
     def list_candidate_files(self, forced=False):
-        if forced:
-            candidates = self.get_pithos_candidates()
-            self.probe_candidates.update(candidates)
-        return self.probe_candidates.keys()
+        with self.probe_candidates.lock() as d:
+            if forced:
+                candidates = self.get_pithos_candidates()
+                d.update(candidates)
+            return d.keys()
 
     def get_pithos_candidates(self, last_modified=None):
         db = self.get_db()
@@ -302,7 +303,8 @@ class PithosFileClient(FileClient):
             def run_body(this):
                 candidates = self.get_pithos_candidates(
                     last_modified=self.last_modification)
-                self.probe_candidates.update(candidates)
+                with self.probe_candidates.lock() as d:
+                    d.update(candidates)
                 time.sleep(interval)
         return utils.start_daemon(PollPithosThread)
 
@@ -327,7 +329,8 @@ class PithosFileClient(FileClient):
 
     def start_probing_file(self, objname, old_state, ref_state, callback=None):
         info = old_state.info
-        cached_info = self.probe_candidates.pop(objname)
+        with self.probe_candidates.lock() as d:
+            cached_info = d.pop(objname, None)
         if exclude_pattern.match(objname):
             logger.warning("Ignoring probe archive: %s, object: '%s'" %
                            (old_state.archive, objname))
diff --git a/agkyra/agkyra/syncer/setup.py b/agkyra/agkyra/syncer/setup.py
index b96b210..3bb5f86 100644
--- a/agkyra/agkyra/syncer/setup.py
+++ b/agkyra/agkyra/syncer/setup.py
@@ -17,9 +17,8 @@ import os
 import threading
 import logging
 
-from agkyra.syncer.utils import join_path
+from agkyra.syncer.utils import join_path, ThreadSafeDict
 from agkyra.syncer.database import SqliteFileStateDB
-from agkyra.syncer.heartbeat import HeartBeat
 from agkyra.syncer.messaging import Messager
 
 from kamaki.clients import ClientError
@@ -94,7 +93,7 @@ class SyncerSettings():
                                           self.cache_fetch_name)
         self.create_dir(self.cache_fetch_path)
 
-        self.heartbeat = HeartBeat()
+        self.heartbeat = ThreadSafeDict()
         self.action_max_wait = kwargs.get("action_max_wait",
                                           DEFAULT_ACTION_MAX_WAIT)
         self.pithos_list_interval = kwargs.get("pithos_list_interval",
diff --git a/agkyra/agkyra/syncer/syncer.py b/agkyra/agkyra/syncer/syncer.py
index fd4c9a2..ab0c830 100644
--- a/agkyra/agkyra/syncer/syncer.py
+++ b/agkyra/agkyra/syncer/syncer.py
@@ -45,7 +45,7 @@ class FileSyncer(object):
         self.notifiers = {}
         self.decide_thread = None
         self.sync_threads = []
-        self.failed_serials = common.LockedDict()
+        self.failed_serials = utils.ThreadSafeDict()
         self.messager = settings.messager
         self.heartbeat = self.settings.heartbeat
 
@@ -160,7 +160,7 @@ class FileSyncer(object):
         if states is not None:
             with self.heartbeat.lock() as hb:
                 beat = {"ident": ident, "tstamp": utils.time_stamp()}
-                hb.set(objname, beat)
+                hb[objname] = beat
         return states
 
     def _do_decide_file_sync(self, objname, master, slave, ident):
@@ -193,7 +193,8 @@ class FileSyncer(object):
                                    (objname, beat))
 
         if decision_serial != sync_serial:
-            failed_sync = self.failed_serials.get((decision_serial, objname))
+            with self.failed_serials.lock() as d:
+                failed_sync = d.get((decision_serial, objname))
             if failed_sync is None:
                 logger.warning(
                     "Already decided: '%s', decision: %s, sync: %s" %
@@ -270,8 +271,9 @@ class FileSyncer(object):
             "Marking failed serial %s for archive: %s, object: '%s'" %
             (serial, state.archive, objname))
         with self.heartbeat.lock() as hb:
-            hb.delete(objname)
-        self.failed_serials.put((serial, objname), state)
+            hb.pop(objname)
+        with self.failed_serials.lock() as d:
+            d[(serial, objname)] = state
 
     def update_state(self, old_state, new_state):
         db = self.get_db()
@@ -285,7 +287,7 @@ class FileSyncer(object):
         objname = synced_source_state.objname
         target = synced_target_state.archive
         with self.heartbeat.lock() as hb:
-            hb.delete(objname)
+            hb.pop(objname)
         msg = messaging.AckSyncMessage(
             archive=target, objname=objname, serial=serial,
             logger=logger)
diff --git a/agkyra/agkyra/syncer/utils.py b/agkyra/agkyra/syncer/utils.py
index 0035932..619aaf7 100644
--- a/agkyra/agkyra/syncer/utils.py
+++ b/agkyra/agkyra/syncer/utils.py
@@ -16,6 +16,7 @@
 import os
 import hashlib
 import datetime
+import threading
 import watchdog.utils
 
 from agkyra.syncer.common import OBJECT_DIRSEP
@@ -86,3 +87,21 @@ def start_daemon(threadClass):
     thread.daemon = True
     thread.start()
     return thread
+
+
+class ThreadSafeDict(object):
+    def __init__(self, *args, **kwargs):
+        self._DICT = {}
+        self._LOCK = threading.Lock()
+
+    def lock(self):
+        class Lock(object):
+            def __enter__(this):
+                self._LOCK.acquire()
+                return self._DICT
+
+            def __exit__(this, exctype, value, traceback):
+                self._LOCK.release()
+                if value is not None:
+                    raise value
+        return Lock()
-- 
GitLab