From 32be86da251a315c2ffdb3cd2ffd8d5a079eb607 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ren=C3=A9=20Nussbaumer?= <rn@google.com>
Date: Fri, 15 Jun 2012 14:04:42 +0200
Subject: [PATCH] Verify user supplied dicts against defaults
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This verifies the user (especially in nested dicts) does not
provide a key which is not seen in the defaults dict for that dict.

Signed-off-by: RenΓ© Nussbaumer <rn@google.com>
Reviewed-by: Michael Hanselmann <hansmi@google.com>
---
 lib/utils/__init__.py         | 40 ++++++++++++++++++++++++++
 test/ganeti.utils_unittest.py | 54 +++++++++++++++++++++++++++++++++++
 2 files changed, 94 insertions(+)

diff --git a/lib/utils/__init__.py b/lib/utils/__init__.py
index 185d28b42..a95cf7f15 100644
--- a/lib/utils/__init__.py
+++ b/lib/utils/__init__.py
@@ -155,6 +155,46 @@ def ValidateServiceName(name):
   return name
 
 
+def _ComputeMissingKeys(key_path, options, defaults):
+  """Helper functions to compute which keys a invalid.
+
+  @param key_path: The current key path (if any)
+  @param options: The user provided options
+  @param defaults: The default dictionary
+  @return: A list of invalid keys
+
+  """
+  defaults_keys = frozenset(defaults.keys())
+  invalid = []
+  for key, value in options.items():
+    if key_path:
+      new_path = "%s/%s" % (key_path, key)
+    else:
+      new_path = key
+
+    if key not in defaults_keys:
+      invalid.append(new_path)
+    elif isinstance(value, dict):
+      invalid.extend(_ComputeMissingKeys(new_path, value, defaults[key]))
+
+  return invalid
+
+
+def VerifyDictOptions(options, defaults):
+  """Verify a dict has only keys set which also are in the defaults dict.
+
+  @param options: The user provided options
+  @param defaults: The default dictionary
+  @raise error.OpPrereqError: If one of the keys is not supported
+
+  """
+  invalid = _ComputeMissingKeys("", options, defaults)
+
+  if invalid:
+    raise errors.OpPrereqError("Provided option keys not supported: %s" %
+                               CommaJoin(invalid), errors.ECODE_INVAL)
+
+
 def ListVolumeGroups():
   """List volume groups and their size
 
diff --git a/test/ganeti.utils_unittest.py b/test/ganeti.utils_unittest.py
index 80e737c77..21316743b 100755
--- a/test/ganeti.utils_unittest.py
+++ b/test/ganeti.utils_unittest.py
@@ -315,5 +315,59 @@ class TestTryConvert(unittest.TestCase):
       self.assertEqual(utils.TryConvert(fn, src), result)
 
 
+class TestVerifyDictOptions(unittest.TestCase):
+  def setUp(self):
+    self.defaults = {
+      "first_key": "foobar",
+      "foobar": {
+        "key1": "value2",
+        "key2": "value1",
+        },
+      "another_key": "another_value",
+      }
+
+  def test(self):
+    some_keys = {
+      "first_key": "blubb",
+      "foobar": {
+        "key2": "foo",
+        },
+      }
+    utils.VerifyDictOptions(some_keys, self.defaults)
+
+  def testInvalid(self):
+    some_keys = {
+      "invalid_key": "blubb",
+      "foobar": {
+        "key2": "foo",
+        },
+      }
+    self.assertRaises(errors.OpPrereqError, utils.VerifyDictOptions,
+                      some_keys, self.defaults)
+
+  def testNestedInvalid(self):
+    some_keys = {
+      "foobar": {
+        "key2": "foo",
+        "key3": "blibb"
+        },
+      }
+    self.assertRaises(errors.OpPrereqError, utils.VerifyDictOptions,
+                      some_keys, self.defaults)
+
+  def testMultiInvalid(self):
+    some_keys = {
+        "foobar": {
+          "key1": "value3",
+          "key6": "Right here",
+        },
+        "invalid_with_sub": {
+          "sub1": "value3",
+        },
+      }
+    self.assertRaises(errors.OpPrereqError, utils.VerifyDictOptions,
+                      some_keys, self.defaults)
+
+
 if __name__ == '__main__':
   testutils.GanetiTestProgram()
-- 
GitLab