From 07e0896f6bce624453a983b2d44d557181c25b35 Mon Sep 17 00:00:00 2001
From: Michael Hanselmann <hansmi@google.com>
Date: Fri, 18 Mar 2011 13:08:08 +0100
Subject: [PATCH] Split BuildHooksEnv of LUs
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Commit dd7f677623 added another call to BuildHooksEnv to provide
post-phase status variables. Since BuildHooksEnv also built the node
lists, that meant they have to be built twice. First a rather strict
check was used, but it turned out to be more tricky. Commit b423c51336
had to remove the strict check again.

With this patch the function is split in two parts, one generating the
actual environment variables, and another part returning the node lists.
The former is called twice.

Unittests are updated.

Signed-off-by: Michael Hanselmann <hansmi@google.com>
Reviewed-by: RenΓ© Nussbaumer <rn@google.com>
---
 lib/cmdlib.py                 | 343 +++++++++++++++++++++++++---------
 lib/mcpu.py                   |  57 +++---
 test/ganeti.hooks_unittest.py | 104 +++++++++--
 3 files changed, 373 insertions(+), 131 deletions(-)

diff --git a/lib/cmdlib.py b/lib/cmdlib.py
index fa8716e5c..af51ec209 100644
--- a/lib/cmdlib.py
+++ b/lib/cmdlib.py
@@ -83,6 +83,7 @@ class LogicalUnit(object):
     - implement CheckPrereq (except when tasklets are used)
     - implement Exec (except when tasklets are used)
     - implement BuildHooksEnv
+    - implement BuildHooksNodes
     - redefine HPATH and HTYPE
     - optionally redefine their run requirements:
         REQ_BGL: the LU needs to hold the Big Ganeti Lock exclusively
@@ -273,21 +274,28 @@ class LogicalUnit(object):
   def BuildHooksEnv(self):
     """Build hooks environment for this LU.
 
-    This method should return a three-node tuple consisting of: a dict
-    containing the environment that will be used for running the
-    specific hook for this LU, a list of node names on which the hook
-    should run before the execution, and a list of node names on which
-    the hook should run after the execution.
+    @rtype: dict
+    @return: Dictionary containing the environment that will be used for
+      running the hooks for this LU. The keys of the dict must not be prefixed
+      with "GANETI_"--that'll be added by the hooks runner. The hooks runner
+      will extend the environment with additional variables. If no environment
+      should be defined, an empty dictionary should be returned (not C{None}).
+    @note: If the C{HPATH} attribute of the LU class is C{None}, this function
+      will not be called.
 
-    The keys of the dict must not have 'GANETI_' prefixed as this will
-    be handled in the hooks runner. Also note additional keys will be
-    added by the hooks runner. If the LU doesn't define any
-    environment, an empty dict (and not None) should be returned.
+    """
+    raise NotImplementedError
 
-    No nodes should be returned as an empty list (and not None).
+  def BuildHooksNodes(self):
+    """Build list of nodes to run LU's hooks.
 
-    Note that if the HPATH for a LU class is None, this function will
-    not be called.
+    @rtype: tuple; (list, list)
+    @return: Tuple containing a list of node names on which the hook
+      should run before the execution and a list of node names on which the
+      hook should run after the execution. No nodes should be returned as an
+      empty list (and not None).
+    @note: If the C{HPATH} attribute of the LU class is C{None}, this function
+      will not be called.
 
     """
     raise NotImplementedError
@@ -397,7 +405,13 @@ class NoHooksLU(LogicalUnit): # pylint: disable-msg=W0223
     This just raises an error.
 
     """
-    assert False, "BuildHooksEnv called for NoHooksLUs"
+    raise AssertionError("BuildHooksEnv called for NoHooksLUs")
+
+  def BuildHooksNodes(self):
+    """Empty BuildHooksNodes for NoHooksLU.
+
+    """
+    raise AssertionError("BuildHooksNodes called for NoHooksLU")
 
 
 class Tasklet:
@@ -1109,9 +1123,15 @@ class LUClusterPostInit(LogicalUnit):
     """Build hooks env.
 
     """
-    env = {"OP_TARGET": self.cfg.GetClusterName()}
-    mn = self.cfg.GetMasterNode()
-    return env, [], [mn]
+    return {
+      "OP_TARGET": self.cfg.GetClusterName(),
+      }
+
+  def BuildHooksNodes(self):
+    """Build hooks nodes.
+
+    """
+    return ([], [self.cfg.GetMasterNode()])
 
   def Exec(self, feedback_fn):
     """Nothing to do.
@@ -1131,8 +1151,15 @@ class LUClusterDestroy(LogicalUnit):
     """Build hooks env.
 
     """
-    env = {"OP_TARGET": self.cfg.GetClusterName()}
-    return env, [], []
+    return {
+      "OP_TARGET": self.cfg.GetClusterName(),
+      }
+
+  def BuildHooksNodes(self):
+    """Build hooks nodes.
+
+    """
+    return ([], [])
 
   def CheckPrereq(self):
     """Check prerequisites.
@@ -2065,7 +2092,6 @@ class LUClusterVerify(LogicalUnit):
       except errors.GenericError, err:
         self._ErrorIf(True, self.ECLUSTERCFG, None, msg % str(err))
 
-
   def BuildHooksEnv(self):
     """Build hooks env.
 
@@ -2073,14 +2099,22 @@ class LUClusterVerify(LogicalUnit):
     the output be logged in the verify output and the verification to fail.
 
     """
-    all_nodes = self.cfg.GetNodeList()
+    cfg = self.cfg
+
     env = {
-      "CLUSTER_TAGS": " ".join(self.cfg.GetClusterInfo().GetTags())
+      "CLUSTER_TAGS": " ".join(cfg.GetClusterInfo().GetTags())
       }
-    for node in self.cfg.GetAllNodesInfo().values():
-      env["NODE_TAGS_%s" % node.name] = " ".join(node.GetTags())
 
-    return env, [], all_nodes
+    env.update(("NODE_TAGS_%s" % node.name, " ".join(node.GetTags()))
+               for node in cfg.GetAllNodesInfo().values())
+
+    return env
+
+  def BuildHooksNodes(self):
+    """Build hooks nodes.
+
+    """
+    return ([], self.cfg.GetNodeList())
 
   def Exec(self, feedback_fn):
     """Verify integrity of cluster, performing various test on nodes.
@@ -2634,13 +2668,16 @@ class LUClusterRename(LogicalUnit):
     """Build hooks env.
 
     """
