From 7f8900599939375920758167a4d757536bb7c99a Mon Sep 17 00:00:00 2001
From: Guido Trotter <ultrotter@google.com>
Date: Wed, 9 Jun 2010 19:35:57 +0100
Subject: [PATCH] _BaseCondition: allow saving/restoring state

Signed-off-by: Guido Trotter <ultrotter@google.com>
Reviewed-by: Michael Hanselmann <hansmi@google.com>
---
 lib/locking.py                  | 31 ++++++++++++++++++++++++-------
 test/ganeti.locking_unittest.py | 10 ++++++++++
 2 files changed, 34 insertions(+), 7 deletions(-)

diff --git a/lib/locking.py b/lib/locking.py
index 19c41eba6..28e640ff0 100644
--- a/lib/locking.py
+++ b/lib/locking.py
@@ -171,6 +171,9 @@ class _BaseCondition(object):
     "_lock",
     "acquire",
     "release",
+    "_is_owned",
+    "_acquire_restore",
+    "_release_save",
     ]
 
   def __init__(self, lock):
@@ -182,9 +185,18 @@ class _BaseCondition(object):
     """
     object.__init__(self)
 
-    # Recursive locks are not supported
-    assert not hasattr(lock, "_acquire_restore")
-    assert not hasattr(lock, "_release_save")
+    try:
+      self._release_save = lock._release_save
+    except AttributeError:
+      self._release_save = self._base_release_save
+    try:
+      self._acquire_restore = lock._acquire_restore
+    except AttributeError:
+      self._acquire_restore = self._base_acquire_restore
+    try:
+      self._is_owned = lock._is_owned
+    except AttributeError:
+      self._is_owned = self._base_is_owned
 
     self._lock = lock
 
@@ -192,16 +204,21 @@ class _BaseCondition(object):
     self.acquire = lock.acquire
     self.release = lock.release
 
-  def _is_owned(self):
+  def _base_is_owned(self):
     """Check whether lock is owned by current thread.
 
     """
     if self._lock.acquire(0):
       self._lock.release()
       return False
-
     return True
 
+  def _base_release_save(self):
+    self._lock.release()
+
+  def _base_acquire_restore(self, _):
+    self._lock.acquire()
+
   def _check_owned(self):
     """Raise an exception if the current thread doesn't own the lock.
 
@@ -280,13 +297,13 @@ class SingleNotifyPipeCondition(_BaseCondition):
         self._poller.register(self._read_fd, select.POLLHUP)
 
       wait_fn = self._waiter_class(self._poller, self._read_fd)
-      self.release()
+      state = self._release_save()
       try:
         # Wait for notification
         wait_fn(timeout)
       finally:
         # Re-acquire lock
-        self.acquire()
+        self._acquire_restore(state)
     finally:
       self._nwaiters -= 1
       if self._nwaiters == 0:
diff --git a/test/ganeti.locking_unittest.py b/test/ganeti.locking_unittest.py
index f767ccec9..9005b09eb 100755
--- a/test/ganeti.locking_unittest.py
+++ b/test/ganeti.locking_unittest.py
@@ -704,6 +704,9 @@ class TestSharedLockInCondition(_ThreadedTestCase):
   def setUp(self):
     _ThreadedTestCase.setUp(self)
     self.sl = locking.SharedLock()
+    self.setCondition()
+
+  def setCondition(self):
     self.cond = threading.Condition(self.sl)
 
   def testKeepMode(self):
@@ -719,6 +722,13 @@ class TestSharedLockInCondition(_ThreadedTestCase):
     self.cond.release()
 
 
+class TestSharedLockInPipeCondition(TestSharedLockInCondition):
+  """SharedLock as a pipe condition lock tests"""
+
+  def setCondition(self):
+    self.cond = locking.PipeCondition(self.sl)
+
+
 class TestSSynchronizedDecorator(_ThreadedTestCase):
   """Shared Lock Synchronized decorator test"""
 
-- 
GitLab