From 5d630c2270df50f352aec9cc59e8a973276a66a5 Mon Sep 17 00:00:00 2001
From: Michael Hanselmann <hansmi@google.com>
Date: Tue, 27 Nov 2012 11:43:07 +0100
Subject: [PATCH] Factorize code to load and verify JSON

A new tool to configure the node daemon will also have to load and
verify JSON data.

Signed-off-by: Michael Hanselmann <hansmi@google.com>
Reviewed-by: Guido Trotter <ultrotter@google.com>
---
 lib/serializer.py                  | 22 ++++++++++++++++++++++
 lib/tools/prepare_node_join.py     | 11 +----------
 test/ganeti.serializer_unittest.py | 25 +++++++++++++++++++++++++
 3 files changed, 48 insertions(+), 10 deletions(-)

diff --git a/lib/serializer.py b/lib/serializer.py
index cbc11fa65..f0081551e 100644
--- a/lib/serializer.py
+++ b/lib/serializer.py
@@ -140,6 +140,28 @@ def LoadSignedJson(txt, key):
   return LoadJson(msg), salt
 
 
+def LoadAndVerifyJson(raw, verify_fn):
+  """Parses and verifies JSON data.
+
+  @type raw: string
+  @param raw: Input data in JSON format
+  @type verify_fn: callable
+  @param verify_fn: Verification function, usually from L{ht}
+  @return: De-serialized data
+
+  """
+  try:
+    data = LoadJson(raw)
+  except Exception, err:
+    raise errors.ParseError("Can't parse input data: %s" % err)
+
+  if not verify_fn(data):
+    raise errors.ParseError("Data does not match expected format: %s" %
+                            verify_fn)
+
+  return data
+
+
 Dump = DumpJson
 Load = LoadJson
 DumpSigned = DumpSignedJson
diff --git a/lib/tools/prepare_node_join.py b/lib/tools/prepare_node_join.py
index deedd0b32..b88e02e7e 100644
--- a/lib/tools/prepare_node_join.py
+++ b/lib/tools/prepare_node_join.py
@@ -281,16 +281,7 @@ def LoadData(raw):
   @rtype: dict
 
   """
-  try:
-    data = serializer.LoadJson(raw)
-  except Exception, err:
-    raise errors.ParseError("Can't parse input data: %s" % err)
-
-  if not _DATA_CHECK(data):
-    raise errors.ParseError("Input data does not match expected format: %s" %
-                            _DATA_CHECK)
-
-  return data
+  return serializer.LoadAndVerifyJson(raw, _DATA_CHECK)
 
 
 def Main():
diff --git a/test/ganeti.serializer_unittest.py b/test/ganeti.serializer_unittest.py
index 75c498c06..45a393238 100755
--- a/test/ganeti.serializer_unittest.py
+++ b/test/ganeti.serializer_unittest.py
@@ -26,6 +26,7 @@ import unittest
 
 from ganeti import serializer
 from ganeti import errors
+from ganeti import ht
 
 import testutils
 
@@ -106,5 +107,29 @@ class TestSerializer(testutils.GanetiTestCase):
                       serializer.DumpJson(tdata), "mykey")
 
 
+class TestLoadAndVerifyJson(unittest.TestCase):
+  def testNoJson(self):
+    self.assertRaises(errors.ParseError, serializer.LoadAndVerifyJson,
+                      "", NotImplemented)
+    self.assertRaises(errors.ParseError, serializer.LoadAndVerifyJson,
+                      "}", NotImplemented)
+
+  def testVerificationFails(self):
+    self.assertRaises(errors.ParseError, serializer.LoadAndVerifyJson,
+                      "{}", lambda _: False)
+
+    verify_fn = ht.TListOf(ht.TNonEmptyString)
+    try:
+      serializer.LoadAndVerifyJson("{}", verify_fn)
+    except errors.ParseError, err:
+      self.assertTrue(str(err).endswith(str(verify_fn)))
+    else:
+      self.fail("Exception not raised")
+
+  def testSuccess(self):
+    self.assertEqual(serializer.LoadAndVerifyJson("{}", ht.TAny), {})
+    self.assertEqual(serializer.LoadAndVerifyJson("\"Foo\"", ht.TAny), "Foo")
+
+
 if __name__ == "__main__":
   testutils.GanetiTestProgram()
-- 
GitLab