Commit 32683096 authored by René Nussbaumer's avatar René Nussbaumer
Browse files

Make the __slots__ functionality more modular



As we will introduce another set of containers using the __slots__ trick
we abstract away as much as possible to separate bases classes. The
child classes then adapt them for their needs. This leads to less code
duplication.
Signed-off-by: default avatarRené Nussbaumer <rn@google.com>
Reviewed-by: default avatarMichael Hanselmann <hansmi@google.com>
parent b112bfc4
......@@ -225,6 +225,7 @@ pkgpython_PYTHON = \
lib/mcpu.py \
lib/netutils.py \
lib/objects.py \
lib/objectutils.py \
lib/opcodes.py \
lib/ovf.py \
lib/qlang.py \
......@@ -877,6 +878,7 @@ python_tests = \
test/ganeti.mcpu_unittest.py \
test/ganeti.netutils_unittest.py \
test/ganeti.objects_unittest.py \
test/ganeti.objectutils_unittest.py \
test/ganeti.opcodes_unittest.py \
test/ganeti.ovf_unittest.py \
test/ganeti.qlang_unittest.py \
......
......@@ -44,6 +44,7 @@ from cStringIO import StringIO
from ganeti import errors
from ganeti import constants
from ganeti import netutils
from ganeti import objectutils
from ganeti import utils
from socket import AF_INET
......@@ -191,7 +192,7 @@ def MakeEmptyIPolicy():
])
class ConfigObject(object):
class ConfigObject(objectutils.ValidatedSlots):
"""A generic config object.
It has the following properties:
......@@ -206,34 +207,22 @@ class ConfigObject(object):
"""
__slots__ = []
def __init__(self, **kwargs):
for k, v in kwargs.iteritems():
setattr(self, k, v)
def __getattr__(self, name):
if name not in self._all_slots():
if name not in self.GetAllSlots():
raise AttributeError("Invalid object attribute %s.%s" %
(type(self).__name__, name))
return None
def __setstate__(self, state):
slots = self._all_slots()
slots = self.GetAllSlots()
for name in state:
if name in slots:
setattr(self, name, state[name])
@classmethod
def _all_slots(cls):
"""Compute the list of all declared slots for a class.
def Validate(self):
"""Validates the slots.
"""
slots = []
for parent in cls.__mro__:
slots.extend(getattr(parent, "__slots__", []))
return slots
#: Public getter for the defined slots
GetAllSlots = _all_slots
def ToDict(self):
"""Convert to a dict holding only standard python types.
......@@ -246,7 +235,7 @@ class ConfigObject(object):
"""
result = {}
for name in self._all_slots():
for name in self.GetAllSlots():
value = getattr(self, name, None)
if value is not None:
result[name] = value
......
#
#
# Copyright (C) 2012 Google Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
# 02110-1301, USA.
"""Module for object related utils."""
class AutoSlots(type):
"""Meta base class for __slots__ definitions.
"""
def __new__(mcs, name, bases, attrs):
"""Called when a class should be created.
@param mcs: The meta class
@param name: Name of created class
@param bases: Base classes
@type attrs: dict
@param attrs: Class attributes
"""
assert "__slots__" not in attrs, \
"Class '%s' defines __slots__ when it should not" % name
attrs["__slots__"] = mcs._GetSlots(attrs)
return type.__new__(mcs, name, bases, attrs)
@classmethod
def _GetSlots(mcs, attrs):
"""Used to get the list of defined slots.
@param attrs: The attributes of the class
"""
raise NotImplementedError
class ValidatedSlots(object):
"""Sets and validates slots.
"""
__slots__ = []
def __init__(self, **kwargs):
"""Constructor for BaseOpCode.
The constructor takes only keyword arguments and will set
attributes on this object based on the passed arguments. As such,
it means that you should not pass arguments which are not in the
__slots__ attribute for this class.
"""
slots = self.GetAllSlots()
for (key, value) in kwargs.items():
if key not in slots:
raise TypeError("Object %s doesn't support the parameter '%s'" %
(self.__class__.__name__, key))
setattr(self, key, value)
@classmethod
def GetAllSlots(cls):
"""Compute the list of all declared slots for a class.
"""
slots = []
for parent in cls.__mro__:
slots.extend(getattr(parent, "__slots__", []))
return slots
def Validate(self):
"""Validates the slots.
This method must be implemented by the child classes.
"""
raise NotImplementedError
......@@ -40,6 +40,7 @@ from ganeti import constants
from ganeti import errors
from ganeti import ht
from ganeti import objects
from ganeti import objectutils
# Common opcode attributes
......@@ -342,7 +343,7 @@ _PStorageType = ("storage_type", ht.NoDefault, _CheckStorageType,
"Storage type")
class _AutoOpParamSlots(type):
class _AutoOpParamSlots(objectutils.AutoSlots):
"""Meta class for opcode definitions.
"""
......@@ -356,27 +357,29 @@ class _AutoOpParamSlots(type):
@param attrs: Class attributes
"""
assert "__slots__" not in attrs, \
"Class '%s' defines __slots__ when it should use OP_PARAMS" % name
assert "OP_ID" not in attrs, "Class '%s' defining OP_ID" % name
slots = mcs._GetSlots(attrs)
assert "OP_DSC_FIELD" not in attrs or attrs["OP_DSC_FIELD"] in slots, \
"Class '%s' uses unknown field in OP_DSC_FIELD" % name
attrs["OP_ID"] = _NameToId(name)
return objectutils.AutoSlots.__new__(mcs, name, bases, attrs)
@classmethod
def _GetSlots(mcs, attrs):
"""Build the slots out of OP_PARAMS.
"""
# Always set OP_PARAMS to avoid duplicates in BaseOpCode.GetAllParams
params = attrs.setdefault("OP_PARAMS", [])
# Use parameter names as slots
slots = [pname for (pname, _, _, _) in params]
assert "OP_DSC_FIELD" not in attrs or attrs["OP_DSC_FIELD"] in slots, \
"Class '%s' uses unknown field in OP_DSC_FIELD" % name
attrs["__slots__"] = slots
return type.__new__(mcs, name, bases, attrs)
return [pname for (pname, _, _, _) in params]
class BaseOpCode(object):
class BaseOpCode(objectutils.ValidatedSlots):
"""A simple serializable object.
This object serves as a parent class for OpCode without any custom
......@@ -387,22 +390,6 @@ class BaseOpCode(object):
# as OP_ID is dynamically defined
__metaclass__ = _AutoOpParamSlots
def __init__(self, **kwargs):
"""Constructor for BaseOpCode.
The constructor takes only keyword arguments and will set
attributes on this object based on the passed arguments. As such,
it means that you should not pass arguments which are not in the
__slots__ attribute for this class.
"""
slots = self._all_slots()
for key in kwargs:
if key not in slots:
raise TypeError("Object %s doesn't support the parameter '%s'" %
(self.__class__.__name__, key))
setattr(self, key, kwargs[key])
def __getstate__(self):
"""Generic serializer.
......@@ -414,7 +401,7 @@ class BaseOpCode(object):
"""
state = {}
for name in self._all_slots():
for name in self.GetAllSlots():
if hasattr(self, name):
state[name] = getattr(self, name)
return state
......@@ -433,23 +420,13 @@ class BaseOpCode(object):
raise ValueError("Invalid data to __setstate__: expected dict, got %s" %
type(state))
for name in self._all_slots():
for name in self.GetAllSlots():
if name not in state and hasattr(self, name):
delattr(self, name)
for name in state:
setattr(self, name, state[name])
@classmethod
def _all_slots(cls):
"""Compute the list of all declared slots for a class.
"""
slots = []
for parent in cls.__mro__:
slots.extend(getattr(parent, "__slots__", []))
return slots
@classmethod
def GetAllParams(cls):
"""Compute list of all parameters for an opcode.
......@@ -460,7 +437,7 @@ class BaseOpCode(object):
slots.extend(getattr(parent, "OP_PARAMS", []))
return slots
def Validate(self, set_defaults):
def Validate(self, set_defaults): # pylint: disable=W0221
"""Validate opcode parameters, optionally setting default values.
@type set_defaults: bool
......
#!/usr/bin/python
#
# Copyright (C) 2012 Google Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
# 02110-1301, USA.
"""Script for unittesting the objectutils module"""
import unittest
from ganeti import objectutils
import testutils
class SlotsAutoSlot(objectutils.AutoSlots):
@classmethod
def _GetSlots(mcs, attr):
return attr["SLOTS"]
class AutoSlotted(object):
__metaclass__ = SlotsAutoSlot
SLOTS = ["foo", "bar", "baz"]
class TestAutoSlot(unittest.TestCase):
def test(self):
slotted = AutoSlotted()
self.assertEqual(slotted.__slots__, AutoSlotted.SLOTS)
if __name__ == "__main__":
testutils.GanetiTestProgram()
......@@ -77,7 +77,7 @@ class TestOpcodes(unittest.TestCase):
{"dry_run": False, "debug_level": 0, },
# All variables
dict([(name, False) for name in cls._all_slots()])
dict([(name, False) for name in cls.GetAllSlots()])
]
for i in args:
......@@ -95,7 +95,7 @@ class TestOpcodes(unittest.TestCase):
self._checkSummary(restored)
for name in ["x_y_z", "hello_world"]:
assert name not in cls._all_slots()
assert name not in cls.GetAllSlots()
for value in [None, True, False, [], "Hello World"]:
self.assertRaises(AttributeError, setattr, op, name, value)
......@@ -158,7 +158,7 @@ class TestOpcodes(unittest.TestCase):
self.assertTrue(opcodes.OpCode not in opcodes.OP_MAPPING.values())
for cls in opcodes.OP_MAPPING.values() + [opcodes.OpCode]:
all_slots = cls._all_slots()
all_slots = cls.GetAllSlots()
self.assertEqual(len(set(all_slots) & supported_by_all), 3,
msg=("Opcode %s doesn't support all base"
......
......@@ -95,7 +95,7 @@ class ConfigShell(cmd.Cmd):
if isinstance(obj, objects.ConfigObject):
# pylint: disable=W0212
# yes, we're using a protected member
for name in obj._all_slots():
for name in obj.GetAllSlots():
child = getattr(obj, name, None)
if isinstance(child, (list, dict, tuple, objects.ConfigObject)):
dirs.append(name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment