diff --git a/lib/locking.py b/lib/locking.py index 16d2c46ecf01981e6cfd5b9a049f54b23fd9beb3..901fdac80489133ee1fc8343b5fbfe7bac1a7d90 100644 --- a/lib/locking.py +++ b/lib/locking.py @@ -223,6 +223,122 @@ class _SingleActionPipeCondition(object): self._Cleanup() +class _PipeCondition(object): + """Group-only non-polling condition with counters. + + This condition class uses pipes and poll, internally, to be able to wait for + notification with a timeout, without resorting to polling. It is almost + compatible with Python's threading.Condition, but only supports notifyAll and + non-recursive locks. As an additional features it's able to report whether + there are any waiting threads. + + """ + __slots__ = [ + "_lock", + "_nwaiters", + "_pipe", + "acquire", + "release", + ] + + _pipe_class = _SingleActionPipeCondition + + def __init__(self, lock): + """Initializes this class. + + """ + object.__init__(self) + + # Recursive locks are not supported + assert not hasattr(lock, "_acquire_restore") + assert not hasattr(lock, "_release_save") + + self._lock = lock + + # Export the lock's acquire() and release() methods + self.acquire = lock.acquire + self.release = lock.release + + self._nwaiters = 0 + self._pipe = None + + def _is_owned(self): + """Check whether lock is owned by current thread. + + """ + if self._lock.acquire(0): + self._lock.release() + return False + + return True + + def _check_owned(self): + """Raise an exception if the current thread doesn't own the lock. + + """ + if not self._is_owned(): + raise RuntimeError("cannot work with un-aquired lock") + + def wait(self, timeout=None): + """Wait for a notification. + + @type timeout: float or None + @param timeout: Waiting timeout (can be None) + + """ + self._check_owned() + + if not self._pipe: + self._pipe = self._pipe_class() + + # Keep local reference to the pipe. It could be replaced by another thread + # notifying while we're waiting. + pipe = self._pipe + + assert self._nwaiters >= 0 + self._nwaiters += 1 + try: + # Get function to wait on the pipe + wait_fn = pipe.StartWaiting() + try: + # Release lock while waiting + self.release() + try: + # Wait for notification + wait_fn(timeout) + finally: + # Re-acquire lock + self.acquire() + finally: + # Destroy pipe if this was the last waiter and the current pipe is + # still the same. The same pipe cannot be reused after cleanup. + if pipe.DoneWaiting() and pipe == self._pipe: + self._pipe = None + finally: + assert self._nwaiters > 0 + self._nwaiters -= 1 + + def notifyAll(self): + """Notify all currently waiting threads. + + """ + self._check_owned() + + # Notify and forget pipe. A new one will be created on the next call to + # wait. + if self._pipe is not None: + self._pipe.notifyAll() + self._pipe = None + + def has_waiting(self): + """Returns whether there are active waiters. + + """ + self._check_owned() + + return bool(self._nwaiters) + + class _CountingCondition(object): """Wrapper for Python's built-in threading.Condition class. diff --git a/test/ganeti.locking_unittest.py b/test/ganeti.locking_unittest.py index 5b9b2c93e20b671041bcf438bb62f9f6c9b2b911..36d061839f121adffaae897cdbdfc4d5cbe290be 100755 --- a/test/ganeti.locking_unittest.py +++ b/test/ganeti.locking_unittest.py @@ -69,6 +69,131 @@ class _ThreadedTestCase(unittest.TestCase): self.threads = [] +class TestPipeCondition(_ThreadedTestCase): + """_PipeCondition tests""" + + def setUp(self): + _ThreadedTestCase.setUp(self) + self.lock = threading.Lock() + self.cond = locking._PipeCondition(self.lock) + self.done = Queue.Queue(0) + + def testAcquireRelease(self): + self.assert_(not self.cond._is_owned()) + self.assertRaises(RuntimeError, self.cond.wait) + self.assertRaises(RuntimeError, self.cond.notifyAll) + + self.cond.acquire() + self.assert_(self.cond._is_owned()) + self.cond.notifyAll() + self.assert_(self.cond._is_owned()) + self.cond.release() + + self.assert_(not self.cond._is_owned()) + self.assertRaises(RuntimeError, self.cond.wait) + self.assertRaises(RuntimeError, self.cond.notifyAll) + + def testNotification(self): + def _NotifyAll(): + self.cond.acquire() + self.cond.notifyAll() + self.cond.release() + + self.cond.acquire() + self._addThread(target=_NotifyAll) + self.cond.wait() + self.assert_(self.cond._is_owned()) + self.cond.release() + self.assert_(not self.cond._is_owned()) + + def _TestWait(self, fn): + self._addThread(target=fn) + self._addThread(target=fn) + self._addThread(target=fn) + + # Wait for threads to be waiting + self.assertEqual(self.done.get(True, 1), "A") + self.assertEqual(self.done.get(True, 1), "A") + self.assertEqual(self.done.get(True, 1), "A") + + self.assertRaises(Queue.Empty, self.done.get_nowait) + + self.cond.acquire() + self.assertEqual(self.cond._nwaiters, 3) + # This new thread can"t acquire the lock, and thus call wait, before we + # release it + self._addThread(target=fn) + self.cond.notifyAll() + self.assertRaises(Queue.Empty, self.done.get_nowait) + self.cond.release() + + # We should now get 3 W and 1 A (for the new thread) in whatever order + w = 0 + a = 0 + for i in range(4): + got = self.done.get(True, 1) + if got == "W": + w += 1 + elif got == "A": + a += 1 + else: + self.fail("Got %s on the done queue" % got) + + self.assertEqual(w, 3) + self.assertEqual(a, 1) + + self.cond.acquire() + self.cond.notifyAll() + self.cond.release() + self._waitThreads() + self.assertEqual(self.done.get_nowait(), "W") + self.assertRaises(Queue.Empty, self.done.get_nowait) + + def testBlockingWait(self): + def _BlockingWait(): + self.cond.acquire() + self.done.put("A") + self.cond.wait() + self.cond.release() + self.done.put("W") + + self._TestWait(_BlockingWait) + + def testLongTimeoutWait(self): + def _Helper(): + self.cond.acquire() + self.done.put("A") + self.cond.wait(15.0) + self.cond.release() + self.done.put("W") + + self._TestWait(_Helper) + + def _TimeoutWait(self, timeout, check): + self.cond.acquire() + self.cond.wait(timeout) + self.cond.release() + self.done.put(check) + + def testShortTimeoutWait(self): + self._addThread(target=self._TimeoutWait, args=(0.1, "T1")) + self._addThread(target=self._TimeoutWait, args=(0.1, "T1")) + self._waitThreads() + self.assertEqual(self.done.get_nowait(), "T1") + self.assertEqual(self.done.get_nowait(), "T1") + self.assertRaises(Queue.Empty, self.done.get_nowait) + + def testZeroTimeoutWait(self): + self._addThread(target=self._TimeoutWait, args=(0, "T0")) + self._addThread(target=self._TimeoutWait, args=(0, "T0")) + self._addThread(target=self._TimeoutWait, args=(0, "T0")) + self._waitThreads() + self.assertEqual(self.done.get_nowait(), "T0") + self.assertEqual(self.done.get_nowait(), "T0") + self.assertEqual(self.done.get_nowait(), "T0") + self.assertRaises(Queue.Empty, self.done.get_nowait) + + class TestSingleActionPipeCondition(unittest.TestCase): """_SingleActionPipeCondition tests"""