From d2aff862c9b9974f2a9d345ddbbcc6baad76f87b Mon Sep 17 00:00:00 2001
From: Guido Trotter <ultrotter@google.com>
Date: Thu, 11 Sep 2008 09:44:08 +0000
Subject: [PATCH] LockSet: forbid add() on a partially owned set

This patch bans add() on a half-acquired set. This behavior was
previously possible, but created a deadlock if someone tried to acquire
the set-lock in the meantime, and thus is now forbidden. The
testAddRemove unit test is fixed for this new behavior, and includes a
few more lines of testing and a new testConcurrentSetLockAdd function
tests its behavior in the concurrent case.

Reviewed-by: imsnah
---
 lib/locking.py                  |  6 ++---
 test/ganeti.locking_unittest.py | 42 +++++++++++++++++++++++++++++----
 2 files changed, 41 insertions(+), 7 deletions(-)

diff --git a/lib/locking.py b/lib/locking.py
index 783ec542b..fd84fe663 100644
--- a/lib/locking.py
+++ b/lib/locking.py
@@ -560,9 +560,9 @@ class LockSet:
       shared: is the pre-acquisition shared?
 
     """
-
-    assert not self.__lock._is_owned(shared=1), (
-           "Cannot add new elements while sharing the set-lock")
+    # Check we don't already own locks at this level
+    assert not self._is_owned() or self.__lock._is_owned(shared=0), \
+      "Cannot add locks if the set is only partially owned, or shared"
 
     # Support passing in a single resource to add rather than many
     if isinstance(names, basestring):
diff --git a/test/ganeti.locking_unittest.py b/test/ganeti.locking_unittest.py
index a43e1b789..94843fecf 100755
--- a/test/ganeti.locking_unittest.py
+++ b/test/ganeti.locking_unittest.py
@@ -354,12 +354,20 @@ class TestLockSet(unittest.TestCase):
     self.assert_('five' not in self.ls._names())
     self.assert_('six' not in self.ls._names())
     self.assertEquals(self.ls._list_owned(), set(['seven']))
-    self.ls.add('eight', acquired=1, shared=1)
-    self.assert_('eight' in self.ls._names())
-    self.assertEquals(self.ls._list_owned(), set(['seven', 'eight']))
+    self.assertRaises(AssertionError, self.ls.add, 'eight', acquired=1)
     self.ls.remove('seven')
     self.assert_('seven' not in self.ls._names())
-    self.assertEquals(self.ls._list_owned(), set(['eight']))
+    self.assertEquals(self.ls._list_owned(), set([]))
+    self.ls.acquire(None, shared=1)
+    self.assertRaises(AssertionError, self.ls.add, 'eight')
+    self.ls.release()
+    self.ls.acquire(None)
+    self.ls.add('eight', acquired=1)
+    self.assert_('eight' in self.ls._names())
+    self.assert_('eight' in self.ls._list_owned())
+    self.ls.add('nine')
+    self.assert_('nine' in self.ls._names())
+    self.assert_('nine' not in self.ls._list_owned())
     self.ls.release()
     self.ls.remove(['two'])
     self.assert_('two' not in self.ls._names())
@@ -550,6 +558,32 @@ class TestLockSet(unittest.TestCase):
     self.assertEqual(self.done.get(True, 1), 'DONE')
     self.assertEqual(self.done.get(True, 1), 'DONE')
 
+  def testConcurrentSetLockAdd(self):
+    self.ls.acquire('one')
+    # Another thread wants the whole SetLock
+    Thread(target=self._doLockSet, args=(None, 0)).start()
+    Thread(target=self._doLockSet, args=(None, 1)).start()
+    self.assertRaises(Queue.Empty, self.done.get, True, 0.2)
+    self.assertRaises(AssertionError, self.ls.add, 'four')
+    self.ls.release()
+    self.assertEqual(self.done.get(True, 1), 'DONE')
+    self.assertEqual(self.done.get(True, 1), 'DONE')
+    self.ls.acquire(None)
+    Thread(target=self._doLockSet, args=(None, 0)).start()
+    Thread(target=self._doLockSet, args=(None, 1)).start()
+    self.assertRaises(Queue.Empty, self.done.get, True, 0.2)
+    self.ls.add('four')
+    self.ls.add('five', acquired=1)
+    self.ls.add('six', acquired=1, shared=1)
+    self.assertEquals(self.ls._list_owned(),
+      set(['one', 'two', 'three', 'five', 'six']))
+    self.assertEquals(self.ls._is_owned(), True)
+    self.assertEquals(self.ls._names(),
+      set(['one', 'two', 'three', 'four', 'five', 'six']))
+    self.ls.release()
+    self.assertEqual(self.done.get(True, 1), 'DONE')
+    self.assertEqual(self.done.get(True, 1), 'DONE')
+
   def testEmptyLockSet(self):
     # get the set-lock
     self.assertEqual(self.ls.acquire(None), set(['one', 'two', 'three']))
-- 
GitLab