From 1cbef6d820ee63e01a09c190f55e77c8b09188c3 Mon Sep 17 00:00:00 2001
From: Michael Hanselmann <hansmi@google.com>
Date: Fri, 31 Dec 2010 16:39:53 +0100
Subject: [PATCH] Migrate code verifying opcode parameters to base class

This allows the function to be used in other places as well.
An optional parameter is added to control whether default
values should be set. Unittests are added, providing full
coverage.

Signed-off-by: Michael Hanselmann <hansmi@google.com>
Reviewed-by: Iustin Pop <iustin@google.com>
---
 lib/cmdlib.py                   | 29 +---------
 lib/opcodes.py                  | 39 ++++++++++++++
 test/ganeti.opcodes_unittest.py | 94 +++++++++++++++++++++++++++++++++
 3 files changed, 135 insertions(+), 27 deletions(-)

diff --git a/lib/cmdlib.py b/lib/cmdlib.py
index d7230b9e5..8c9e06d68 100644
--- a/lib/cmdlib.py
+++ b/lib/cmdlib.py
@@ -54,7 +54,6 @@ from ganeti import uidpool
 from ganeti import compat
 from ganeti import masterd
 from ganeti import netutils
-from ganeti import ht
 from ganeti import query
 from ganeti import qlang
 from ganeti import opcodes
@@ -134,32 +133,8 @@ class LogicalUnit(object):
     # Tasklets
     self.tasklets = None
 
-    # The new kind-of-type-system
-    op_id = self.op.OP_ID
-    for attr_name, aval, test in self.op.GetAllParams():
-      if not hasattr(op, attr_name):
-        if aval == ht.NoDefault:
-          raise errors.OpPrereqError("Required parameter '%s.%s' missing" %
-                                     (op_id, attr_name), errors.ECODE_INVAL)
-        else:
-          if callable(aval):
-            dval = aval()
-          else:
-            dval = aval
-          setattr(self.op, attr_name, dval)
-      attr_val = getattr(op, attr_name)
-      if test == ht.NoType:
-        # no tests here
-        continue
-      if not callable(test):
-        raise errors.ProgrammerError("Validation for parameter '%s.%s' failed,"
-                                     " given type is not a proper type (%s)" %
-                                     (op_id, attr_name, test))
-      if not test(attr_val):
-        logging.error("OpCode %s, parameter %s, has invalid type %s/value %s",
-                      self.op.OP_ID, attr_name, type(attr_val), attr_val)
-        raise errors.OpPrereqError("Parameter '%s.%s' fails validation" %
-                                   (op_id, attr_name), errors.ECODE_INVAL)
+    # Validate opcode parameters and set defaults
+    self.op.Validate(True)
 
     self.CheckArguments()
 
diff --git a/lib/opcodes.py b/lib/opcodes.py
index 4211c717c..927b0788c 100644
--- a/lib/opcodes.py
+++ b/lib/opcodes.py
@@ -33,6 +33,8 @@ opcodes.
 # few public methods:
 # pylint: disable-msg=R0903
 
+import logging
+
 from ganeti import constants
 from ganeti import errors
 from ganeti import ht
@@ -236,6 +238,43 @@ class BaseOpCode(object):
       slots.extend(getattr(parent, "OP_PARAMS", []))
     return slots
 
+  def Validate(self, set_defaults):
+    """Validate opcode parameters, optionally setting default values.
+
+    @type set_defaults: bool
+    @param set_defaults: Whether to set default values
+    @raise errors.OpPrereqError: When a parameter value doesn't match
+                                 requirements
+
+    """
+    for (attr_name, default, test) in self.GetAllParams():
+      assert test == ht.NoType or callable(test)
+
+      if not hasattr(self, attr_name):
+        if default == ht.NoDefault:
+          raise errors.OpPrereqError("Required parameter '%s.%s' missing" %
+                                     (self.OP_ID, attr_name),
+                                     errors.ECODE_INVAL)
+        elif set_defaults:
+          if callable(default):
+            dval = default()
+          else:
+            dval = default
+          setattr(self, attr_name, dval)
+
+      if test == ht.NoType:
+        # no tests here
+        continue
+
+      if set_defaults or hasattr(self, attr_name):
+        attr_val = getattr(self, attr_name)
+        if not test(attr_val):
+          logging.error("OpCode %s, parameter %s, has invalid type %s/value %s",
+                        self.OP_ID, attr_name, type(attr_val), attr_val)
+          raise errors.OpPrereqError("Parameter '%s.%s' fails validation" %
+                                     (self.OP_ID, attr_name),
+                                     errors.ECODE_INVAL)
+
 
 class OpCode(BaseOpCode):
   """Abstract OpCode.
diff --git a/test/ganeti.opcodes_unittest.py b/test/ganeti.opcodes_unittest.py
index 6be101dd0..65c515e26 100755
--- a/test/ganeti.opcodes_unittest.py
+++ b/test/ganeti.opcodes_unittest.py
@@ -28,6 +28,8 @@ import unittest
 from ganeti import utils
 from ganeti import opcodes
 from ganeti import ht
+from ganeti import constants
+from ganeti import errors
 
 import testutils
 
@@ -131,6 +133,98 @@ class TestOpcodes(unittest.TestCase):
           self.assertFalse(callable(aval()),
                            msg="Default value returned by function is callable")
 
+  def testValidateNoModification(self):
+    class _TestOp(opcodes.OpCode):
+      OP_ID = "OP_TEST"
+      OP_PARAMS = [
+        ("nodef", ht.NoDefault, ht.TMaybeString),
+        ("wdef", "default", ht.TMaybeString),
+        ("number", 0, ht.TInt),
+        ("notype", None, ht.NoType),
+        ]
+
+    # Missing required parameter "nodef"
+    op = _TestOp()
+    before = op.__getstate__()
+    self.assertRaises(errors.OpPrereqError, op.Validate, False)
+    self.assertFalse(hasattr(op, "nodef"))
+    self.assertFalse(hasattr(op, "wdef"))
+    self.assertFalse(hasattr(op, "number"))
+    self.assertFalse(hasattr(op, "notype"))
+    self.assertEqual(op.__getstate__(), before, msg="Opcode was modified")
+
+    # Required parameter "nodef" is provided
+    op = _TestOp(nodef="foo")
+    before = op.__getstate__()
+    op.Validate(False)
+    self.assertEqual(op.__getstate__(), before, msg="Opcode was modified")
+    self.assertEqual(op.nodef, "foo")
+    self.assertFalse(hasattr(op, "wdef"))
+    self.assertFalse(hasattr(op, "number"))
+    self.assertFalse(hasattr(op, "notype"))
+
+    # Missing required parameter "nodef"
+    op = _TestOp(wdef="hello", number=999)
+    before = op.__getstate__()
+    self.assertRaises(errors.OpPrereqError, op.Validate, False)
+    self.assertFalse(hasattr(op, "nodef"))
+    self.assertFalse(hasattr(op, "notype"))
+    self.assertEqual(op.__getstate__(), before, msg="Opcode was modified")
+
+    # Wrong type for "nodef"
+    op = _TestOp(nodef=987)
+    before = op.__getstate__()
+    self.assertRaises(errors.OpPrereqError, op.Validate, False)
+    self.assertEqual(op.nodef, 987)
+    self.assertFalse(hasattr(op, "notype"))
+    self.assertEqual(op.__getstate__(), before, msg="Opcode was modified")
+
+    # Testing different types for "notype"
+    op = _TestOp(nodef="foo", notype=[1, 2, 3])
+    before = op.__getstate__()
+    op.Validate(False)
+    self.assertEqual(op.nodef, "foo")
+    self.assertEqual(op.notype, [1, 2, 3])
+    self.assertEqual(op.__getstate__(), before, msg="Opcode was modified")
+
+    op = _TestOp(nodef="foo", notype="Hello World")
+    before = op.__getstate__()
+    op.Validate(False)
+    self.assertEqual(op.nodef, "foo")
+    self.assertEqual(op.notype, "Hello World")
+    self.assertEqual(op.__getstate__(), before, msg="Opcode was modified")
+
+  def testValidateSetDefaults(self):
+    class _TestOp(opcodes.OpCode):
+      OP_ID = "OP_TEST"
+      OP_PARAMS = [
+        # Static default value
+        ("value1", "default", ht.TMaybeString),
+
+        # Default value callback
+        ("value2", lambda: "result", ht.TMaybeString),
+        ]
+
+    op = _TestOp()
+    before = op.__getstate__()
+    op.Validate(True)
+    self.assertNotEqual(op.__getstate__(), before,
+                        msg="Opcode was not modified")
+    self.assertEqual(op.value1, "default")
+    self.assertEqual(op.value2, "result")
+    self.assert_(op.dry_run is None)
+    self.assert_(op.debug_level is None)
+    self.assertEqual(op.priority, constants.OP_PRIO_DEFAULT)
+
+    op = _TestOp(value1="hello", value2="world", debug_level=123)
+    before = op.__getstate__()
+    op.Validate(True)
+    self.assertNotEqual(op.__getstate__(), before,
+                        msg="Opcode was not modified")
+    self.assertEqual(op.value1, "hello")
+    self.assertEqual(op.value2, "world")
+    self.assertEqual(op.debug_level, 123)
+
 
 if __name__ == "__main__":
   testutils.GanetiTestProgram()
-- 
GitLab