Commit 8a3c9e8a authored by Michael Hanselmann's avatar Michael Hanselmann
Browse files

ssh.GetUserFiles: RSA support, unit tests

This patch changes “ssh.GetUserFiles” to support two different kinds of
SSH keys, RSA and DSA. Before it would always use DSA. Newly written
unit tests are included.
Signed-off-by: default avatarMichael Hanselmann <>
Reviewed-by: default avatarIustin Pop <>
parent 986efb78
......@@ -2044,5 +2044,10 @@ IALLOC_HAIL = "hail"
# SSH key types
SSHK_RSA = "rsa"
SSHK_DSA = "dsa"
SSHK_ALL = frozenset([SSHK_RSA, SSHK_DSA])
# Do not re-export imported modules
del re, _vcsversion, _autoconf, socket, pathutils
......@@ -48,25 +48,35 @@ def FormatParamikoFingerprint(fingerprint):
return ":".join(re.findall(r"..", fingerprint.lower()))
def GetUserFiles(user, mkdir=False):
"""Return the paths of a user's ssh files.
The function will return a triplet (priv_key_path, pub_key_path,
auth_key_path) that are used for ssh authentication. Currently, the
keys used are DSA keys, so this function will return:
(~user/.ssh/id_dsa, ~user/.ssh/,
If the optional parameter mkdir is True, the ssh directory will be
created if it doesn't exist.
Regardless of the mkdir parameters, the script will raise an error
if ~user/.ssh is not a directory.
def GetUserFiles(user, mkdir=False, kind=constants.SSHK_DSA,
"""Return the paths of a user's SSH files.
@type user: string
@param user: Username
@type mkdir: bool
@param mkdir: Whether to create ".ssh" directory if it doesn't exist
@type kind: string
@param kind: One of L{constants.SSHK_ALL}
@rtype: tuple; (string, string, string)
@return: Tuple containing three file system paths; the private SSH key file,
the public SSH key file and the user's C{authorized_keys} file
@raise errors.OpExecError: When home directory of the user can not be
@raise errors.OpExecError: Regardless of the C{mkdir} parameters, this
exception is raised if C{~$user/.ssh} is not a directory
user_dir = utils.GetHomeDir(user)
user_dir = _homedir_fn(user)
if not user_dir:
raise errors.OpExecError("Cannot resolve home of user %s" % user)
raise errors.OpExecError("Cannot resolve home of user '%s'" % user)
if kind == constants.SSHK_DSA:
suffix = "dsa"
elif kind == constants.SSHK_RSA:
suffix = "rsa"
raise errors.ProgrammerError("Unknown SSH key kind '%s'" % kind)
ssh_dir = utils.PathJoin(user_dir, ".ssh")
if mkdir:
......@@ -75,7 +85,8 @@ def GetUserFiles(user, mkdir=False):
raise errors.OpExecError("Path %s is not a directory" % ssh_dir)
return [utils.PathJoin(ssh_dir, base)
for base in ["id_dsa", "", "authorized_keys"]]
for base in ["id_%s" % suffix, "" % suffix,
class SshRunner:
......@@ -24,6 +24,7 @@
import os
import tempfile
import unittest
import shutil
import testutils
import mocks
......@@ -31,6 +32,7 @@ import mocks
from ganeti import constants
from ganeti import utils
from ganeti import ssh
from ganeti import errors
class TestKnownHosts(testutils.GanetiTestCase):
......@@ -54,5 +56,74 @@ class TestKnownHosts(testutils.GanetiTestCase):
self.assertRaises(AssertionError, ssh.FormatParamikoFingerprint, "C0Ffe")
if __name__ == '__main__':
class TestGetUserFiles(unittest.TestCase):
def setUp(self):
self.tmpdir = tempfile.mkdtemp()
def tearDown(self):
def _GetNoHomedir(_):
return None
def _GetTempHomedir(self, _):
return self.tmpdir
def testNonExistantUser(self):
for kind in constants.SSHK_ALL:
self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example",
kind=kind, _homedir_fn=self._GetNoHomedir)
def testUnknownKind(self):
kind = "something-else"
assert kind not in constants.SSHK_ALL
self.assertRaises(errors.ProgrammerError, ssh.GetUserFiles, "example4645",
kind=kind, _homedir_fn=self._GetTempHomedir)
self.assertEqual(os.listdir(self.tmpdir), [])
def testNoSshDirectory(self):
for kind in constants.SSHK_ALL:
self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example29694",
kind=kind, _homedir_fn=self._GetTempHomedir)
self.assertEqual(os.listdir(self.tmpdir), [])
def testSshIsFile(self):
utils.WriteFile(os.path.join(self.tmpdir, ".ssh"), data="")
for kind in constants.SSHK_ALL:
self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example26237",
kind=kind, _homedir_fn=self._GetTempHomedir)
self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
def testMakeSshDirectory(self):
sshdir = os.path.join(self.tmpdir, ".ssh")
self.assertEqual(os.listdir(self.tmpdir), [])
for kind in constants.SSHK_ALL:
ssh.GetUserFiles("example20745", mkdir=True, kind=kind,
self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
self.assertEqual(os.stat(sshdir).st_mode & 0777, 0700)
def testFilenames(self):
sshdir = os.path.join(self.tmpdir, ".ssh")
for kind in constants.SSHK_ALL:
result = ssh.GetUserFiles("example15103", mkdir=False, kind=kind,
self.assertEqual(result, [
os.path.join(self.tmpdir, ".ssh", "id_%s" % kind),
os.path.join(self.tmpdir, ".ssh", "" % kind),
os.path.join(self.tmpdir, ".ssh", "authorized_keys"),
self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
self.assertEqual(os.listdir(sshdir), [])
if __name__ == "__main__":
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment