diff --git a/lib/ssconf.py b/lib/ssconf.py index 7823e6624fe4f3784c8c43631f51870519424904..29db858da4e5cd8968fb21524659fe618dd5d50d 100644 --- a/lib/ssconf.py +++ b/lib/ssconf.py @@ -103,12 +103,14 @@ class SimpleStore(object): - keys are restricted to predefined values """ - def __init__(self, cfg_location=None): + def __init__(self, cfg_location=None, _lockfile=pathutils.SSCONF_LOCK_FILE): if cfg_location is None: self._cfg_dir = pathutils.DATA_DIR else: self._cfg_dir = cfg_location + self._lockfile = _lockfile + def KeyToFilename(self, key): """Convert a given key into filename. @@ -136,14 +138,16 @@ class SimpleStore(object): raise errors.ConfigurationError("Can't read ssconf file %s: %s" % (filename, str(err))) - def WriteFiles(self, values): + def WriteFiles(self, values, dry_run=False): """Writes ssconf files used by external scripts. @type values: dict @param values: Dictionary of (name, value) + @type dry_run boolean + @param dry_run: Whether to perform a dry run """ - ssconf_lock = utils.FileLock.Open(pathutils.SSCONF_LOCK_FILE) + ssconf_lock = utils.FileLock.Open(self._lockfile) # Get lock while writing files ssconf_lock.Exclusive(blocking=True, timeout=SSCONF_LOCK_TIMEOUT) @@ -151,11 +155,15 @@ class SimpleStore(object): for name, value in values.iteritems(): if value and not value.endswith("\n"): value += "\n" + if len(value) > _MAX_SIZE: - raise errors.ConfigurationError("ssconf file %s above maximum size" % - name) + msg = ("Value '%s' has a length of %s bytes, but only up to %s are" + " allowed" % (name, len(value), _MAX_SIZE)) + raise errors.ConfigurationError(msg) + utils.WriteFile(self.KeyToFilename(name), data=value, - mode=constants.SS_FILE_PERMS) + mode=constants.SS_FILE_PERMS, + dry_run=dry_run) finally: ssconf_lock.Unlock() @@ -320,13 +328,13 @@ class SimpleStore(object): " family: %s" % err) -def WriteSsconfFiles(values): +def WriteSsconfFiles(values, dry_run=False): """Update all ssconf files. Wrapper around L{SimpleStore.WriteFiles}. """ - SimpleStore().WriteFiles(values) + SimpleStore().WriteFiles(values, dry_run=dry_run) def GetMasterAndMyself(ss=None): diff --git a/test/ganeti.ssconf_unittest.py b/test/ganeti.ssconf_unittest.py index 958859b35efb6aed3a71bd23413d5a2a3c9b6b69..eb9273545f3d1fbdb5710e80f1b13839429131e3 100755 --- a/test/ganeti.ssconf_unittest.py +++ b/test/ganeti.ssconf_unittest.py @@ -95,11 +95,20 @@ class TestReadSsconfFile(unittest.TestCase): class TestSimpleStore(unittest.TestCase): def setUp(self): - self.tmpdir = tempfile.mkdtemp() - self.sstore = ssconf.SimpleStore(cfg_location=self.tmpdir) + self._tmpdir = tempfile.mkdtemp() + self.ssdir = utils.PathJoin(self._tmpdir, "files") + lockfile = utils.PathJoin(self._tmpdir, "lock") + + os.mkdir(self.ssdir) + + self.sstore = ssconf.SimpleStore(cfg_location=self.ssdir, + _lockfile=lockfile) def tearDown(self): - shutil.rmtree(self.tmpdir) + shutil.rmtree(self._tmpdir) + + def _ReadSsFile(self, filename): + return utils.ReadFile(utils.PathJoin(self.ssdir, "ssconf_%s" % filename)) def testInvalidKey(self): self.assertRaises(errors.ProgrammerError, self.sstore.KeyToFilename, @@ -110,7 +119,7 @@ class TestSimpleStore(unittest.TestCase): def testKeyToFilename(self): for key in ssconf._VALID_KEYS: result = self.sstore.KeyToFilename(key) - self.assertTrue(utils.IsBelowDir(self.tmpdir, result)) + self.assertTrue(utils.IsBelowDir(self.ssdir, result)) self.assertTrue(os.path.basename(result).startswith("ssconf_")) def testReadFileNonExistingFile(self): @@ -140,6 +149,67 @@ class TestSimpleStore(unittest.TestCase): default="something.example.com"), "cluster.example.com") + def testWriteFiles(self): + values = { + constants.SS_CLUSTER_NAME: "cluster.example.com", + constants.SS_CLUSTER_TAGS: "value\nwith\nnewlines\n", + constants.SS_INSTANCE_LIST: "", + } + + self.sstore.WriteFiles(values) + + self.assertEqual(sorted(os.listdir(self.ssdir)), sorted([ + "ssconf_cluster_name", + "ssconf_cluster_tags", + "ssconf_instance_list", + ])) + + self.assertEqual(self._ReadSsFile(constants.SS_CLUSTER_NAME), + "cluster.example.com\n") + self.assertEqual(self._ReadSsFile(constants.SS_CLUSTER_TAGS), + "value\nwith\nnewlines\n") + self.assertEqual(self._ReadSsFile(constants.SS_INSTANCE_LIST), "") + + def testWriteFilesUnknownKey(self): + values = { + "unknown key": "value", + } + + self.assertRaises(errors.ProgrammerError, self.sstore.WriteFiles, + values, dry_run=True) + + self.assertEqual(os.listdir(self.ssdir), []) + + def testWriteFilesDryRun(self): + values = { + constants.SS_CLUSTER_NAME: "cluster.example.com", + } + + self.sstore.WriteFiles(values, dry_run=True) + + self.assertEqual(os.listdir(self.ssdir), []) + + def testWriteFilesNoValues(self): + for dry_run in [False, True]: + self.sstore.WriteFiles({}, dry_run=dry_run) + + self.assertEqual(os.listdir(self.ssdir), []) + + def testWriteFilesTooLong(self): + values = { + constants.SS_INSTANCE_LIST: "A" * ssconf._MAX_SIZE, + } + + for dry_run in [False, True]: + try: + self.sstore.WriteFiles(values, dry_run=dry_run) + except errors.ConfigurationError, err: + self.assertTrue(str(err).startswith("Value 'instance_list' has")) + else: + self.fail("Exception was not raised") + + self.assertEqual(os.listdir(self.ssdir), []) + class TestVerifyClusterName(unittest.TestCase): def setUp(self):