diff --git a/lib/ht.py b/lib/ht.py index f2fc0b388e4698bd0d23452c39a822fedd0a41ac..a0c787e397e9df5681f9312be1806a142bd4f5ba 100644 --- a/lib/ht.py +++ b/lib/ht.py @@ -292,3 +292,51 @@ def TDictOf(key_type, val_type): compat.all(val_type(v) for v in container.values())) return desc(TAnd(TDict, fn)) + + +def _TStrictDictCheck(require_all, exclusive, items, val): + """Helper function for L{TStrictDict}. + + """ + notfound_fn = lambda _: not exclusive + + if require_all and not frozenset(val.keys()).issuperset(items.keys()): + # Requires items not found in value + return False + + return compat.all(items.get(key, notfound_fn)(value) + for (key, value) in val.items()) + + +def TStrictDict(require_all, exclusive, items): + """Strict dictionary check with specific keys. + + @type require_all: boolean + @param require_all: Whether all keys in L{items} are required + @type exclusive: boolean + @param exclusive: Whether only keys listed in L{items} should be accepted + @type items: dictionary + @param items: Mapping from key (string) to verification function + + """ + descparts = ["Dictionary containing"] + + if exclusive: + descparts.append(" none but the") + + if require_all: + descparts.append(" required") + + if len(items) == 1: + descparts.append(" key ") + else: + descparts.append(" keys ") + + descparts.append(utils.CommaJoin("\"%s\" (value %s)" % (key, value) + for (key, value) in items.items())) + + desc = WithDesc("".join(descparts)) + + return desc(TAnd(TDict, + compat.partial(_TStrictDictCheck, require_all, exclusive, + items))) diff --git a/test/ganeti.ht_unittest.py b/test/ganeti.ht_unittest.py index 34ae6715e0c0d49fcf1e14c2123d600e2637b75f..1dba3165223e496276218671a7a5b7d4e50f7646 100755 --- a/test/ganeti.ht_unittest.py +++ b/test/ganeti.ht_unittest.py @@ -187,6 +187,49 @@ class TestTypeChecks(unittest.TestCase): self.assertFalse(fn({"x": None})) self.assertFalse(fn({"": 8234})) + def testStrictDictRequireAllExclusive(self): + fn = ht.TStrictDict(True, True, { "a": ht.TInt, }) + self.assertFalse(fn(1)) + self.assertFalse(fn(None)) + self.assertFalse(fn({})) + self.assertFalse(fn({"a": "Hello", })) + self.assertFalse(fn({"unknown": 999,})) + self.assertFalse(fn({"unknown": None,})) + + self.assertTrue(fn({"a": 123, })) + self.assertTrue(fn({"a": -5, })) + + fn = ht.TStrictDict(True, True, { "a": ht.TInt, "x": ht.TString, }) + self.assertFalse(fn({})) + self.assertFalse(fn({"a": -5, })) + self.assertTrue(fn({"a": 123, "x": "", })) + self.assertFalse(fn({"a": 123, "x": None, })) + + def testStrictDictExclusive(self): + fn = ht.TStrictDict(False, True, { "a": ht.TInt, "b": ht.TList, }) + self.assertTrue(fn({})) + self.assertTrue(fn({"a": 123, })) + self.assertTrue(fn({"b": range(4), })) + self.assertFalse(fn({"b": 123, })) + + self.assertFalse(fn({"foo": {}, })) + self.assertFalse(fn({"bar": object(), })) + + def testStrictDictRequireAll(self): + fn = ht.TStrictDict(True, False, { "a": ht.TInt, "m": ht.TInt, }) + self.assertTrue(fn({"a": 1, "m": 2, "bar": object(), })) + self.assertFalse(fn({})) + self.assertFalse(fn({"a": 1, "bar": object(), })) + self.assertFalse(fn({"a": 1, "m": [], "bar": object(), })) + + def testStrictDict(self): + fn = ht.TStrictDict(False, False, { "a": ht.TInt, }) + self.assertTrue(fn({})) + self.assertFalse(fn({"a": ""})) + self.assertTrue(fn({"a": 11})) + self.assertTrue(fn({"other": 11})) + self.assertTrue(fn({"other": object()})) + if __name__ == "__main__": testutils.GanetiTestProgram()