From f40ae421dcbeeab00b8559cca94e5b847b544330 Mon Sep 17 00:00:00 2001
From: Iustin Pop <iustin@google.com>
Date: Sat, 27 Nov 2010 18:58:24 +0000
Subject: [PATCH] Improve unittests for the utils module

This just a random collection of unittest improvements. Coverage
increases from 73% to 76%.

Signed-off-by: Iustin Pop <iustin@google.com>
Reviewed-by: Michael Hanselmann <hansmi@google.com>
---
 test/ganeti.utils_unittest.py | 219 +++++++++++++++++++++++++++++-----
 1 file changed, 190 insertions(+), 29 deletions(-)

diff --git a/test/ganeti.utils_unittest.py b/test/ganeti.utils_unittest.py
index ef11e8ed0..9faf3bf31 100755
--- a/test/ganeti.utils_unittest.py
+++ b/test/ganeti.utils_unittest.py
@@ -363,8 +363,22 @@ class TestRunCmd(testutils.GanetiTestCase):
     self.failUnlessEqual(RunCmd(["env"], reset_env=True,
                                 env={"FOO": "bar",}).stdout.strip(), "FOO=bar")
 
+  def testNoFork(self):
+    """Test that nofork raise an error"""
+    assert not utils.no_fork
+    utils.no_fork = True
+    try:
+      self.assertRaises(errors.ProgrammerError, RunCmd, ["true"])
+    finally:
+      utils.no_fork = False
+
+  def testWrongParams(self):
+    """Test wrong parameters"""
+    self.assertRaises(errors.ProgrammerError, RunCmd, ["true"],
+                      output="/dev/null", interactive=True)
+
 
-class TestRunParts(unittest.TestCase):
+class TestRunParts(testutils.GanetiTestCase):
   """Testing case for the RunParts function"""
 
   def setUp(self):
@@ -489,6 +503,10 @@ class TestRunParts(unittest.TestCase):
     self.failUnlessEqual(runresult.exit_code, 0)
     self.failUnless(not runresult.failed)
 
+  def testMissingDirectory(self):
+    nosuchdir = utils.PathJoin(self.rundir, "no/such/directory")
+    self.assertEqual(RunParts(nosuchdir), [])
+
 
 class TestStartDaemon(testutils.GanetiTestCase):
   def setUp(self):
@@ -681,6 +699,26 @@ class TestRemoveFile(unittest.TestCase):
       self.fail("File '%s' not removed" % symlink)
 
 
+class TestRemoveDir(unittest.TestCase):
+  def setUp(self):
+    self.tmpdir = tempfile.mkdtemp()
+
+  def tearDown(self):
+    try:
+      shutil.rmtree(self.tmpdir)
+    except EnvironmentError:
+      pass
+
+  def testEmptyDir(self):
+    utils.RemoveDir(self.tmpdir)
+    self.assertFalse(os.path.isdir(self.tmpdir))
+
+  def testNonEmptyDir(self):
+    self.tmpfile = os.path.join(self.tmpdir, "test1")
+    open(self.tmpfile, "w").close()
+    self.assertRaises(EnvironmentError, utils.RemoveDir, self.tmpdir)
+
+
 class TestRename(unittest.TestCase):
   """Test case for RenameFile"""
 
@@ -882,6 +920,10 @@ class TestReadOneLineFile(testutils.GanetiTestCase):
           datastrict = ReadOneLineFile(myfile, strict=True)
           self.assertEqual(myline, datastrict)
 
+  def testEmptyfile(self):
+    myfile = self._CreateTempFile()
+    self.assertRaises(errors.GenericError, ReadOneLineFile, myfile)
+
 
 class TestTimestampForFilename(unittest.TestCase):
   def test(self):
@@ -973,6 +1015,9 @@ class TestFormatUnit(unittest.TestCase):
     self.assertEqual(FormatUnit(5120 * 1024, 't'), '5.0')
     self.assertEqual(FormatUnit(29829 * 1024, 't'), '29.1')
 
+  def testErrors(self):
+    self.assertRaises(errors.ProgrammerError, FormatUnit, 1, "a")
+
 
 class TestParseUnit(unittest.TestCase):
   """Test case for the ParseUnit function"""
@@ -1035,18 +1080,9 @@ class TestParseCpuMask(unittest.TestCase):
     self.assertEqual(utils.ParseCpuMask("0-2,4,5-5"), [0,1,2,4,5])
 
   def testInvalidInput(self):
-    self.assertRaises(errors.ParseError,
-                      utils.ParseCpuMask,
-                      "garbage")
-    self.assertRaises(errors.ParseError,
-                      utils.ParseCpuMask,
-                      "0,")
-    self.assertRaises(errors.ParseError,
-                      utils.ParseCpuMask,
-                      "0-1-2")
-    self.assertRaises(errors.ParseError,
-                      utils.ParseCpuMask,
-                      "2-1")
+    for data in ["garbage", "0,", "0-1-2", "2-1", "1-a"]:
+      self.assertRaises(errors.ParseError, utils.ParseCpuMask, data)
+
 
 class TestSshKeys(testutils.GanetiTestCase):
   """Test case for the AddAuthorizedKey function"""
