diff --git a/Makefile.am b/Makefile.am index 7a377572841ea9c4fc10e7b3f138a8cd9288331d..c4b4d06b7d5b01317f46d205d299f2e90da7cdef 100644 --- a/Makefile.am +++ b/Makefile.am @@ -225,6 +225,7 @@ pkgpython_PYTHON = \ lib/mcpu.py \ lib/netutils.py \ lib/objects.py \ + lib/objectutils.py \ lib/opcodes.py \ lib/ovf.py \ lib/qlang.py \ @@ -877,6 +878,7 @@ python_tests = \ test/ganeti.mcpu_unittest.py \ test/ganeti.netutils_unittest.py \ test/ganeti.objects_unittest.py \ + test/ganeti.objectutils_unittest.py \ test/ganeti.opcodes_unittest.py \ test/ganeti.ovf_unittest.py \ test/ganeti.qlang_unittest.py \ diff --git a/lib/objects.py b/lib/objects.py index 1d511824dea6466074ddcdc93edeba2f19037fac..30223b1a4f82eb959d11713f97f307e99a0fd9b4 100644 --- a/lib/objects.py +++ b/lib/objects.py @@ -44,6 +44,7 @@ from cStringIO import StringIO from ganeti import errors from ganeti import constants from ganeti import netutils +from ganeti import objectutils from ganeti import utils from socket import AF_INET @@ -191,7 +192,7 @@ def MakeEmptyIPolicy(): ]) -class ConfigObject(object): +class ConfigObject(objectutils.ValidatedSlots): """A generic config object. It has the following properties: @@ -206,34 +207,22 @@ class ConfigObject(object): """ __slots__ = [] - def __init__(self, **kwargs): - for k, v in kwargs.iteritems(): - setattr(self, k, v) - def __getattr__(self, name): - if name not in self._all_slots(): + if name not in self.GetAllSlots(): raise AttributeError("Invalid object attribute %s.%s" % (type(self).__name__, name)) return None def __setstate__(self, state): - slots = self._all_slots() + slots = self.GetAllSlots() for name in state: if name in slots: setattr(self, name, state[name]) - @classmethod - def _all_slots(cls): - """Compute the list of all declared slots for a class. + def Validate(self): + """Validates the slots. """ - slots = [] - for parent in cls.__mro__: - slots.extend(getattr(parent, "__slots__", [])) - return slots - - #: Public getter for the defined slots - GetAllSlots = _all_slots def ToDict(self): """Convert to a dict holding only standard python types. @@ -246,7 +235,7 @@ class ConfigObject(object): """ result = {} - for name in self._all_slots(): + for name in self.GetAllSlots(): value = getattr(self, name, None) if value is not None: result[name] = value diff --git a/lib/objectutils.py b/lib/objectutils.py new file mode 100644 index 0000000000000000000000000000000000000000..e2742b14c1413a1297f26cfb9bf0534d999b8dd7 --- /dev/null +++ b/lib/objectutils.py @@ -0,0 +1,93 @@ +# +# + +# Copyright (C) 2012 Google Inc. +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +# 02110-1301, USA. + +"""Module for object related utils.""" + + +class AutoSlots(type): + """Meta base class for __slots__ definitions. + + """ + def __new__(mcs, name, bases, attrs): + """Called when a class should be created. + + @param mcs: The meta class + @param name: Name of created class + @param bases: Base classes + @type attrs: dict + @param attrs: Class attributes + + """ + assert "__slots__" not in attrs, \ + "Class '%s' defines __slots__ when it should not" % name + + attrs["__slots__"] = mcs._GetSlots(attrs) + + return type.__new__(mcs, name, bases, attrs) + + @classmethod + def _GetSlots(mcs, attrs): + """Used to get the list of defined slots. + + @param attrs: The attributes of the class + + """ + raise NotImplementedError + + +class ValidatedSlots(object): + """Sets and validates slots. + + """ + __slots__ = [] + + def __init__(self, **kwargs): + """Constructor for BaseOpCode. + + The constructor takes only keyword arguments and will set + attributes on this object based on the passed arguments. As such, + it means that you should not pass arguments which are not in the + __slots__ attribute for this class. + + """ + slots = self.GetAllSlots() + for (key, value) in kwargs.items(): + if key not in slots: + raise TypeError("Object %s doesn't support the parameter '%s'" % + (self.__class__.__name__, key)) + setattr(self, key, value) + + @classmethod + def GetAllSlots(cls): + """Compute the list of all declared slots for a class. + + """ + slots = [] + for parent in cls.__mro__: + slots.extend(getattr(parent, "__slots__", [])) + return slots + + def Validate(self): + """Validates the slots. + + This method must be implemented by the child classes. + + """ + raise NotImplementedError diff --git a/lib/opcodes.py b/lib/opcodes.py index 3ca448bd0468e45763f3688bc74ee94c03e0cd9f..9c2bbcd1fd2dda2ed839ea1413d79651c74fe600 100644 --- a/lib/opcodes.py +++ b/lib/opcodes.py @@ -40,6 +40,7 @@ from ganeti import constants from ganeti import errors from ganeti import ht from ganeti import objects +from ganeti import objectutils # Common opcode attributes @@ -342,7 +343,7 @@ _PStorageType = ("storage_type", ht.NoDefault, _CheckStorageType, "Storage type") -class _AutoOpParamSlots(type): +class _AutoOpParamSlots(objectutils.AutoSlots): """Meta class for opcode definitions. """ @@ -356,27 +357,29 @@ class _AutoOpParamSlots(type): @param attrs: Class attributes """ - assert "__slots__" not in attrs, \ - "Class '%s' defines __slots__ when it should use OP_PARAMS" % name assert "OP_ID" not in attrs, "Class '%s' defining OP_ID" % name + slots = mcs._GetSlots(attrs) + assert "OP_DSC_FIELD" not in attrs or attrs["OP_DSC_FIELD"] in slots, \ + "Class '%s' uses unknown field in OP_DSC_FIELD" % name + attrs["OP_ID"] = _NameToId(name) + return objectutils.AutoSlots.__new__(mcs, name, bases, attrs) + + @classmethod + def _GetSlots(mcs, attrs): + """Build the slots out of OP_PARAMS. + + """ # Always set OP_PARAMS to avoid duplicates in BaseOpCode.GetAllParams params = attrs.setdefault("OP_PARAMS", []) # Use parameter names as slots - slots = [pname for (pname, _, _, _) in params] - - assert "OP_DSC_FIELD" not in attrs or attrs["OP_DSC_FIELD"] in slots, \ - "Class '%s' uses unknown field in OP_DSC_FIELD" % name - - attrs["__slots__"] = slots - - return type.__new__(mcs, name, bases, attrs) + return [pname for (pname, _, _, _) in params] -class BaseOpCode(object): +class BaseOpCode(objectutils.ValidatedSlots): """A simple serializable object. This object serves as a parent class for OpCode without any custom @@ -387,22 +390,6 @@ class BaseOpCode(object): # as OP_ID is dynamically defined __metaclass__ = _AutoOpParamSlots - def __init__(self, **kwargs): - """Constructor for BaseOpCode. - - The constructor takes only keyword arguments and will set - attributes on this object based on the passed arguments. As such, - it means that you should not pass arguments which are not in the - __slots__ attribute for this class. - - """ - slots = self._all_slots() - for key in kwargs: - if key not in slots: - raise TypeError("Object %s doesn't support the parameter '%s'" % - (self.__class__.__name__, key)) - setattr(self, key, kwargs[key]) - def __getstate__(self): """Generic serializer. @@ -414,7 +401,7 @@ class BaseOpCode(object): """ state = {} - for name in self._all_slots(): + for name in self.GetAllSlots(): if hasattr(self, name): state[name] = getattr(self, name) return state @@ -433,23 +420,13 @@ class BaseOpCode(object): raise ValueError("Invalid data to __setstate__: expected dict, got %s" % type(state)) - for name in self._all_slots(): + for name in self.GetAllSlots(): if name not in state and hasattr(self, name): delattr(self, name) for name in state: setattr(self, name, state[name]) - @classmethod - def _all_slots(cls): - """Compute the list of all declared slots for a class. - - """ - slots = [] - for parent in cls.__mro__: - slots.extend(getattr(parent, "__slots__", [])) - return slots - @classmethod def GetAllParams(cls): """Compute list of all parameters for an opcode. @@ -460,7 +437,7 @@ class BaseOpCode(object): slots.extend(getattr(parent, "OP_PARAMS", [])) return slots - def Validate(self, set_defaults): + def Validate(self, set_defaults): # pylint: disable=W0221 """Validate opcode parameters, optionally setting default values. @type set_defaults: bool diff --git a/test/ganeti.objectutils_unittest.py b/test/ganeti.objectutils_unittest.py new file mode 100755 index 0000000000000000000000000000000000000000..aa3a20515d698ec91c36e3fd883b2a0095102942 --- /dev/null +++ b/test/ganeti.objectutils_unittest.py @@ -0,0 +1,50 @@ +#!/usr/bin/python +# + +# Copyright (C) 2012 Google Inc. +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +# 02110-1301, USA. + + +"""Script for unittesting the objectutils module""" + + +import unittest + +from ganeti import objectutils + +import testutils + + +class SlotsAutoSlot(objectutils.AutoSlots): + @classmethod + def _GetSlots(mcs, attr): + return attr["SLOTS"] + + +class AutoSlotted(object): + __metaclass__ = SlotsAutoSlot + + SLOTS = ["foo", "bar", "baz"] + + +class TestAutoSlot(unittest.TestCase): + def test(self): + slotted = AutoSlotted() + self.assertEqual(slotted.__slots__, AutoSlotted.SLOTS) + +if __name__ == "__main__": + testutils.GanetiTestProgram() diff --git a/test/ganeti.opcodes_unittest.py b/test/ganeti.opcodes_unittest.py index 7e6cb6e340424ba9f4115f1742f931245927c915..9de861cf0a954ce7782d430290bf07b7b1fd6d50 100755 --- a/test/ganeti.opcodes_unittest.py +++ b/test/ganeti.opcodes_unittest.py @@ -77,7 +77,7 @@ class TestOpcodes(unittest.TestCase): {"dry_run": False, "debug_level": 0, }, # All variables - dict([(name, False) for name in cls._all_slots()]) + dict([(name, False) for name in cls.GetAllSlots()]) ] for i in args: @@ -95,7 +95,7 @@ class TestOpcodes(unittest.TestCase): self._checkSummary(restored) for name in ["x_y_z", "hello_world"]: - assert name not in cls._all_slots() + assert name not in cls.GetAllSlots() for value in [None, True, False, [], "Hello World"]: self.assertRaises(AttributeError, setattr, op, name, value) @@ -158,7 +158,7 @@ class TestOpcodes(unittest.TestCase): self.assertTrue(opcodes.OpCode not in opcodes.OP_MAPPING.values()) for cls in opcodes.OP_MAPPING.values() + [opcodes.OpCode]: - all_slots = cls._all_slots() + all_slots = cls.GetAllSlots() self.assertEqual(len(set(all_slots) & supported_by_all), 3, msg=("Opcode %s doesn't support all base" diff --git a/tools/cfgshell b/tools/cfgshell index 3e79561b109d986d6204891e822fcbccb2eacc9c..fa145c238b68ad59c59bb8aff4797ddc33ef754a 100755 --- a/tools/cfgshell +++ b/tools/cfgshell @@ -95,7 +95,7 @@ class ConfigShell(cmd.Cmd): if isinstance(obj, objects.ConfigObject): # pylint: disable=W0212 # yes, we're using a protected member - for name in obj._all_slots(): + for name in obj.GetAllSlots(): child = getattr(obj, name, None) if isinstance(child, (list, dict, tuple, objects.ConfigObject)): dirs.append(name)