diff --git a/lib/serializer.py b/lib/serializer.py index 5b806759c456a66801958e65c309dc6a275df9c7..090ae00b65a7f0cd572c9822dd9f45e85d0407d6 100644 --- a/lib/serializer.py +++ b/lib/serializer.py @@ -29,14 +29,18 @@ backend (currently json). # C0103: Invalid name, since pylint doesn't see that Dump points to a # function and not a constant +_OLD_SIMPLEJSON = False + try: import json except ImportError: # The "json" module was only added in Python 2.6. Earlier versions must use # the separate "simplejson" module. import simplejson as json + _OLD_SIMPLEJSON = True import re +import logging from ganeti import errors from ganeti import utils @@ -47,7 +51,23 @@ _JSON_INDENT = 2 _RE_EOLSP = re.compile("[ \t]+$", re.MULTILINE) -def _GetJsonDumpers(_encoder_class=json.JSONEncoder): +class _CustomJsonEncoder(json.JSONEncoder): + if __debug__ and not _OLD_SIMPLEJSON: + try: + _orig_fn = json.JSONEncoder._iterencode_dict + except AttributeError: + raise Exception("Can't override JSONEncoder's '_iterencode_dict'") + else: + def _iterencode_dict(self, data, *args, **kwargs): + for key in data.keys(): + if not (key is None or isinstance(key, (basestring, bool))): + raise ValueError("Key '%s' is of disallowed type '%s'" % + (key, type(key))) + + return self._orig_fn(data, *args, **kwargs) + + +def _GetJsonDumpers(_encoder_class=_CustomJsonEncoder): """Returns two JSON functions to serialize data. @rtype: (callable, callable) diff --git a/test/ganeti.serializer_unittest.py b/test/ganeti.serializer_unittest.py index 9d8d6568096133b2b8d36f7f5a71c2132a264e16..3e68efaa62ce41dd34cf04b6937b20a617d92bf4 100755 --- a/test/ganeti.serializer_unittest.py +++ b/test/ganeti.serializer_unittest.py @@ -23,9 +23,11 @@ import unittest +import warnings from ganeti import serializer from ganeti import errors +from ganeti import compat import testutils @@ -107,5 +109,24 @@ class TestSerializer(testutils.GanetiTestCase): serializer.DumpJson(tdata), "mykey") +class TestInvalidDictionaryKey(unittest.TestCase): + def _Test(self, data): + if serializer._OLD_SIMPLEJSON: + # Using old "simplejson", can't really test + warnings.warn("This test requires Python 2.6 or above to function" + " correctly") + self.assertTrue(serializer.DumpJson(data)) + else: + self.assertRaises(ValueError, serializer.DumpJson, data) + + def test(self): + for value in [123, 1.1, -1, -9492.1123, -3234e-4]: + self._Test({value: ""}) + + def testAllowed(self): + for value in ["", "Hello World", None, True, False]: + self.assertTrue(serializer.DumpJson({value: ""})) + + if __name__ == '__main__': testutils.GanetiTestProgram()