-    env = {
+    return {
       "OP_TARGET": self.cfg.GetClusterName(),
       "NEW_NAME": self.op.name,
       }
-    mn = self.cfg.GetMasterNode()
-    all_nodes = self.cfg.GetNodeList()
-    return env, [mn], all_nodes
+
+  def BuildHooksNodes(self):
+    """Build hooks nodes.
+
+    """
+    return ([self.cfg.GetMasterNode()], self.cfg.GetNodeList())
 
   def CheckPrereq(self):
     """Verify that the passed name is a valid one.
@@ -2734,12 +2771,17 @@ class LUClusterSetParams(LogicalUnit):
     """Build hooks env.
 
     """
-    env = {
+    return {
       "OP_TARGET": self.cfg.GetClusterName(),
       "NEW_VG_NAME": self.op.vg_name,
       }
+
+  def BuildHooksNodes(self):
+    """Build hooks nodes.
+
+    """
     mn = self.cfg.GetMasterNode()
-    return env, [mn], [mn]
+    return ([mn], [mn])
 
   def CheckPrereq(self):
     """Check prerequisites.
@@ -3580,17 +3622,22 @@ class LUNodeRemove(LogicalUnit):
     node would then be impossible to remove.
 
     """
-    env = {
+    return {
       "OP_TARGET": self.op.node_name,
       "NODE_NAME": self.op.node_name,
       }
+
+  def BuildHooksNodes(self):
+    """Build hooks nodes.
+
+    """
     all_nodes = self.cfg.GetNodeList()
     try:
       all_nodes.remove(self.op.node_name)
     except ValueError:
-      logging.warning("Node %s which is about to be removed not found"
-                      " in the all nodes list", self.op.node_name)
-    return env, all_nodes, all_nodes
+      logging.warning("Node '%s', which is about to be removed, was not found"
+                      " in the list of all nodes", self.op.node_name)
+    return (all_nodes, all_nodes)
 
   def CheckPrereq(self):
     """Check prerequisites.
@@ -4106,7 +4153,7 @@ class LUNodeAdd(LogicalUnit):
     This will run on all nodes before, and on all nodes + the new node after.
 
     """
-    env = {
+    return {
       "OP_TARGET": self.op.node_name,
       "NODE_NAME": self.op.node_name,
       "NODE_PIP": self.op.primary_ip,
@@ -4115,11 +4162,15 @@ class LUNodeAdd(LogicalUnit):
       "VM_CAPABLE": str(self.op.vm_capable),
       }
 
+  def BuildHooksNodes(self):
+    """Build hooks nodes.
+
+    """
     # Exclude added node
     pre_nodes = list(set(self.cfg.GetNodeList()) - set([self.op.node_name]))
     post_nodes = pre_nodes + [self.op.node_name, ]
 
-    return (env, pre_nodes, post_nodes)
+    return (pre_nodes, post_nodes)
 
   def CheckPrereq(self):
     """Check prerequisites.
@@ -4432,7 +4483,7 @@ class LUNodeSetParams(LogicalUnit):
     This runs on the master node.
 
     """
-    env = {
+    return {
       "OP_TARGET": self.op.node_name,
       "MASTER_CANDIDATE": str(self.op.master_candidate),
       "OFFLINE": str(self.op.offline),
@@ -4440,9 +4491,13 @@ class LUNodeSetParams(LogicalUnit):
       "MASTER_CAPABLE": str(self.op.master_capable),
       "VM_CAPABLE": str(self.op.vm_capable),
       }
-    nl = [self.cfg.GetMasterNode(),
-          self.op.node_name]
-    return env, nl, nl
+
+  def BuildHooksNodes(self):
+    """Build hooks nodes.
+
+    """
+    nl = [self.cfg.GetMasterNode(), self.op.node_name]
+    return (nl, nl)
 
   def CheckPrereq(self):
     """Check prerequisites.
@@ -5136,9 +5191,17 @@ class LUInstanceStartup(LogicalUnit):
     env = {
       "FORCE": self.op.force,
       }
+
     env.update(_BuildInstanceHookEnvByObject(self, self.instance))
+
+    return env
+
+  def BuildHooksNodes(self):
+    """Build hooks nodes.
+
+    """
     nl = [self.cfg.GetMasterNode()] + list(self.instance.all_nodes)
-    return env, nl, nl
+    return (nl, nl)
 
   def CheckPrereq(self):
     """Check prerequisites.
@@ -5233,9 +5296,17 @@ class LUInstanceReboot(LogicalUnit):
       "REBOOT_TYPE": self.op.reboot_type,
       "SHUTDOWN_TIMEOUT": self.op.shutdown_timeout,
       }
+
     env.update(_BuildInstanceHookEnvByObject(self, self.instance))
+
+    return env
+
+  def BuildHooksNodes(self):
+    """Build hooks nodes.
+
+    """
     nl = [self.cfg.GetMasterNode()] + list(self.instance.all_nodes)
-    return env, nl, nl
+    return (nl, nl)
 
   def CheckPrereq(self):
     """Check prerequisites.
@@ -5315,8 +5386,14 @@ class LUInstanceShutdown(LogicalUnit):
     """
     env = _BuildInstanceHookEnvByObject(self, self.instance)
     env["TIMEOUT"] = self.op.timeout
+    return env
+
+  def BuildHooksNodes(self):
+    """Build hooks nodes.
+
+    """
     nl = [self.cfg.GetMasterNode()] + list(self.instance.all_nodes)
-    return env, nl, nl
+    return (nl, nl)
 
   def CheckPrereq(self):
     """Check prerequisites.
@@ -5375,9 +5452,14 @@ class LUInstanceReinstall(LogicalUnit):
     This runs on master, primary and secondary nodes of the instance.
 
     """
-    env = _BuildInstanceHookEnvByObject(self, self.instance)
+    return _BuildInstanceHookEnvByObject(self, self.instance)
+
+  def BuildHooksNodes(self):
+    """Build hooks nodes.
+
+    """
     nl = [self.cfg.GetMasterNode()] + list(self.instance.all_nodes)
-    return env, nl, nl
+    return (nl, nl)
 
   def CheckPrereq(self):
     """Check prerequisites.
@@ -5461,9 +5543,14 @@ class LUInstanceRecreateDisks(LogicalUnit):
     This runs on master, primary and secondary nodes of the instance.
 
     """
-    env = _BuildInstanceHookEnvByObject(self, self.instance)
+    return _BuildInstanceHookEnvByObject(self, self.instance)
+
+  def BuildHooksNodes(self):
+    """Build hooks nodes.
+
+    """
     nl = [self.cfg.GetMasterNode()] + list(self.instance.all_nodes)
-    return env, nl, nl
+    return (nl, nl)
 
   def CheckPrereq(self):
     """Check prerequisites.
@@ -5528,8 +5615,14 @@ class LUInstanceRename(LogicalUnit):
     """
     env = _BuildInstanceHookEnvByObject(self, self.instance)
     env["INSTANCE_NEW_NAME"] = self.op.new_name
+    return env
+
+  def BuildHooksNodes(self):
+    """Build hooks nodes.
+
+    """
     nl = [self.cfg.GetMasterNode()] + list(self.instance.all_nodes)
-    return env, nl, nl
+    return (nl, nl)
 
   def CheckPrereq(self):
     """Check prerequisites.
@@ -5639,9 +5732,15 @@ class LUInstanceRemove(LogicalUnit):
     """
     env = _BuildInstanceHookEnvByObject(self, self.instance)
     env["SHUTDOWN_TIMEOUT"] = self.op.shutdown_timeout
+    return env
+
+  def BuildHooksNodes(self):
+    """Build hooks nodes.
+
+    """
     nl = [self.cfg.GetMasterNode()]
     nl_post = list(self.instance.all_nodes) + nl
-    return env, nl, nl_post
+    return (nl, nl_post)
 
   def CheckPrereq(self):
     """Check prerequisites.
@@ -5777,10 +5876,15 @@ class LUInstanceFailover(LogicalUnit):
       env["OLD_SECONDARY"] = env["NEW_SECONDARY"] = ""
 
     env.update(_BuildInstanceHookEnvByObject(self, instance))
-    nl = [self.cfg.GetMasterNode()] + list(instance.secondary_nodes)
-    nl_post = list(nl)
-    nl_post.append(source_node)
-    return env, nl, nl_post
+
+    return env
+
+  def BuildHooksNodes(self):
+    """Build hooks nodes.
+
+    """
+    nl = [self.cfg.GetMasterNode()] + list(self.instance.secondary_nodes)
+    return (nl, nl + [self.instance.primary_node])
 
   def CheckPrereq(self):
     """Check prerequisites.
@@ -5994,12 +6098,12 @@ class LUInstanceMigrate(LogicalUnit):
     source_node = instance.primary_node
     target_node = self._migrater.target_node
     env = _BuildInstanceHookEnvByObject(self, instance)
-    env["MIGRATE_LIVE"] = self._migrater.live
-    env["MIGRATE_CLEANUP"] = self.op.cleanup
     env.update({
-        "OLD_PRIMARY": source_node,
-        "NEW_PRIMARY": target_node,
-        })
+      "MIGRATE_LIVE": self._migrater.live,
+      "MIGRATE_CLEANUP": self.op.cleanup,
+      "OLD_PRIMARY": source_node,
+      "NEW_PRIMARY": target_node,
+      })
 
     if instance.disk_template in constants.DTS_INT_MIRROR:
       env["OLD_SECONDARY"] = target_node
@@ -6007,10 +6111,15 @@ class LUInstanceMigrate(LogicalUnit):
     else:
       env["OLD_SECONDARY"] = env["NEW_SECONDARY"] = None
 
+    return env
+
+  def BuildHooksNodes(self):
+    """Build hooks nodes.
+
+    """
+    instance = self._migrater.instance
     nl = [self.cfg.GetMasterNode()] + list(instance.secondary_nodes)
-    nl_post = list(nl)
-    nl_post.append(source_node)
-    return env, nl, nl_post
+    return (nl, nl + [instance.primary_node])
 
 
 class LUInstanceMove(LogicalUnit):
@@ -6043,9 +6152,18 @@ class LUInstanceMove(LogicalUnit):
       "SHUTDOWN_TIMEOUT": self.op.shutdown_timeout,
       }
     env.update(_BuildInstanceHookEnvByObject(self, self.instance))
-    nl = [self.cfg.GetMasterNode()] + [self.instance.primary_node,
-                                       self.op.target_node]
-    return env, nl, nl
+    return env
+
+  def BuildHooksNodes(self):
+    """Build hooks nodes.
+
+    """
+    nl = [
+      self.cfg.GetMasterNode(),
+      self.instance.primary_node,
+      self.op.target_node,
+      ]
+    return (nl, nl)
 
   def CheckPrereq(self):
     """Check prerequisites.
@@ -6244,13 +6362,16 @@ class LUNodeMigrate(LogicalUnit):
     This runs on the master, the primary and all the secondaries.
 
     """
-    env = {
+    return {
       "NODE_NAME": self.op.node_name,
       }
 
-    nl = [self.cfg.GetMasterNode()]
+  def BuildHooksNodes(self):
+    """Build hooks nodes.
 
-    return (env, nl, nl)
+    """
+    nl = [self.cfg.GetMasterNode()]
+    return (nl, nl)
 
 
 class TLMigrateInstance(Tasklet):
@@ -7469,9 +7590,14 @@ class LUInstanceCreate(LogicalUnit):
       hypervisor_name=self.op.hypervisor,
     ))
 
-    nl = ([self.cfg.GetMasterNode(), self.op.pnode] +
-          self.secondaries)
-    return env, nl, nl
+    return env
+
+  def BuildHooksNodes(self):
+    """Build hooks nodes.
+
+    """
+    nl = [self.cfg.GetMasterNode(), self.op.pnode] + self.secondaries
+    return nl, nl
 
   def _ReadExportInfo(self):
     """Reads the export information from disk.
@@ -8258,13 +8384,20 @@ class LUInstanceReplaceDisks(LogicalUnit):
       "OLD_SECONDARY": instance.secondary_nodes[0],
       }
     env.update(_BuildInstanceHookEnvByObject(self, instance))
