Skip to content
Snippets Groups Projects
Commit cdf71b12 authored by Andrea Spadaccini's avatar Andrea Spadaccini
Browse files

Add the JoinDisjointDicts function to utils.algo


Add a function that joins two dictionaries, enforcing the constraint
that the two key sets should be disjoint. Also, add unit tests for this
function.

Signed-off-by: default avatarAndrea Spadaccini <spadaccio@google.com>
Reviewed-by: default avatarMichael Hanselmann <hansmi@google.com>
parent 97298dc9
No related branches found
No related tags found
No related merge requests found
......@@ -46,6 +46,29 @@ def UniqueSequence(seq):
return [i for i in seq if i not in seen and not seen.add(i)]
def JoinDisjointDicts(dict_a, dict_b):
"""Joins dictionaries with no conflicting keys.
Enforces the constraint that the two key sets must be disjoint, and then
merges the two dictionaries in a new dictionary that is returned to the
caller.
@type dict_a: dict
@param dict_a: the first dictionary
@type dict_b: dict
@param dict_b: the second dictionary
@rtype: dict
@return: a new dictionary containing all the key/value pairs contained in the
two dictionaries.
"""
assert not (set(dict_a) & set(dict_b)), ("Duplicate keys found while joining"
" %s and %s" % (dict_a, dict_b))
result = dict_a.copy()
result.update(dict_b)
return result
def FindDuplicates(seq):
"""Identifies duplicates in a list.
......
......@@ -272,5 +272,33 @@ class TestRunningTimeout(unittest.TestCase):
self.assertRaises(ValueError, algo.RunningTimeout, -1.0, True)
class TestJoinDisjointDicts(unittest.TestCase):
def setUp(self):
self.non_empty_dict = {"a": 1, "b": 2}
self.empty_dict = dict()
def testWithEmptyDicts(self):
self.assertEqual(self.empty_dict, algo.JoinDisjointDicts(self.empty_dict,
self.empty_dict))
self.assertEqual(self.non_empty_dict, algo.JoinDisjointDicts(
self.empty_dict, self.non_empty_dict))
self.assertEqual(self.non_empty_dict, algo.JoinDisjointDicts(
self.non_empty_dict, self.empty_dict))
def testNonDisjoint(self):
self.assertRaises(AssertionError, algo.JoinDisjointDicts,
self.non_empty_dict, self.non_empty_dict)
def testCommonCase(self):
dict_a = {"TEST1": 1, "TEST2": 2}
dict_b = {"TEST3": 3, "TEST4": 4}
result = dict_a.copy()
result.update(dict_b)
self.assertEqual(result, algo.JoinDisjointDicts(dict_a, dict_b))
self.assertEqual(result, algo.JoinDisjointDicts(dict_b, dict_a))
if __name__ == "__main__":
testutils.GanetiTestProgram()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment