From 5484cda525eba647c16014e680bdb00d01dc2b20 Mon Sep 17 00:00:00 2001 From: Michael Hanselmann <hansmi@google.com> Date: Wed, 24 Oct 2012 01:10:36 +0200 Subject: [PATCH] ssh: Add function to get all of user's SSH files This new function returns the file paths for all of a user's SSH-related files (RSA, DSA and authorized_keys). Signed-off-by: Michael Hanselmann <hansmi@google.com> Reviewed-by: Iustin Pop <iustin@google.com> --- lib/ssh.py | 24 ++++++++++++++++++++++++ test/ganeti.ssh_unittest.py | 20 ++++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/lib/ssh.py b/lib/ssh.py index cec442df1..13071485d 100644 --- a/lib/ssh.py +++ b/lib/ssh.py @@ -34,6 +34,7 @@ from ganeti import constants from ganeti import netutils from ganeti import pathutils from ganeti import vcluster +from ganeti import compat def FormatParamikoFingerprint(fingerprint): @@ -95,6 +96,29 @@ def GetUserFiles(user, mkdir=False, dircheck=True, kind=constants.SSHK_DSA, "authorized_keys"]] +def GetAllUserFiles(user, mkdir=False, dircheck=True, _homedir_fn=None): + """Wrapper over L{GetUserFiles} to retrieve files for all SSH key types. + + See L{GetUserFiles} for details. + + @rtype: tuple; (string, dict with string as key, tuple of (string, string) as + value) + + """ + helper = compat.partial(GetUserFiles, user, mkdir=mkdir, dircheck=dircheck, + _homedir_fn=_homedir_fn) + result = [(kind, helper(kind=kind)) for kind in constants.SSHK_ALL] + + authorized_keys = [i for (_, (_, _, i)) in result] + + assert len(frozenset(authorized_keys)) == 1, \ + "Different paths for authorized_keys were returned" + + return (authorized_keys[0], + dict((kind, (privkey, pubkey)) + for (kind, (privkey, pubkey, _)) in result)) + + class SshRunner: """Wrapper for SSH commands. diff --git a/test/ganeti.ssh_unittest.py b/test/ganeti.ssh_unittest.py index 419c05e3e..bb4f01530 100755 --- a/test/ganeti.ssh_unittest.py +++ b/test/ganeti.ssh_unittest.py @@ -132,6 +132,26 @@ class TestGetUserFiles(unittest.TestCase): _homedir_fn=self._GetTempHomedir) self.assertEqual(os.listdir(self.tmpdir), []) + def testGetAllUserFiles(self): + result = ssh.GetAllUserFiles("example7475", mkdir=False, dircheck=False, + _homedir_fn=self._GetTempHomedir) + self.assertEqual(result, + (os.path.join(self.tmpdir, ".ssh", "authorized_keys"), { + constants.SSHK_RSA: + (os.path.join(self.tmpdir, ".ssh", "id_rsa"), + os.path.join(self.tmpdir, ".ssh", "id_rsa.pub")), + constants.SSHK_DSA: + (os.path.join(self.tmpdir, ".ssh", "id_dsa"), + os.path.join(self.tmpdir, ".ssh", "id_dsa.pub")), + })) + self.assertEqual(os.listdir(self.tmpdir), []) + + def testGetAllUserFilesNoDirectoryNoMkdir(self): + self.assertRaises(errors.OpExecError, ssh.GetAllUserFiles, + "example17270", mkdir=False, dircheck=True, + _homedir_fn=self._GetTempHomedir) + self.assertEqual(os.listdir(self.tmpdir), []) + if __name__ == "__main__": testutils.GanetiTestProgram() -- GitLab