+    return env
+
+  def BuildHooksNodes(self):
+    """Build hooks nodes.
+
+    """
+    instance = self.replacer.instance
     nl = [
       self.cfg.GetMasterNode(),
       instance.primary_node,
       ]
     if self.op.remote_node is not None:
       nl.append(self.op.remote_node)
-    return env, nl, nl
+    return nl, nl
 
 
 class TLReplaceDisks(Tasklet):
@@ -9099,8 +9232,14 @@ class LUInstanceGrowDisk(LogicalUnit):
       "AMOUNT": self.op.amount,
       }
     env.update(_BuildInstanceHookEnvByObject(self, self.instance))
+    return env
+
+  def BuildHooksNodes(self):
+    """Build hooks nodes.
+
+    """
     nl = [self.cfg.GetMasterNode()] + list(self.instance.all_nodes)
-    return env, nl, nl
+    return (nl, nl)
 
   def CheckPrereq(self):
     """Check prerequisites.
@@ -9507,8 +9646,15 @@ class LUInstanceSetParams(LogicalUnit):
     env = _BuildInstanceHookEnvByObject(self, self.instance, override=args)
     if self.op.disk_template:
       env["NEW_DISK_TEMPLATE"] = self.op.disk_template
+
+    return env
+
+  def BuildHooksNodes(self):
+    """Build hooks nodes.
+
+    """
     nl = [self.cfg.GetMasterNode()] + list(self.instance.all_nodes)
-    return env, nl, nl
+    return (nl, nl)
 
   def CheckPrereq(self):
     """Check prerequisites.
