diff --git a/lib/serializer.py b/lib/serializer.py index 50206e38afd421c84dac4e30341b0ddb55b66edc..34c5a3b936d35f04dcc7fab9eaa11328158c3c42 100644 --- a/lib/serializer.py +++ b/lib/serializer.py @@ -96,14 +96,13 @@ def DumpSignedJson(data, key, salt=None): return DumpJson(signed_dict) -def LoadSignedJson(txt, key, salt_verifier=None): +def LoadSignedJson(txt, key): """Verify that a given message was signed with the given key, and load it. @param txt: json-encoded hmac-signed message @param key: shared hmac key - @param salt_verifier: function taking a salt as input and returning boolean @rtype: tuple of original data, string - @return: original data + @return: original data, salt @raises errors.SignatureError: if the message signature doesn't verify """ @@ -117,61 +116,10 @@ def LoadSignedJson(txt, key, salt_verifier=None): except KeyError: raise errors.SignatureError('Invalid external message') - if salt and not salt_verifier: - raise errors.SignatureError('Salted message is not verified') - elif salt_verifier is not None: - if not salt_verifier(salt): - raise errors.SignatureError('Invalid salt') - if hmac.new(key, salt + msg, sha1).hexdigest() != hmac_sign: raise errors.SignatureError('Invalid Signature') - return LoadJson(msg) - - -def SaltEqualTo(expected): - """Helper salt verifier function that checks for equality. - - @type expected: string - @param expected: expected salt - @rtype: function - @return: salt verifier that returns True if the target salt is "x" - - """ - return lambda salt: salt == expected - - -def SaltIn(expected): - """Helper salt verifier function that checks for membership. - - @type expected: collection - @param expected: collection of possible valid salts - @rtype: function - @return: salt verifier that returns True if the salt is in the collection - - """ - return lambda salt: salt in expected - - -def SaltInRange(min, max): - """Helper salt verifier function that checks for integer range. - - @type min: integer - @param min: minimum salt value - @type max: integer - @param max: maximum salt value - @rtype: function - @return: salt verifier that returns True if the salt is in the min,max range - - """ - def _CheckSaltInRange(salt): - try: - i_salt = int(salt) - except (TypeError, ValueError), err: - return False - - return i_salt > min and i_salt < max - return _CheckSaltInRange + return LoadJson(msg), salt Dump = DumpJson diff --git a/test/ganeti.serializer_unittest.py b/test/ganeti.serializer_unittest.py index b1612d6a956afe004b50e12cae7d5633842a24a8..8dd8aea3b5693187a25eb2e2ef8f8bd148b50580 100755 --- a/test/ganeti.serializer_unittest.py +++ b/test/ganeti.serializer_unittest.py @@ -63,36 +63,16 @@ class TestSerializer(unittest.TestCase): def testSignedMessage(self): LoadSigned = serializer.LoadSigned DumpSigned = serializer.DumpSigned - SaltEqualTo = serializer.SaltEqualTo - SaltIn = serializer.SaltIn - SaltInRange = serializer.SaltInRange for data in self._TESTDATA: - self.assertEqual(LoadSigned(DumpSigned(data, "mykey"), "mykey"), data) + self.assertEqual(LoadSigned(DumpSigned(data, "mykey"), "mykey"), + (data, '')) self.assertEqual(LoadSigned( DumpSigned(data, "myprivatekey", "mysalt"), - "myprivatekey", SaltEqualTo("mysalt")), data) - self.assertEqual(LoadSigned( - DumpSigned(data, "myprivatekey", "mysalt"), - "myprivatekey", SaltIn(["notmysalt", "mysalt"])), data) - self.assertEqual(LoadSigned( - DumpSigned(data, "myprivatekey", "12345"), - "myprivatekey", SaltInRange(12340, 12346)), data) + "myprivatekey"), (data, "mysalt")) self.assertRaises(errors.SignatureError, serializer.LoadSigned, serializer.DumpSigned("test", "myprivatekey"), "myotherkey") - self.assertRaises(errors.SignatureError, serializer.LoadSigned, - serializer.DumpSigned("test", "myprivatekey", "salt"), - "myprivatekey") - self.assertRaises(errors.SignatureError, serializer.LoadSigned, - serializer.DumpSigned("test", "myprivatekey", "salt"), - "myprivatekey", SaltIn(["notmysalt", "morenotmysalt"])) - self.assertRaises(errors.SignatureError, serializer.LoadSigned, - serializer.DumpSigned("test", "myprivatekey", "salt"), - "myprivatekey", SaltInRange(1, 2)) - self.assertRaises(errors.SignatureError, serializer.LoadSigned, - serializer.DumpSigned("test", "myprivatekey", "12345"), - "myprivatekey", SaltInRange(1, 2)) def _TestSerializer(self, dump_fn, load_fn): for data in self._TESTDATA: