diff --git a/lib/constants.py b/lib/constants.py index 0342ddf1a74ec23618ae2b6628eb734d65cb23c3..39b68953450eb0c74a37f86bc3aa746eb323fd0a 100644 --- a/lib/constants.py +++ b/lib/constants.py @@ -2044,5 +2044,10 @@ IALLOC_HAIL = "hail" FAKE_OP_MASTER_TURNUP = "OP_CLUSTER_IP_TURNUP" FAKE_OP_MASTER_TURNDOWN = "OP_CLUSTER_IP_TURNDOWN" +# SSH key types +SSHK_RSA = "rsa" +SSHK_DSA = "dsa" +SSHK_ALL = frozenset([SSHK_RSA, SSHK_DSA]) + # Do not re-export imported modules del re, _vcsversion, _autoconf, socket, pathutils diff --git a/lib/ssh.py b/lib/ssh.py index 4c4a18c4f5e65119663f0b7ae70841d3a6dfb5c3..b61f7c481e6f2cf434964d7d5d96e9c6d45053f0 100644 --- a/lib/ssh.py +++ b/lib/ssh.py @@ -48,25 +48,35 @@ def FormatParamikoFingerprint(fingerprint): return ":".join(re.findall(r"..", fingerprint.lower())) -def GetUserFiles(user, mkdir=False): - """Return the paths of a user's ssh files. - - The function will return a triplet (priv_key_path, pub_key_path, - auth_key_path) that are used for ssh authentication. Currently, the - keys used are DSA keys, so this function will return: - (~user/.ssh/id_dsa, ~user/.ssh/id_dsa.pub, - ~user/.ssh/authorized_keys). - - If the optional parameter mkdir is True, the ssh directory will be - created if it doesn't exist. - - Regardless of the mkdir parameters, the script will raise an error - if ~user/.ssh is not a directory. +def GetUserFiles(user, mkdir=False, kind=constants.SSHK_DSA, + _homedir_fn=utils.GetHomeDir): + """Return the paths of a user's SSH files. + + @type user: string + @param user: Username + @type mkdir: bool + @param mkdir: Whether to create ".ssh" directory if it doesn't exist + @type kind: string + @param kind: One of L{constants.SSHK_ALL} + @rtype: tuple; (string, string, string) + @return: Tuple containing three file system paths; the private SSH key file, + the public SSH key file and the user's C{authorized_keys} file + @raise errors.OpExecError: When home directory of the user can not be + determined + @raise errors.OpExecError: Regardless of the C{mkdir} parameters, this + exception is raised if C{~$user/.ssh} is not a directory """ - user_dir = utils.GetHomeDir(user) + user_dir = _homedir_fn(user) if not user_dir: - raise errors.OpExecError("Cannot resolve home of user %s" % user) + raise errors.OpExecError("Cannot resolve home of user '%s'" % user) + + if kind == constants.SSHK_DSA: + suffix = "dsa" + elif kind == constants.SSHK_RSA: + suffix = "rsa" + else: + raise errors.ProgrammerError("Unknown SSH key kind '%s'" % kind) ssh_dir = utils.PathJoin(user_dir, ".ssh") if mkdir: @@ -75,7 +85,8 @@ def GetUserFiles(user, mkdir=False): raise errors.OpExecError("Path %s is not a directory" % ssh_dir) return [utils.PathJoin(ssh_dir, base) - for base in ["id_dsa", "id_dsa.pub", "authorized_keys"]] + for base in ["id_%s" % suffix, "id_%s.pub" % suffix, + "authorized_keys"]] class SshRunner: diff --git a/test/ganeti.ssh_unittest.py b/test/ganeti.ssh_unittest.py index bd6c951efd742518112e37c3b5740f8c33d07b32..77960c237a2bc852dc551cd61e2596dc5ec44446 100755 --- a/test/ganeti.ssh_unittest.py +++ b/test/ganeti.ssh_unittest.py @@ -24,6 +24,7 @@ import os import tempfile import unittest +import shutil import testutils import mocks @@ -31,6 +32,7 @@ import mocks from ganeti import constants from ganeti import utils from ganeti import ssh +from ganeti import errors class TestKnownHosts(testutils.GanetiTestCase): @@ -54,5 +56,74 @@ class TestKnownHosts(testutils.GanetiTestCase): self.assertRaises(AssertionError, ssh.FormatParamikoFingerprint, "C0Ffe") -if __name__ == '__main__': +class TestGetUserFiles(unittest.TestCase): + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tmpdir) + + @staticmethod + def _GetNoHomedir(_): + return None + + def _GetTempHomedir(self, _): + return self.tmpdir + + def testNonExistantUser(self): + for kind in constants.SSHK_ALL: + self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example", + kind=kind, _homedir_fn=self._GetNoHomedir) + + def testUnknownKind(self): + kind = "something-else" + assert kind not in constants.SSHK_ALL + self.assertRaises(errors.ProgrammerError, ssh.GetUserFiles, "example4645", + kind=kind, _homedir_fn=self._GetTempHomedir) + + self.assertEqual(os.listdir(self.tmpdir), []) + + def testNoSshDirectory(self): + for kind in constants.SSHK_ALL: + self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example29694", + kind=kind, _homedir_fn=self._GetTempHomedir) + self.assertEqual(os.listdir(self.tmpdir), []) + + def testSshIsFile(self): + utils.WriteFile(os.path.join(self.tmpdir, ".ssh"), data="") + for kind in constants.SSHK_ALL: + self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example26237", + kind=kind, _homedir_fn=self._GetTempHomedir) + self.assertEqual(os.listdir(self.tmpdir), [".ssh"]) + + def testMakeSshDirectory(self): + sshdir = os.path.join(self.tmpdir, ".ssh") + + self.assertEqual(os.listdir(self.tmpdir), []) + + for kind in constants.SSHK_ALL: + ssh.GetUserFiles("example20745", mkdir=True, kind=kind, + _homedir_fn=self._GetTempHomedir) + self.assertEqual(os.listdir(self.tmpdir), [".ssh"]) + self.assertEqual(os.stat(sshdir).st_mode & 0777, 0700) + + def testFilenames(self): + sshdir = os.path.join(self.tmpdir, ".ssh") + + os.mkdir(sshdir) + + for kind in constants.SSHK_ALL: + result = ssh.GetUserFiles("example15103", mkdir=False, kind=kind, + _homedir_fn=self._GetTempHomedir) + self.assertEqual(result, [ + os.path.join(self.tmpdir, ".ssh", "id_%s" % kind), + os.path.join(self.tmpdir, ".ssh", "id_%s.pub" % kind), + os.path.join(self.tmpdir, ".ssh", "authorized_keys"), + ]) + + self.assertEqual(os.listdir(self.tmpdir), [".ssh"]) + self.assertEqual(os.listdir(sshdir), []) + + +if __name__ == "__main__": testutils.GanetiTestProgram()