diff --git a/lib/constants.py b/lib/constants.py index 3b4c38b4c5e654684e36377c61b85b824d8b78d0..a04dbadeb4a98d1e4e2ab43c26db0d4b200eab9d 100644 --- a/lib/constants.py +++ b/lib/constants.py @@ -510,11 +510,13 @@ REBOOT_TYPES = frozenset([INSTANCE_REBOOT_SOFT, INSTANCE_REBOOT_FULL]) VTYPE_STRING = 'string' +VTYPE_MAYBE_STRING = "maybe-string" VTYPE_BOOL = 'bool' VTYPE_SIZE = 'size' # size, in MiBs VTYPE_INT = 'int' ENFORCEABLE_TYPES = frozenset([ VTYPE_STRING, + VTYPE_MAYBE_STRING, VTYPE_BOOL, VTYPE_SIZE, VTYPE_INT, diff --git a/lib/utils.py b/lib/utils.py index ddf6882296b4c570fed1c74888dbee15847bcaeb..441e5bb11e6add131d262ef5ef744064f4aa7757 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -786,8 +786,10 @@ def ForceDictType(target, key_types, allowed_values=None): msg = "'%s' has non-enforceable type %s" % (key, ktype) raise errors.ProgrammerError(msg) - if ktype == constants.VTYPE_STRING: - if not isinstance(target[key], basestring): + if ktype in (constants.VTYPE_STRING, constants.VTYPE_MAYBE_STRING): + if target[key] is None and ktype == constants.VTYPE_MAYBE_STRING: + pass + elif not isinstance(target[key], basestring): if isinstance(target[key], bool) and not target[key]: target[key] = '' else: diff --git a/test/ganeti.utils_unittest.py b/test/ganeti.utils_unittest.py index 839838e867f8b4886fed53cb1b6d91ec845fa424..df5ef2026a6bfb4422037270b1b7c7cb946e3348 100755 --- a/test/ganeti.utils_unittest.py +++ b/test/ganeti.utils_unittest.py @@ -1503,6 +1503,7 @@ class TestForceDictType(unittest.TestCase): 'b': constants.VTYPE_BOOL, 'c': constants.VTYPE_STRING, 'd': constants.VTYPE_SIZE, + "e": constants.VTYPE_MAYBE_STRING, } def _fdt(self, dict, allowed_values=None): @@ -1526,12 +1527,17 @@ class TestForceDictType(unittest.TestCase): self.assertEqual(self._fdt({'b': 'True'}), {'b': True}) self.assertEqual(self._fdt({'d': '4'}), {'d': 4}) self.assertEqual(self._fdt({'d': '4M'}), {'d': 4}) + self.assertEqual(self._fdt({"e": None, }), {"e": None, }) + self.assertEqual(self._fdt({"e": "Hello World", }), {"e": "Hello World", }) + self.assertEqual(self._fdt({"e": False, }), {"e": '', }) def testErrors(self): self.assertRaises(errors.TypeEnforcementError, self._fdt, {'a': 'astring'}) self.assertRaises(errors.TypeEnforcementError, self._fdt, {'c': True}) self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': 'astring'}) self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': '4 L'}) + self.assertRaises(errors.TypeEnforcementError, self._fdt, {"e": object(), }) + self.assertRaises(errors.TypeEnforcementError, self._fdt, {"e": [], }) class TestIsNormAbsPath(unittest.TestCase):