From cdf71b125aa14eeca8e9a5a28bb4d56a531b89ee Mon Sep 17 00:00:00 2001
From: Andrea Spadaccini <spadaccio@google.com>
Date: Mon, 24 Oct 2011 15:01:20 +0100
Subject: [PATCH] Add the JoinDisjointDicts function to utils.algo

Add a function that joins two dictionaries, enforcing the constraint
that the two key sets should be disjoint. Also, add unit tests for this
function.

Signed-off-by: Andrea Spadaccini <spadaccio@google.com>
Reviewed-by: Michael Hanselmann <hansmi@google.com>
---
 lib/utils/algo.py                  | 23 +++++++++++++++++++++++
 test/ganeti.utils.algo_unittest.py | 28 ++++++++++++++++++++++++++++
 2 files changed, 51 insertions(+)

diff --git a/lib/utils/algo.py b/lib/utils/algo.py
index 9182e8e9a..0a12ba4c7 100644
--- a/lib/utils/algo.py
+++ b/lib/utils/algo.py
@@ -46,6 +46,29 @@ def UniqueSequence(seq):
   return [i for i in seq if i not in seen and not seen.add(i)]
 
 
+def JoinDisjointDicts(dict_a, dict_b):
+  """Joins dictionaries with no conflicting keys.
+
+  Enforces the constraint that the two key sets must be disjoint, and then
+  merges the two dictionaries in a new dictionary that is returned to the
+  caller.
+
+  @type dict_a: dict
+  @param dict_a: the first dictionary
+  @type dict_b: dict
+  @param dict_b: the second dictionary
+  @rtype: dict
+  @return: a new dictionary containing all the key/value pairs contained in the
+  two dictionaries.
+
+  """
+  assert not (set(dict_a) & set(dict_b)), ("Duplicate keys found while joining"
+                                           " %s and %s" % (dict_a, dict_b))
+  result = dict_a.copy()
+  result.update(dict_b)
+  return result
+
+
 def FindDuplicates(seq):
   """Identifies duplicates in a list.
 
diff --git a/test/ganeti.utils.algo_unittest.py b/test/ganeti.utils.algo_unittest.py
index 96b4ff5a0..6a3b6d6d5 100755
--- a/test/ganeti.utils.algo_unittest.py
+++ b/test/ganeti.utils.algo_unittest.py
@@ -272,5 +272,33 @@ class TestRunningTimeout(unittest.TestCase):
     self.assertRaises(ValueError, algo.RunningTimeout, -1.0, True)
 
 
+class TestJoinDisjointDicts(unittest.TestCase):
+  def setUp(self):
+    self.non_empty_dict = {"a": 1, "b": 2}
+    self.empty_dict = dict()
+
+  def testWithEmptyDicts(self):
+    self.assertEqual(self.empty_dict, algo.JoinDisjointDicts(self.empty_dict,
+      self.empty_dict))
+    self.assertEqual(self.non_empty_dict, algo.JoinDisjointDicts(
+      self.empty_dict, self.non_empty_dict))
+    self.assertEqual(self.non_empty_dict, algo.JoinDisjointDicts(
+      self.non_empty_dict, self.empty_dict))
+
+  def testNonDisjoint(self):
+    self.assertRaises(AssertionError, algo.JoinDisjointDicts,
+      self.non_empty_dict, self.non_empty_dict)
+
+  def testCommonCase(self):
+    dict_a = {"TEST1": 1, "TEST2": 2}
+    dict_b = {"TEST3": 3, "TEST4": 4}
+
+    result = dict_a.copy()
+    result.update(dict_b)
+
+    self.assertEqual(result, algo.JoinDisjointDicts(dict_a, dict_b))
+    self.assertEqual(result, algo.JoinDisjointDicts(dict_b, dict_a))
+
+
 if __name__ == "__main__":
   testutils.GanetiTestProgram()
-- 
GitLab