diff --git a/Makefile.am b/Makefile.am index afc106c9a3717c736b4c3b0a64b2518dad890678..67475f37dafb6c62d6d85c9131711d5f1d5ae26b 100644 --- a/Makefile.am +++ b/Makefile.am @@ -958,6 +958,7 @@ python_tests = \ test/ganeti.runtime_unittest.py \ test/ganeti.serializer_unittest.py \ test/ganeti.server.rapi_unittest.py \ + test/ganeti.ssconf_unittest.py \ test/ganeti.ssh_unittest.py \ test/ganeti.storage_unittest.py \ test/ganeti.tools.ensure_dirs_unittest.py \ diff --git a/lib/ssconf.py b/lib/ssconf.py index 30c9d784f50cc2a57281a738d1bb25e09fbf42e2..ad5969c7f2482d899473957318ae493ceaffa5a5 100644 --- a/lib/ssconf.py +++ b/lib/ssconf.py @@ -69,6 +69,28 @@ _VALID_KEYS = frozenset([ _MAX_SIZE = 128 * 1024 +def ReadSsconfFile(filename): + """Reads an ssconf file and verifies its size. + + @type filename: string + @param filename: Path to file + @rtype: string + @return: File contents without newlines at the end + @raise RuntimeError: When the file size exceeds L{_MAX_SIZE} + + """ + statcb = utils.FileStatHelper() + + data = utils.ReadFile(filename, size=_MAX_SIZE, preread=statcb) + + if statcb.st.st_size > _MAX_SIZE: + msg = ("File '%s' has a size of %s bytes (up to %s allowed)" % + (filename, statcb.st.st_size, _MAX_SIZE)) + raise RuntimeError(msg) + + return data.rstrip("\n") + + class SimpleStore(object): """Interface to static cluster data. @@ -106,15 +128,13 @@ class SimpleStore(object): """ filename = self.KeyToFilename(key) try: - data = utils.ReadFile(filename, size=_MAX_SIZE) + return ReadSsconfFile(filename) except EnvironmentError, err: if err.errno == errno.ENOENT and default is not None: return default raise errors.ConfigurationError("Can't read ssconf file %s: %s" % (filename, str(err))) - return data.rstrip("\n") - def WriteFiles(self, values): """Writes ssconf files used by external scripts. diff --git a/test/ganeti.ssconf_unittest.py b/test/ganeti.ssconf_unittest.py new file mode 100755 index 0000000000000000000000000000000000000000..86d93be415c8c9d391aadca13007373d518bacd8 --- /dev/null +++ b/test/ganeti.ssconf_unittest.py @@ -0,0 +1,145 @@ +#!/usr/bin/python +# + +# Copyright (C) 2012 Google Inc. +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +# 02110-1301, USA. + + +"""Script for testing ganeti.ssconf""" + +import os +import unittest +import tempfile +import shutil +import errno + +from ganeti import utils +from ganeti import constants +from ganeti import errors +from ganeti import ssconf + +import testutils + + +class TestReadSsconfFile(unittest.TestCase): + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tmpdir) + + def testReadDirectory(self): + self.assertRaises(EnvironmentError, ssconf.ReadSsconfFile, self.tmpdir) + + def testNonExistantFile(self): + testfile = utils.PathJoin(self.tmpdir, "does.not.exist") + + self.assertFalse(os.path.exists(testfile)) + + try: + ssconf.ReadSsconfFile(testfile) + except EnvironmentError, err: + self.assertEqual(err.errno, errno.ENOENT) + else: + self.fail("Exception was not raised") + + def testEmptyFile(self): + testfile = utils.PathJoin(self.tmpdir, "empty") + + utils.WriteFile(testfile, data="") + + self.assertEqual(ssconf.ReadSsconfFile(testfile), "") + + def testSingleLine(self): + testfile = utils.PathJoin(self.tmpdir, "data") + + for nl in range(0, 10): + utils.WriteFile(testfile, data="Hello World" + ("\n" * nl)) + + self.assertEqual(ssconf.ReadSsconfFile(testfile), + "Hello World") + + def testExactlyMaxSize(self): + testfile = utils.PathJoin(self.tmpdir, "data") + + data = "A" * ssconf._MAX_SIZE + utils.WriteFile(testfile, data=data) + + self.assertEqual(os.path.getsize(testfile), ssconf._MAX_SIZE) + + self.assertEqual(ssconf.ReadSsconfFile(testfile), + data) + + def testLargeFile(self): + testfile = utils.PathJoin(self.tmpdir, "data") + + for size in [ssconf._MAX_SIZE + 1, ssconf._MAX_SIZE * 2]: + utils.WriteFile(testfile, data="A" * size) + self.assertTrue(os.path.getsize(testfile) > ssconf._MAX_SIZE) + self.assertRaises(RuntimeError, ssconf.ReadSsconfFile, testfile) + + +class TestSimpleStore(unittest.TestCase): + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + self.sstore = ssconf.SimpleStore(cfg_location=self.tmpdir) + + def tearDown(self): + shutil.rmtree(self.tmpdir) + + def testInvalidKey(self): + self.assertRaises(errors.ProgrammerError, self.sstore.KeyToFilename, + "not a valid key") + self.assertRaises(errors.ProgrammerError, self.sstore._ReadFile, + "not a valid key") + + def testKeyToFilename(self): + for key in ssconf._VALID_KEYS: + result = self.sstore.KeyToFilename(key) + self.assertTrue(utils.IsBelowDir(self.tmpdir, result)) + self.assertTrue(os.path.basename(result).startswith("ssconf_")) + + def testReadFileNonExistingFile(self): + filename = self.sstore.KeyToFilename(constants.SS_CLUSTER_NAME) + + self.assertFalse(os.path.exists(filename)) + try: + self.sstore._ReadFile(constants.SS_CLUSTER_NAME) + except errors.ConfigurationError, err: + self.assertTrue(str(err).startswith("Can't read ssconf file")) + else: + self.fail("Exception was not raised") + + for default in ["", "Hello World", 0, 100]: + self.assertFalse(os.path.exists(filename)) + result = self.sstore._ReadFile(constants.SS_CLUSTER_NAME, default=default) + self.assertEqual(result, default) + + def testReadFile(self): + utils.WriteFile(self.sstore.KeyToFilename(constants.SS_CLUSTER_NAME), + data="cluster.example.com") + + self.assertEqual(self.sstore._ReadFile(constants.SS_CLUSTER_NAME), + "cluster.example.com") + + self.assertEqual(self.sstore._ReadFile(constants.SS_CLUSTER_NAME, + default="something.example.com"), + "cluster.example.com") + + +if __name__ == "__main__": + testutils.GanetiTestProgram()