@@ -1523,21 +1559,19 @@ class FieldSetTestCase(unittest.TestCase):
 
 class TestForceDictType(unittest.TestCase):
   """Test case for ForceDictType"""
-
-  def setUp(self):
-    self.key_types = {
-      'a': constants.VTYPE_INT,
-      'b': constants.VTYPE_BOOL,
-      'c': constants.VTYPE_STRING,
-      'd': constants.VTYPE_SIZE,
-      "e": constants.VTYPE_MAYBE_STRING,
-      }
+  KEY_TYPES = {
+    "a": constants.VTYPE_INT,
+    "b": constants.VTYPE_BOOL,
+    "c": constants.VTYPE_STRING,
+    "d": constants.VTYPE_SIZE,
+    "e": constants.VTYPE_MAYBE_STRING,
+    }
 
   def _fdt(self, dict, allowed_values=None):
     if allowed_values is None:
-      utils.ForceDictType(dict, self.key_types)
+      utils.ForceDictType(dict, self.KEY_TYPES)
     else:
-      utils.ForceDictType(dict, self.key_types, allowed_values=allowed_values)
+      utils.ForceDictType(dict, self.KEY_TYPES, allowed_values=allowed_values)
 
     return dict
 
@@ -1550,6 +1584,7 @@ class TestForceDictType(unittest.TestCase):
     self.assertEqual(self._fdt({'b': 1, 'c': False}), {'b': True, 'c': ''})
     self.assertEqual(self._fdt({'b': 'false'}), {'b': False})
     self.assertEqual(self._fdt({'b': 'False'}), {'b': False})
+    self.assertEqual(self._fdt({'b': False}), {'b': False})
     self.assertEqual(self._fdt({'b': 'true'}), {'b': True})
     self.assertEqual(self._fdt({'b': 'True'}), {'b': True})
     self.assertEqual(self._fdt({'d': '4'}), {'d': 4})
@@ -1557,14 +1592,20 @@ class TestForceDictType(unittest.TestCase):
     self.assertEqual(self._fdt({"e": None, }), {"e": None, })
     self.assertEqual(self._fdt({"e": "Hello World", }), {"e": "Hello World", })
     self.assertEqual(self._fdt({"e": False, }), {"e": '', })
+    self.assertEqual(self._fdt({"b": "hello", }, ["hello"]), {"b": "hello"})
 
   def testErrors(self):
     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'a': 'astring'})
+    self.assertRaises(errors.TypeEnforcementError, self._fdt, {"b": "hello"})
     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'c': True})
     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': 'astring'})
     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': '4 L'})
     self.assertRaises(errors.TypeEnforcementError, self._fdt, {"e": object(), })
     self.assertRaises(errors.TypeEnforcementError, self._fdt, {"e": [], })
+    self.assertRaises(errors.TypeEnforcementError, self._fdt, {"x": None, })
+    self.assertRaises(errors.TypeEnforcementError, self._fdt, [])
+    self.assertRaises(errors.ProgrammerError, utils.ForceDictType,
+                      {"b": "hello"}, {"b": "no-such-type"})
 
 
 class TestIsNormAbsPath(unittest.TestCase):
@@ -1659,17 +1700,32 @@ class RunInSeparateProcess(unittest.TestCase):
                       utils.RunInSeparateProcess, _exc)
 
 
-class TestFingerprintFile(unittest.TestCase):
+class TestFingerprintFiles(unittest.TestCase):
   def setUp(self):
     self.tmpfile = tempfile.NamedTemporaryFile()
+    self.tmpfile2 = tempfile.NamedTemporaryFile()
+    utils.WriteFile(self.tmpfile2.name, data="Hello World\n")
+    self.results = {
+      self.tmpfile.name: "da39a3ee5e6b4b0d3255bfef95601890afd80709",
+      self.tmpfile2.name: "648a6a6ffffdaa0badb23b8baf90b6168dd16b3a",
+      }
 
-  def test(self):
+  def testSingleFile(self):
     self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
-                     "da39a3ee5e6b4b0d3255bfef95601890afd80709")
+                     self.results[self.tmpfile.name])
+
+    self.assertEqual(utils._FingerprintFile("/no/such/file"), None)
 
-    utils.WriteFile(self.tmpfile.name, data="Hello World\n")
+  def testBigFile(self):
+    self.tmpfile.write("A" * 8192)
+    self.tmpfile.flush()
     self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
