diff --git a/lib/ssconf.py b/lib/ssconf.py index ad5969c7f2482d899473957318ae493ceaffa5a5..7e34b5d3b5b9a19effd79275f363370e219e5c79 100644 --- a/lib/ssconf.py +++ b/lib/ssconf.py @@ -28,6 +28,7 @@ configuration data, which is mostly static and available to all nodes. import sys import errno +import logging from ganeti import errors from ganeti import constants @@ -368,3 +369,21 @@ def CheckMaster(debug, ss=None): if debug: sys.stderr.write("Not master, exiting.\n") sys.exit(constants.EXIT_NOTMASTER) + + +def VerifyClusterName(name, _cfg_location=None): + """Verifies cluster name against a local cluster name. + + @type name: string + @param name: Cluster name + + """ + sstore = SimpleStore(cfg_location=_cfg_location) + + try: + local_name = sstore.GetClusterName() + except errors.ConfigurationError, err: + logging.debug("Can't get local cluster name: %s", err) + else: + if name != local_name: + raise errors.GenericError("Current cluster name is '%s'" % local_name) diff --git a/lib/tools/prepare_node_join.py b/lib/tools/prepare_node_join.py index b88e02e7e8cafc83516caef2ba8b1521b95066eb..2441785e77d50e5efe7e793fd65d8af43c513f6e 100644 --- a/lib/tools/prepare_node_join.py +++ b/lib/tools/prepare_node_join.py @@ -150,31 +150,7 @@ def VerifyCertificate(data, _verify_fn=_VerifyCertificate): _verify_fn(cert) -def _VerifyClusterName(name, _ss_cluster_name_file=None): - """Verifies cluster name against a local cluster name. - - @type name: string - @param name: Cluster name - - """ - if _ss_cluster_name_file is None: - _ss_cluster_name_file = \ - ssconf.SimpleStore().KeyToFilename(constants.SS_CLUSTER_NAME) - - try: - local_name = utils.ReadOneLineFile(_ss_cluster_name_file) - except EnvironmentError, err: - if err.errno != errno.ENOENT: - raise - - logging.debug("Local cluster name was not found (file %s)", - _ss_cluster_name_file) - else: - if name != local_name: - raise JoinError("Current cluster name is '%s'" % local_name) - - -def VerifyClusterName(data, _verify_fn=_VerifyClusterName): +def VerifyClusterName(data, _verify_fn=ssconf.VerifyClusterName): """Verifies cluster name. @type data: dict diff --git a/test/ganeti.ssconf_unittest.py b/test/ganeti.ssconf_unittest.py index 86d93be415c8c9d391aadca13007373d518bacd8..1db88e8896c2be36960044e4bed2bddc59a70aac 100755 --- a/test/ganeti.ssconf_unittest.py +++ b/test/ganeti.ssconf_unittest.py @@ -141,5 +141,33 @@ class TestSimpleStore(unittest.TestCase): "cluster.example.com") +class TestVerifyClusterName(unittest.TestCase): + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tmpdir) + + def testMissingFile(self): + tmploc = utils.PathJoin(self.tmpdir, "does-not-exist") + ssconf.VerifyClusterName(NotImplemented, _cfg_location=tmploc) + + def testMatchingName(self): + tmpfile = utils.PathJoin(self.tmpdir, "ssconf_cluster_name") + + for content in ["cluster.example.com", "cluster.example.com\n\n"]: + utils.WriteFile(tmpfile, data=content) + ssconf.VerifyClusterName("cluster.example.com", + _cfg_location=self.tmpdir) + + def testNameMismatch(self): + tmpfile = utils.PathJoin(self.tmpdir, "ssconf_cluster_name") + + for content in ["something.example.com", "foobar\n\ncluster.example.com"]: + utils.WriteFile(tmpfile, data=content) + self.assertRaises(errors.GenericError, ssconf.VerifyClusterName, + "cluster.example.com", _cfg_location=self.tmpdir) + + if __name__ == "__main__": testutils.GanetiTestProgram() diff --git a/test/ganeti.tools.prepare_node_join_unittest.py b/test/ganeti.tools.prepare_node_join_unittest.py index 1cda5d2174f2a8b218424990832d7a3990e686d2..c014280f315480384a20865b9682c96ab4b2030b 100755 --- a/test/ganeti.tools.prepare_node_join_unittest.py +++ b/test/ganeti.tools.prepare_node_join_unittest.py @@ -130,26 +130,18 @@ class TestVerifyClusterName(unittest.TestCase): self.assertRaises(_JoinError, prepare_node_join.VerifyClusterName, {}, _verify_fn=NotImplemented) - def testMissingFile(self): - tmpfile = utils.PathJoin(self.tmpdir, "does-not-exist") - prepare_node_join._VerifyClusterName(NotImplemented, - _ss_cluster_name_file=tmpfile) - - def testMatchingName(self): - tmpfile = utils.PathJoin(self.tmpdir, "cluster_name") - - for content in ["cluster.example.com", "cluster.example.com\n\n"]: - utils.WriteFile(tmpfile, data=content) - prepare_node_join._VerifyClusterName("cluster.example.com", - _ss_cluster_name_file=tmpfile) + @staticmethod + def _FailingVerify(name): + assert name == "cluster.example.com" + raise errors.GenericError() - def testNameMismatch(self): - tmpfile = utils.PathJoin(self.tmpdir, "cluster_name") + def testFailingVerification(self): + data = { + constants.SSHS_CLUSTER_NAME: "cluster.example.com", + } - for content in ["something.example.com", "foobar\n\ncluster.example.com"]: - utils.WriteFile(tmpfile, data=content) - self.assertRaises(_JoinError, prepare_node_join._VerifyClusterName, - "cluster.example.com", _ss_cluster_name_file=tmpfile) + self.assertRaises(errors.GenericError, prepare_node_join.VerifyClusterName, + data, _verify_fn=self._FailingVerify) class TestUpdateSshDaemon(unittest.TestCase):