diff --git a/lib/utils/algo.py b/lib/utils/algo.py index 9182e8e9a9b2e594a6ec9f82a5bd23ebbbca8d12..0a12ba4c787ed9927a67c0206e6b4fd9b74b6c79 100644 --- a/lib/utils/algo.py +++ b/lib/utils/algo.py @@ -46,6 +46,29 @@ def UniqueSequence(seq): return [i for i in seq if i not in seen and not seen.add(i)] +def JoinDisjointDicts(dict_a, dict_b): + """Joins dictionaries with no conflicting keys. + + Enforces the constraint that the two key sets must be disjoint, and then + merges the two dictionaries in a new dictionary that is returned to the + caller. + + @type dict_a: dict + @param dict_a: the first dictionary + @type dict_b: dict + @param dict_b: the second dictionary + @rtype: dict + @return: a new dictionary containing all the key/value pairs contained in the + two dictionaries. + + """ + assert not (set(dict_a) & set(dict_b)), ("Duplicate keys found while joining" + " %s and %s" % (dict_a, dict_b)) + result = dict_a.copy() + result.update(dict_b) + return result + + def FindDuplicates(seq): """Identifies duplicates in a list. diff --git a/test/ganeti.utils.algo_unittest.py b/test/ganeti.utils.algo_unittest.py index 96b4ff5a0d44c40b51cdb29e2b68c9ac4d9c3e11..6a3b6d6d536d7bd41e155483148b5af09b879ba9 100755 --- a/test/ganeti.utils.algo_unittest.py +++ b/test/ganeti.utils.algo_unittest.py @@ -272,5 +272,33 @@ class TestRunningTimeout(unittest.TestCase): self.assertRaises(ValueError, algo.RunningTimeout, -1.0, True) +class TestJoinDisjointDicts(unittest.TestCase): + def setUp(self): + self.non_empty_dict = {"a": 1, "b": 2} + self.empty_dict = dict() + + def testWithEmptyDicts(self): + self.assertEqual(self.empty_dict, algo.JoinDisjointDicts(self.empty_dict, + self.empty_dict)) + self.assertEqual(self.non_empty_dict, algo.JoinDisjointDicts( + self.empty_dict, self.non_empty_dict)) + self.assertEqual(self.non_empty_dict, algo.JoinDisjointDicts( + self.non_empty_dict, self.empty_dict)) + + def testNonDisjoint(self): + self.assertRaises(AssertionError, algo.JoinDisjointDicts, + self.non_empty_dict, self.non_empty_dict) + + def testCommonCase(self): + dict_a = {"TEST1": 1, "TEST2": 2} + dict_b = {"TEST3": 3, "TEST4": 4} + + result = dict_a.copy() + result.update(dict_b) + + self.assertEqual(result, algo.JoinDisjointDicts(dict_a, dict_b)) + self.assertEqual(result, algo.JoinDisjointDicts(dict_b, dict_a)) + + if __name__ == "__main__": testutils.GanetiTestProgram()