From cdf71b125aa14eeca8e9a5a28bb4d56a531b89ee Mon Sep 17 00:00:00 2001 From: Andrea Spadaccini <spadaccio@google.com> Date: Mon, 24 Oct 2011 15:01:20 +0100 Subject: [PATCH] Add the JoinDisjointDicts function to utils.algo Add a function that joins two dictionaries, enforcing the constraint that the two key sets should be disjoint. Also, add unit tests for this function. Signed-off-by: Andrea Spadaccini <spadaccio@google.com> Reviewed-by: Michael Hanselmann <hansmi@google.com> --- lib/utils/algo.py | 23 +++++++++++++++++++++++ test/ganeti.utils.algo_unittest.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/lib/utils/algo.py b/lib/utils/algo.py index 9182e8e9a..0a12ba4c7 100644 --- a/lib/utils/algo.py +++ b/lib/utils/algo.py @@ -46,6 +46,29 @@ def UniqueSequence(seq): return [i for i in seq if i not in seen and not seen.add(i)] +def JoinDisjointDicts(dict_a, dict_b): + """Joins dictionaries with no conflicting keys. + + Enforces the constraint that the two key sets must be disjoint, and then + merges the two dictionaries in a new dictionary that is returned to the + caller. + + @type dict_a: dict + @param dict_a: the first dictionary + @type dict_b: dict + @param dict_b: the second dictionary + @rtype: dict + @return: a new dictionary containing all the key/value pairs contained in the + two dictionaries. + + """ + assert not (set(dict_a) & set(dict_b)), ("Duplicate keys found while joining" + " %s and %s" % (dict_a, dict_b)) + result = dict_a.copy() + result.update(dict_b) + return result + + def FindDuplicates(seq): """Identifies duplicates in a list. diff --git a/test/ganeti.utils.algo_unittest.py b/test/ganeti.utils.algo_unittest.py index 96b4ff5a0..6a3b6d6d5 100755 --- a/test/ganeti.utils.algo_unittest.py +++ b/test/ganeti.utils.algo_unittest.py @@ -272,5 +272,33 @@ class TestRunningTimeout(unittest.TestCase): self.assertRaises(ValueError, algo.RunningTimeout, -1.0, True) +class TestJoinDisjointDicts(unittest.TestCase): + def setUp(self): + self.non_empty_dict = {"a": 1, "b": 2} + self.empty_dict = dict() + + def testWithEmptyDicts(self): + self.assertEqual(self.empty_dict, algo.JoinDisjointDicts(self.empty_dict, + self.empty_dict)) + self.assertEqual(self.non_empty_dict, algo.JoinDisjointDicts( + self.empty_dict, self.non_empty_dict)) + self.assertEqual(self.non_empty_dict, algo.JoinDisjointDicts( + self.non_empty_dict, self.empty_dict)) + + def testNonDisjoint(self): + self.assertRaises(AssertionError, algo.JoinDisjointDicts, + self.non_empty_dict, self.non_empty_dict) + + def testCommonCase(self): + dict_a = {"TEST1": 1, "TEST2": 2} + dict_b = {"TEST3": 3, "TEST4": 4} + + result = dict_a.copy() + result.update(dict_b) + + self.assertEqual(result, algo.JoinDisjointDicts(dict_a, dict_b)) + self.assertEqual(result, algo.JoinDisjointDicts(dict_b, dict_a)) + + if __name__ == "__main__": testutils.GanetiTestProgram() -- GitLab