@@ -10125,12 +10271,18 @@ class LUBackupExport(LogicalUnit):
 
     env.update(_BuildInstanceHookEnvByObject(self, self.instance))
 
+    return env
+
+  def BuildHooksNodes(self):
+    """Build hooks nodes.
+
+    """
     nl = [self.cfg.GetMasterNode(), self.instance.primary_node]
 
     if self.op.mode == constants.EXPORT_MODE_LOCAL:
       nl.append(self.op.target_node)
 
-    return env, nl, nl
+    return (nl, nl)
 
   def CheckPrereq(self):
     """Check prerequisites.
@@ -10440,11 +10592,16 @@ class LUGroupAdd(LogicalUnit):
     """Build hooks env.
 
     """
-    env = {
+    return {
       "GROUP_NAME": self.op.group_name,
       }
+
+  def BuildHooksNodes(self):
+    """Build hooks nodes.
+
+    """
     mn = self.cfg.GetMasterNode()
-    return env, [mn], [mn]
+    return ([mn], [mn])
 
   def Exec(self, feedback_fn):
     """Add the node group to the cluster.
@@ -10711,12 +10868,17 @@ class LUGroupSetParams(LogicalUnit):
     """Build hooks env.
 
     """
-    env = {
+    return {
       "GROUP_NAME": self.op.group_name,
       "NEW_ALLOC_POLICY": self.op.alloc_policy,
       }
+
+  def BuildHooksNodes(self):
+    """Build hooks nodes.
+
+    """
     mn = self.cfg.GetMasterNode()
-    return env, [mn], [mn]
+    return ([mn], [mn])
 
   def Exec(self, feedback_fn):
     """Modifies the node group.
@@ -10779,11 +10941,16 @@ class LUGroupRemove(LogicalUnit):
     """Build hooks env.
 
     """
-    env = {
+    return {
       "GROUP_NAME": self.op.group_name,
       }
+
+  def BuildHooksNodes(self):
+    """Build hooks nodes.
+
+    """
     mn = self.cfg.GetMasterNode()
-    return env, [mn], [mn]
+    return ([mn], [mn])
 
   def Exec(self, feedback_fn):
     """Remove the node group.
@@ -10831,21 +10998,25 @@ class LUGroupRename(LogicalUnit):
     """Build hooks env.
 
     """
-    env = {
+    return {
       "OLD_NAME": self.op.group_name,
       "NEW_NAME": self.op.new_name,
       }
 
+  def BuildHooksNodes(self):
+    """Build hooks nodes.
+
+    """
     mn = self.cfg.GetMasterNode()
+
     all_nodes = self.cfg.GetAllNodesInfo()
-    run_nodes = [mn]
     all_nodes.pop(mn, None)
 
-    for node in all_nodes.values():
-      if node.group == self.group_uuid:
-        run_nodes.append(node.name)
+    run_nodes = [mn]
+    run_nodes.extend(node.name for node in all_nodes.values()
+                     if node.group == self.group_uuid)
 
-    return env, run_nodes, run_nodes
+    return (run_nodes, run_nodes)
 
   def Exec(self, feedback_fn):
     """Rename the node group.
diff --git a/lib/mcpu.py b/lib/mcpu.py
index 399fcbe63..863cf9aa8 100644
--- a/lib/mcpu.py
+++ b/lib/mcpu.py
@@ -427,8 +427,14 @@ class HooksMaster(object):
     self.callfn = callfn
     self.lu = lu
     self.op = lu.op
-    self.pre_env = None
-    self.post_nodes = None
+    self.pre_env = self._BuildEnv(constants.HOOKS_PHASE_PRE)
+
+    if self.lu.HPATH is None:
+      nodes = (None, None)
+    else:
+      nodes = map(frozenset, self.lu.BuildHooksNodes())
+
+    (self.pre_nodes, self.post_nodes) = nodes
 
   def _BuildEnv(self, phase):
     """Compute the environment and the target nodes.
@@ -447,34 +453,29 @@ class HooksMaster(object):
     env = {}
 
     if self.lu.HPATH is not None:
-      (lu_env, lu_nodes_pre, lu_nodes_post) = self.lu.BuildHooksEnv()
+      lu_env = self.lu.BuildHooksEnv()
       if lu_env:
-        assert not compat.any(key.upper().startswith(prefix)
-                              for key in lu_env)
+        assert not compat.any(key.upper().startswith(prefix) for key in lu_env)
         env.update(("%s%s" % (prefix, key), value)
                    for (key, value) in lu_env.items())
-    else:
-      lu_nodes_pre = lu_nodes_post = []
 
     if phase == constants.HOOKS_PHASE_PRE:
       assert compat.all((key.startswith("GANETI_") and
                          not key.startswith("GANETI_POST_"))
                         for key in env)
 
-      # Record environment for any post-phase hooks
-      self.pre_env = env
-
     elif phase == constants.HOOKS_PHASE_POST:
       assert compat.all(key.startswith("GANETI_POST_") for key in env)
+      assert isinstance(self.pre_env, dict)
 
-      if self.pre_env:
-        assert not compat.any(key.startswith("GANETI_POST_")
-                              for key in self.pre_env)
-        env.update(self.pre_env)
+      # Merge with pre-phase environment
+      assert not compat.any(key.startswith("GANETI_POST_")
+                            for key in self.pre_env)
+      env.update(self.pre_env)
     else:
       raise AssertionError("Unknown phase '%s'" % phase)
 
-    return env, frozenset(lu_nodes_pre), frozenset(lu_nodes_post)
+    return env
 
   def _RunWrapper(self, node_list, hpath, phase, phase_env):
     """Simple wrapper over self.callfn.
@@ -488,12 +489,14 @@ class HooksMaster(object):
       "PATH": "/sbin:/bin:/usr/sbin:/usr/bin",
       "GANETI_HOOKS_VERSION": constants.HOOKS_VERSION,
       "GANETI_OP_CODE": self.op.OP_ID,
-      "GANETI_OBJECT_TYPE": self.lu.HTYPE,
       "GANETI_DATA_DIR": constants.DATA_DIR,
       "GANETI_HOOKS_PHASE": phase,
       "GANETI_HOOKS_PATH": hpath,
       }
 
+    if self.lu.HTYPE:
+      env["GANETI_OBJECT_TYPE"] = self.lu.HTYPE
+
     if cfg is not None:
       env["GANETI_CLUSTER"] = cfg.GetClusterName()
       env["GANETI_MASTER"] = cfg.GetMasterNode()
@@ -523,17 +526,16 @@ class HooksMaster(object):
     @raise errors.HooksAbort: on failure of one of the hooks
 
     """
-    (env, node_list_pre, node_list_post) = self._BuildEnv(phase)
-    if nodes is None:
-      if phase == constants.HOOKS_PHASE_PRE:
-        nodes = node_list_pre
-        self.post_nodes = node_list_post
-      elif self.post_nodes is None:
-        raise AssertionError("Pre-phase must be run before post-phase")
-      elif phase == constants.HOOKS_PHASE_POST:
+    if phase == constants.HOOKS_PHASE_PRE:
+      if nodes is None:
+        nodes = self.pre_nodes
+      env = self.pre_env
+    elif phase == constants.HOOKS_PHASE_POST:
+      if nodes is None:
         nodes = self.post_nodes
-      else:
-        raise AssertionError("Unknown phase '%s'" % phase)
+      env = self._BuildEnv(phase)
+    else:
+      raise AssertionError("Unknown phase '%s'" % phase)
 
     if not nodes:
       # empty node list, we should not attempt to run this as either
@@ -584,9 +586,6 @@ class HooksMaster(object):
     top-level LI if the configuration has been updated.
 
     """
-    if self.pre_env is None:
-      raise AssertionError("Pre-phase must be run before configuration update")
-
     phase = constants.HOOKS_PHASE_POST
     hpath = constants.HOOKS_NAME_CFGUPDATE
     nodes = [self.lu.cfg.GetMasterNode()]
diff --git a/test/ganeti.hooks_unittest.py b/test/ganeti.hooks_unittest.py
index 30836cd67..abbb7c1ba 100755
--- a/test/ganeti.hooks_unittest.py
+++ b/test/ganeti.hooks_unittest.py
@@ -45,8 +45,12 @@ import testutils
 
 class FakeLU(cmdlib.LogicalUnit):
   HPATH = "test"
+
   def BuildHooksEnv(self):
-    return {}, ["localhost"], ["localhost"]
+    return {}
+
+  def BuildHooksNodes(self):
+    return ["localhost"], ["localhost"]
 
 
 class TestHooksRunner(unittest.TestCase):
@@ -282,8 +286,14 @@ class FakeEnvLU(cmdlib.LogicalUnit):
 
   def BuildHooksEnv(self):
     assert self.hook_env is not None
+    return self.hook_env
+
+  def BuildHooksNodes(self):
+    return (["localhost"], ["localhost"])
+
 
-    return self.hook_env, ["localhost"], ["localhost"]
+class FakeNoHooksLU(cmdlib.NoHooksLU):
+  pass
 
 
 class TestHooksRunnerEnv(unittest.TestCase):
@@ -292,7 +302,6 @@ class TestHooksRunnerEnv(unittest.TestCase):
 
     self.op = opcodes.OpTestDummy(result=False, messages=[], fail=False)
     self.lu = FakeEnvLU(FakeProc(), self.op, FakeContext(), None)
-    self.hm = mcpu.HooksMaster(self._HooksRpc, self.lu)
 
   def _HooksRpc(self, *args):
     self._rpcs.append(args)
@@ -303,14 +312,18 @@ class TestHooksRunnerEnv(unittest.TestCase):
     self.assertEqual(env["GANETI_HOOKS_PHASE"], phase)
     self.assertEqual(env["GANETI_HOOKS_PATH"], hpath)
     self.assertEqual(env["GANETI_OP_CODE"], self.op.OP_ID)
-    self.assertEqual(env["GANETI_OBJECT_TYPE"], constants.HTYPE_GROUP)
     self.assertEqual(env["GANETI_HOOKS_VERSION"], str(constants.HOOKS_VERSION))
     self.assertEqual(env["GANETI_DATA_DIR"], constants.DATA_DIR)
+    if "GANETI_OBJECT_TYPE" in env:
+      self.assertEqual(env["GANETI_OBJECT_TYPE"], constants.HTYPE_GROUP)
+    else:
+      self.assertTrue(self.lu.HTYPE is None)
 
   def testEmptyEnv(self):
     # Check pre-phase hook
     self.lu.hook_env = {}
-    self.hm.RunPhase(constants.HOOKS_PHASE_PRE)
+    hm = mcpu.HooksMaster(self._HooksRpc, self.lu)
+    hm.RunPhase(constants.HOOKS_PHASE_PRE)
 
     (node_list, hpath, phase, env) = self._rpcs.pop(0)
     self.assertEqual(node_list, set(["localhost"]))
@@ -320,7 +333,7 @@ class TestHooksRunnerEnv(unittest.TestCase):
 
     # Check post-phase hook
     self.lu.hook_env = {}
-    self.hm.RunPhase(constants.HOOKS_PHASE_POST)
+    hm.RunPhase(constants.HOOKS_PHASE_POST)
 
     (node_list, hpath, phase, env) = self._rpcs.pop(0)
     self.assertEqual(node_list, set(["localhost"]))
@@ -335,7 +348,8 @@ class TestHooksRunnerEnv(unittest.TestCase):
     self.lu.hook_env = {
       "FOO": "pre-foo-value",
       }
-    self.hm.RunPhase(constants.HOOKS_PHASE_PRE)
+    hm = mcpu.HooksMaster(self._HooksRpc, self.lu)
+    hm.RunPhase(constants.HOOKS_PHASE_PRE)
 
     (node_list, hpath, phase, env) = self._rpcs.pop(0)
     self.assertEqual(node_list, set(["localhost"]))
@@ -350,7 +364,7 @@ class TestHooksRunnerEnv(unittest.TestCase):
       "FOO": "post-value",
       "BAR": 123,
       }
-    self.hm.RunPhase(constants.HOOKS_PHASE_POST)
+    hm.RunPhase(constants.HOOKS_PHASE_POST)
 
     (node_list, hpath, phase, env) = self._rpcs.pop(0)
     self.assertEqual(node_list, set(["localhost"]))
@@ -365,7 +379,7 @@ class TestHooksRunnerEnv(unittest.TestCase):
     self.assertRaises(IndexError, self._rpcs.pop)
 
     # Check configuration update hook
-    self.hm.RunConfigUpdate()
+    hm.RunConfigUpdate()
     (node_list, hpath, phase, env) = self._rpcs.pop(0)
     self.assertEqual(set(node_list), set([self.lu.cfg.GetMasterNode()]))
     self.assertEqual(hpath, constants.HOOKS_NAME_CFGUPDATE)
@@ -389,7 +403,8 @@ class TestHooksRunnerEnv(unittest.TestCase):
 
   def testNoNodes(self):
     self.lu.hook_env = {}
-    self.hm.RunPhase(constants.HOOKS_PHASE_PRE, nodes=[])
+    hm = mcpu.HooksMaster(self._HooksRpc, self.lu)
+    hm.RunPhase(constants.HOOKS_PHASE_PRE, nodes=[])
     self.assertRaises(IndexError, self._rpcs.pop)
 
   def testSpecificNodes(self):
@@ -400,8 +415,10 @@ class TestHooksRunnerEnv(unittest.TestCase):
       "node93782.example.net",
       ]
 
+    hm = mcpu.HooksMaster(self._HooksRpc, self.lu)
+
     for phase in [constants.HOOKS_PHASE_PRE, constants.HOOKS_PHASE_POST]:
-      self.hm.RunPhase(phase, nodes=nodes)
+      hm.RunPhase(phase, nodes=nodes)
 
       (node_list, hpath, rpc_phase, env) = self._rpcs.pop(0)
       self.assertEqual(set(node_list), set(nodes))
@@ -412,14 +429,69 @@ class TestHooksRunnerEnv(unittest.TestCase):
       self.assertRaises(IndexError, self._rpcs.pop)
 
   def testRunConfigUpdateNoPre(self):
-    self.lu.hook_env = {}
-    self.assertRaises(AssertionError, self.hm.RunConfigUpdate)
+    self.lu.hook_env = {
+      "FOO": "value",
+      }
+
+    hm = mcpu.HooksMaster(self._HooksRpc, self.lu)
+    hm.RunConfigUpdate()
+
+    (node_list, hpath, phase, env) = self._rpcs.pop(0)
+    self.assertEqual(set(node_list), set([self.lu.cfg.GetMasterNode()]))
+    self.assertEqual(hpath, constants.HOOKS_NAME_CFGUPDATE)
+    self.assertEqual(phase, constants.HOOKS_PHASE_POST)
+    self.assertEqual(env["GANETI_FOO"], "value")
+    self.assertFalse(compat.any(key.startswith("GANETI_POST") for key in env))
+    self._CheckEnv(env, constants.HOOKS_PHASE_POST,
+                   constants.HOOKS_NAME_CFGUPDATE)
+
     self.assertRaises(IndexError, self._rpcs.pop)
 
   def testNoPreBeforePost(self):
-    self.lu.hook_env = {}
-    self.assertRaises(AssertionError, self.hm.RunPhase,
-                      constants.HOOKS_PHASE_POST)
+    self.lu.hook_env = {
+      "FOO": "value",
+      }
+
+    hm = mcpu.HooksMaster(self._HooksRpc, self.lu)
+    hm.RunPhase(constants.HOOKS_PHASE_POST)
+
+    (node_list, hpath, phase, env) = self._rpcs.pop(0)
+    self.assertEqual(node_list, set(["localhost"]))
+    self.assertEqual(hpath, self.lu.HPATH)
+    self.assertEqual(phase, constants.HOOKS_PHASE_POST)
+    self.assertEqual(env["GANETI_FOO"], "value")
+    self.assertEqual(env["GANETI_POST_FOO"], "value")
+    self._CheckEnv(env, constants.HOOKS_PHASE_POST, self.lu.HPATH)
+
+    self.assertRaises(IndexError, self._rpcs.pop)
+
+  def testNoHooksLU(self):
+    self.lu = FakeNoHooksLU(FakeProc(), self.op, FakeContext(), None)
+    self.assertRaises(AssertionError, self.lu.BuildHooksEnv)
+    self.assertRaises(AssertionError, self.lu.BuildHooksNodes)
+
+    hm = mcpu.HooksMaster(self._HooksRpc, self.lu)
+    self.assertEqual(hm.pre_env, {})
+    self.assertRaises(IndexError, self._rpcs.pop)
+
+    hm.RunPhase(constants.HOOKS_PHASE_PRE)
+    self.assertRaises(IndexError, self._rpcs.pop)
+
+    hm.RunPhase(constants.HOOKS_PHASE_POST)
+    self.assertRaises(IndexError, self._rpcs.pop)
+
+    hm.RunConfigUpdate()
+
+    (node_list, hpath, phase, env) = self._rpcs.pop(0)
+    self.assertEqual(set(node_list), set([self.lu.cfg.GetMasterNode()]))
+    self.assertEqual(hpath, constants.HOOKS_NAME_CFGUPDATE)
+    self.assertEqual(phase, constants.HOOKS_PHASE_POST)
+    self.assertFalse(compat.any(key.startswith("GANETI_POST") for key in env))
+    self._CheckEnv(env, constants.HOOKS_PHASE_POST,
+                   constants.HOOKS_NAME_CFGUPDATE)
+    self.assertRaises(IndexError, self._rpcs.pop)
+
+    assert isinstance(self.lu, FakeNoHooksLU), "LU was replaced"
 
 
 if __name__ == '__main__':
-- 
GitLab