From f4a2f532f9c6243444ab52b67fad18e0628e036e Mon Sep 17 00:00:00 2001
From: Guido Trotter <ultrotter@google.com>
Date: Mon, 13 Jul 2009 15:40:06 +0200
Subject: [PATCH] HMAC authenticated json messages

This patch includes HMAC authenticated json messages to the serializer.
The new interface works on any json-encodable data type, and can sign it
with a private key and an optional salt. The same private key must be
used upon message loading to verify the message.

Signed-off-by: Guido Trotter <ultrotter@google.com>
Reviewed-by: Michael Hanselmann <hansmi@google.com>
---
 lib/errors.py                      |  11 +++
 lib/serializer.py                  | 103 +++++++++++++++++++++++++++++
 test/ganeti.serializer_unittest.py |  35 ++++++++++
 3 files changed, 149 insertions(+)

diff --git a/lib/errors.py b/lib/errors.py
index 195175773..9bc9a593e 100644
--- a/lib/errors.py
+++ b/lib/errors.py
@@ -98,6 +98,17 @@ class RemoteError(GenericError):
   pass
 
 
+class SignatureError(GenericError):
+  """Error authenticating a remote message.
+
+  This is raised when the hmac signature on a message doesn't verify correctly
+  to the message itself. It can happen because of network unreliability or
+  because of spurious traffic.
+
+  """
+  pass
+
+
 class ParameterError(GenericError):
   """A passed parameter to a command is invalid.
 
diff --git a/lib/serializer.py b/lib/serializer.py
index fcde99270..97811879d 100644
--- a/lib/serializer.py
+++ b/lib/serializer.py
@@ -27,6 +27,10 @@ backend (currently json).
 
 import simplejson
 import re
+import hmac
+import hashlib
+
+from ganeti import errors
 
 
 # Check whether the simplejson module supports indentation
@@ -70,5 +74,104 @@ def LoadJson(txt):
   return simplejson.loads(txt)
 
 
+def DumpSignedJson(data, key, salt=None):
+  """Serialize a given object and authenticate it.
+
+  @param data: the data to serialize
+  @param key: shared hmac key
+  @return: the string representation of data signed by the hmac key
+
+  """
+  txt = DumpJson(data, indent=False)
+  if salt is None:
+    salt = ''
+  signed_dict = {
+    'msg': txt,
+    'salt': salt,
+    'hmac': hmac.new(key, salt + txt, hashlib.sha256).hexdigest(),
+  }
+  return DumpJson(signed_dict)
+
+
+def LoadSignedJson(txt, key, salt_verifier=None):
+  """Verify that a given message was signed with the given key, and load it.
+
+  @param txt: json-encoded hmac-signed message
+  @param key: shared hmac key
+  @param salt_verifier: function taking a salt as input and returning boolean
+  @rtype: tuple of original data, string
+  @return: (original data, salt)
+  @raises errors.SignatureError: if the message signature doesn't verify
+
+  """
+  signed_dict = LoadJson(txt)
+  if not isinstance(signed_dict, dict):
+    raise errors.SignatureError('Invalid external message')
+  try:
+    msg = signed_dict['msg']
+    salt = signed_dict['salt']
+    hmac_sign = signed_dict['hmac']
+  except KeyError:
+    raise errors.SignatureError('Invalid external message')
+
+  if salt and not salt_verifier:
+    raise errors.SignatureError('Salted message is not verified')
+  elif salt_verifier is not None:
+    if not salt_verifier(salt):
+      raise errors.SignatureError('Invalid salt')
+
+  if hmac.new(key, salt + msg, hashlib.sha256).hexdigest() != hmac_sign:
+    raise errors.SignatureError('Invalid Signature')
+  return LoadJson(msg)
+
+
+def SaltEqualTo(expected):
+  """Helper salt verifier function that checks for equality.
+
+  @type expected: string
+  @param expected: expected salt
+  @rtype: function
+  @return: salt verifier that returns True if the target salt is "x"
+
+  """
+  return lambda salt: salt == expected
+
+
+def SaltIn(expected):
+  """Helper salt verifier function that checks for equality.
+
+  @type expected: collection
+  @param expected: collection of possible valid salts
+  @rtype: function
+  @return: salt verifier that returns True if the salt is in the collection
+
+  """
+  return lambda salt: salt in expected
+
+
+def SaltInRange(min, max):
+  """Helper salt verifier function that checks for equality.
+
+  @type min: integer
+  @param min: minimum salt value
+  @type max: integer
+  @param max: maximum salt value
+  @rtype: function
+  @return: salt verifier that returns True if the salt is in the min,max range
+
+  """
+  def _CheckSaltInRange(salt):
+    try:
+      i_salt = int(salt)
+    except (TypeError, ValueError), err:
+      return False
+
+    return i_salt > min and i_salt < max
+
+  return _CheckSaltInRange
+
+
 Dump = DumpJson
 Load = LoadJson
+DumpSigned = DumpSignedJson
+LoadSigned = LoadSignedJson
diff --git a/test/ganeti.serializer_unittest.py b/test/ganeti.serializer_unittest.py
index 08aad675d..b1612d6a9 100755
--- a/test/ganeti.serializer_unittest.py
+++ b/test/ganeti.serializer_unittest.py
@@ -25,6 +25,7 @@
 import unittest
 
 from ganeti import serializer
+from ganeti import errors
 
 
 class SimplejsonMock(object):
@@ -59,6 +60,40 @@ class TestSerializer(unittest.TestCase):
   def testJson(self):
     return self._TestSerializer(serializer.DumpJson, serializer.LoadJson)
 
+  def testSignedMessage(self):
+    LoadSigned = serializer.LoadSigned
+    DumpSigned = serializer.DumpSigned
+    SaltEqualTo = serializer.SaltEqualTo
+    SaltIn = serializer.SaltIn
+    SaltInRange = serializer.SaltInRange
+
+    for data in self._TESTDATA:
+      self.assertEqual(LoadSigned(DumpSigned(data, "mykey"), "mykey"), data)
+      self.assertEqual(LoadSigned(
+                         DumpSigned(data, "myprivatekey", "mysalt"),
+                         "myprivatekey", SaltEqualTo("mysalt")), data)
+      self.assertEqual(LoadSigned(
+                         DumpSigned(data, "myprivatekey", "mysalt"),
+                         "myprivatekey", SaltIn(["notmysalt", "mysalt"])), data)
+      self.assertEqual(LoadSigned(
+                         DumpSigned(data, "myprivatekey", "12345"),
+                         "myprivatekey", SaltInRange(12340, 12346)), data)
+    self.assertRaises(errors.SignatureError, serializer.LoadSigned,
+                      serializer.DumpSigned("test", "myprivatekey"),
+                      "myotherkey")
+    self.assertRaises(errors.SignatureError, serializer.LoadSigned,
+                      serializer.DumpSigned("test", "myprivatekey", "salt"),
+                      "myprivatekey")
+    self.assertRaises(errors.SignatureError, serializer.LoadSigned,
+                      serializer.DumpSigned("test", "myprivatekey", "salt"),
+                      "myprivatekey", SaltIn(["notmysalt", "morenotmysalt"]))
+    self.assertRaises(errors.SignatureError, serializer.LoadSigned,
+                      serializer.DumpSigned("test", "myprivatekey", "salt"),
+                      "myprivatekey", SaltInRange(1, 2))
+    self.assertRaises(errors.SignatureError, serializer.LoadSigned,
+                      serializer.DumpSigned("test", "myprivatekey", "12345"),
+                      "myprivatekey", SaltInRange(1, 2))
+
   def _TestSerializer(self, dump_fn, load_fn):
     for data in self._TESTDATA:
       self.failUnless(dump_fn(data).endswith("\n"))
-- 
GitLab