-                     "648a6a6ffffdaa0badb23b8baf90b6168dd16b3a")
+                     "35b6795ca20d6dc0aff8c7c110c96cd1070b8c38")
+
+  def testMultiple(self):
+    all_files = self.results.keys()
+    all_files.append("/no/such/file")
+    self.assertEqual(utils.FingerprintFiles(self.results.keys()), self.results)
 
 
 class TestUnescapeAndSplit(unittest.TestCase):
@@ -2414,6 +2470,11 @@ class TestFileID(testutils.GanetiTestCase):
     # this doesn't raise, since we passed None
     utils.SafeWriteFile(name, None, data="")
 
+  def testError(self):
+    t = tempfile.NamedTemporaryFile()
+    self.assertRaises(errors.ProgrammerError, utils.GetFileID,
+                      path=t.name, fd=t.fileno())
+
 
 class TimeMock:
   def __init__(self, values):
@@ -2451,5 +2512,105 @@ class TestRunningTimeout(unittest.TestCase):
     self.assertRaises(ValueError, utils.RunningTimeout, -1.0, True)
 
 
+class TestTryConvert(unittest.TestCase):
+  def test(self):
+    for src, fn, result in [
+      ("1", int, 1),
+      ("a", int, "a"),
+      ("", bool, False),
+      ("a", bool, True),
+      ]:
+      self.assertEqual(utils.TryConvert(fn, src), result)
+
+
+class TestIsValidShellParam(unittest.TestCase):
+  def test(self):
+    for val, result in [
+      ("abc", True),
+      ("ab;cd", False),
+      ]:
+      self.assertEqual(utils.IsValidShellParam(val), result)
+
+
+class TestBuildShellCmd(unittest.TestCase):
+  def test(self):
+    self.assertRaises(errors.ProgrammerError, utils.BuildShellCmd,
+                      "ls %s", "ab;cd")
+    self.assertEqual(utils.BuildShellCmd("ls %s", "ab"), "ls ab")
+
+
+class TestWriteFile(unittest.TestCase):
+  def setUp(self):
+    self.tfile = tempfile.NamedTemporaryFile()
+    self.did_pre = False
+    self.did_post = False
+    self.did_write = False
+
+  def markPre(self, fd):
+    self.did_pre = True
+
+  def markPost(self, fd):
+    self.did_post = True
+
+  def markWrite(self, fd):
+    self.did_write = True
+
+  def testWrite(self):
+    data = "abc"
+    utils.WriteFile(self.tfile.name, data=data)
+    self.assertEqual(utils.ReadFile(self.tfile.name), data)
+
+  def testErrors(self):
+    self.assertRaises(errors.ProgrammerError, utils.WriteFile,
+                      self.tfile.name, data="test", fn=lambda fd: None)
+    self.assertRaises(errors.ProgrammerError, utils.WriteFile, self.tfile.name)
+    self.assertRaises(errors.ProgrammerError, utils.WriteFile,
+                      self.tfile.name, data="test", atime=0)
+
+  def testCalls(self):
+    utils.WriteFile(self.tfile.name, fn=self.markWrite,
+                    prewrite=self.markPre, postwrite=self.markPost)
+    self.assertTrue(self.did_pre)
+    self.assertTrue(self.did_post)
+    self.assertTrue(self.did_write)
+
+  def testDryRun(self):
+    orig = "abc"
+    self.tfile.write(orig)
+    self.tfile.flush()
+    utils.WriteFile(self.tfile.name, data="hello", dry_run=True)
+    self.assertEqual(utils.ReadFile(self.tfile.name), orig)
+
+  def testTimes(self):
+    f = self.tfile.name
+    for at, mt in [(0, 0), (1000, 1000), (2000, 3000),
+                   (int(time.time()), 5000)]:
+      utils.WriteFile(f, data="hello", atime=at, mtime=mt)
+      st = os.stat(f)
+      self.assertEqual(st.st_atime, at)
+      self.assertEqual(st.st_mtime, mt)
+
+
+  def testNoClose(self):
+    data = "hello"
+    self.assertEqual(utils.WriteFile(self.tfile.name, data="abc"), None)
+    fd = utils.WriteFile(self.tfile.name, data=data, close=False)
+    try:
+      os.lseek(fd, 0, 0)
+      self.assertEqual(os.read(fd, 4096), data)
+    finally:
+      os.close(fd)
+
+
+class TestNormalizeAndValidateMac(unittest.TestCase):
+  def testInvalid(self):
+    self.assertRaises(errors.OpPrereqError,
+                      utils.NormalizeAndValidateMac, "xxx")
+
+  def testNormalization(self):
+    for mac in ["aa:bb:cc:dd:ee:ff", "00:AA:11:bB:22:cc"]:
+      self.assertEqual(utils.NormalizeAndValidateMac(mac), mac.lower())
+
+
 if __name__ == '__main__':
   testutils.GanetiTestProgram()
-- 
GitLab