diff --git a/lib/errors.py b/lib/errors.py index 19517577374edeb2d3868c3050794440e1a55618..9bc9a593ec33f54fbb537f752ea8cb969970c75c 100644 --- a/lib/errors.py +++ b/lib/errors.py @@ -98,6 +98,17 @@ class RemoteError(GenericError): pass +class SignatureError(GenericError): + """Error authenticating a remote message. + + This is raised when the hmac signature on a message doesn't verify correctly + to the message itself. It can happen because of network unreliability or + because of spurious traffic. + + """ + pass + + class ParameterError(GenericError): """A passed parameter to a command is invalid. diff --git a/lib/serializer.py b/lib/serializer.py index fcde99270d79c6f30176c2e22116ca6a247dafe5..97811879d74d3c65d9ca8f87b971ae0e3e97699d 100644 --- a/lib/serializer.py +++ b/lib/serializer.py @@ -27,6 +27,10 @@ backend (currently json). import simplejson import re +import hmac +import hashlib + +from ganeti import errors # Check whether the simplejson module supports indentation @@ -70,5 +74,104 @@ def LoadJson(txt): return simplejson.loads(txt) +def DumpSignedJson(data, key, salt=None): + """Serialize a given object and authenticate it. + + @param data: the data to serialize + @param key: shared hmac key + @return: the string representation of data signed by the hmac key + + """ + txt = DumpJson(data, indent=False) + if salt is None: + salt = '' + signed_dict = { + 'msg': txt, + 'salt': salt, + 'hmac': hmac.new(key, salt + txt, hashlib.sha256).hexdigest(), + } + return DumpJson(signed_dict) + + +def LoadSignedJson(txt, key, salt_verifier=None): + """Verify that a given message was signed with the given key, and load it. + + @param txt: json-encoded hmac-signed message + @param key: shared hmac key + @param salt_verifier: function taking a salt as input and returning boolean + @rtype: tuple of original data, string + @return: (original data, salt) + @raises errors.SignatureError: if the message signature doesn't verify + + """ + signed_dict = LoadJson(txt) + if not isinstance(signed_dict, dict): + raise errors.SignatureError('Invalid external message') + try: + msg = signed_dict['msg'] + salt = signed_dict['salt'] + hmac_sign = signed_dict['hmac'] + except KeyError: + raise errors.SignatureError('Invalid external message') + + if salt and not salt_verifier: + raise errors.SignatureError('Salted message is not verified') + elif salt_verifier is not None: + if not salt_verifier(salt): + raise errors.SignatureError('Invalid salt') + + if hmac.new(key, salt + msg, hashlib.sha256).hexdigest() != hmac_sign: + raise errors.SignatureError('Invalid Signature') + return LoadJson(msg) + + +def SaltEqualTo(expected): + """Helper salt verifier function that checks for equality. + + @type expected: string + @param expected: expected salt + @rtype: function + @return: salt verifier that returns True if the target salt is "x" + + """ + return lambda salt: salt == expected + + +def SaltIn(expected): + """Helper salt verifier function that checks for equality. + + @type expected: collection + @param expected: collection of possible valid salts + @rtype: function + @return: salt verifier that returns True if the salt is in the collection + + """ + return lambda salt: salt in expected + + +def SaltInRange(min, max): + """Helper salt verifier function that checks for equality. + + @type min: integer + @param min: minimum salt value + @type max: integer + @param max: maximum salt value + @rtype: function + @return: salt verifier that returns True if the salt is in the min,max range + + """ + def _CheckSaltInRange(salt): + try: + i_salt = int(salt) + except (TypeError, ValueError), err: + return False + + return i_salt > min and i_salt < max + + return _CheckSaltInRange + + Dump = DumpJson Load = LoadJson +DumpSigned = DumpSignedJson +LoadSigned = LoadSignedJson diff --git a/test/ganeti.serializer_unittest.py b/test/ganeti.serializer_unittest.py index 08aad675d0c6656d203d261da98c3b7bc1bf1ffb..b1612d6a956afe004b50e12cae7d5633842a24a8 100755 --- a/test/ganeti.serializer_unittest.py +++ b/test/ganeti.serializer_unittest.py @@ -25,6 +25,7 @@ import unittest from ganeti import serializer +from ganeti import errors class SimplejsonMock(object): @@ -59,6 +60,40 @@ class TestSerializer(unittest.TestCase): def testJson(self): return self._TestSerializer(serializer.DumpJson, serializer.LoadJson) + def testSignedMessage(self): + LoadSigned = serializer.LoadSigned + DumpSigned = serializer.DumpSigned + SaltEqualTo = serializer.SaltEqualTo + SaltIn = serializer.SaltIn + SaltInRange = serializer.SaltInRange + + for data in self._TESTDATA: + self.assertEqual(LoadSigned(DumpSigned(data, "mykey"), "mykey"), data) + self.assertEqual(LoadSigned( + DumpSigned(data, "myprivatekey", "mysalt"), + "myprivatekey", SaltEqualTo("mysalt")), data) + self.assertEqual(LoadSigned( + DumpSigned(data, "myprivatekey", "mysalt"), + "myprivatekey", SaltIn(["notmysalt", "mysalt"])), data) + self.assertEqual(LoadSigned( + DumpSigned(data, "myprivatekey", "12345"), + "myprivatekey", SaltInRange(12340, 12346)), data) + self.assertRaises(errors.SignatureError, serializer.LoadSigned, + serializer.DumpSigned("test", "myprivatekey"), + "myotherkey") + self.assertRaises(errors.SignatureError, serializer.LoadSigned, + serializer.DumpSigned("test", "myprivatekey", "salt"), + "myprivatekey") + self.assertRaises(errors.SignatureError, serializer.LoadSigned, + serializer.DumpSigned("test", "myprivatekey", "salt"), + "myprivatekey", SaltIn(["notmysalt", "morenotmysalt"])) + self.assertRaises(errors.SignatureError, serializer.LoadSigned, + serializer.DumpSigned("test", "myprivatekey", "salt"), + "myprivatekey", SaltInRange(1, 2)) + self.assertRaises(errors.SignatureError, serializer.LoadSigned, + serializer.DumpSigned("test", "myprivatekey", "12345"), + "myprivatekey", SaltInRange(1, 2)) + def _TestSerializer(self, dump_fn, load_fn): for data in self._TESTDATA: self.failUnless(dump_fn(data).endswith("\n"))