Skip to content
Snippets Groups Projects
Commit 5484cda5 authored by Michael Hanselmann's avatar Michael Hanselmann
Browse files

ssh: Add function to get all of user's SSH files


This new function returns the file paths for all of a user's SSH-related
files (RSA, DSA and authorized_keys).

Signed-off-by: default avatarMichael Hanselmann <hansmi@google.com>
Reviewed-by: default avatarIustin Pop <iustin@google.com>
parent d5d76ab2
No related branches found
No related tags found
No related merge requests found
......@@ -34,6 +34,7 @@ from ganeti import constants
from ganeti import netutils
from ganeti import pathutils
from ganeti import vcluster
from ganeti import compat
def FormatParamikoFingerprint(fingerprint):
......@@ -95,6 +96,29 @@ def GetUserFiles(user, mkdir=False, dircheck=True, kind=constants.SSHK_DSA,
"authorized_keys"]]
def GetAllUserFiles(user, mkdir=False, dircheck=True, _homedir_fn=None):
"""Wrapper over L{GetUserFiles} to retrieve files for all SSH key types.
See L{GetUserFiles} for details.
@rtype: tuple; (string, dict with string as key, tuple of (string, string) as
value)
"""
helper = compat.partial(GetUserFiles, user, mkdir=mkdir, dircheck=dircheck,
_homedir_fn=_homedir_fn)
result = [(kind, helper(kind=kind)) for kind in constants.SSHK_ALL]
authorized_keys = [i for (_, (_, _, i)) in result]
assert len(frozenset(authorized_keys)) == 1, \
"Different paths for authorized_keys were returned"
return (authorized_keys[0],
dict((kind, (privkey, pubkey))
for (kind, (privkey, pubkey, _)) in result))
class SshRunner:
"""Wrapper for SSH commands.
......
......@@ -132,6 +132,26 @@ class TestGetUserFiles(unittest.TestCase):
_homedir_fn=self._GetTempHomedir)
self.assertEqual(os.listdir(self.tmpdir), [])
def testGetAllUserFiles(self):
result = ssh.GetAllUserFiles("example7475", mkdir=False, dircheck=False,
_homedir_fn=self._GetTempHomedir)
self.assertEqual(result,
(os.path.join(self.tmpdir, ".ssh", "authorized_keys"), {
constants.SSHK_RSA:
(os.path.join(self.tmpdir, ".ssh", "id_rsa"),
os.path.join(self.tmpdir, ".ssh", "id_rsa.pub")),
constants.SSHK_DSA:
(os.path.join(self.tmpdir, ".ssh", "id_dsa"),
os.path.join(self.tmpdir, ".ssh", "id_dsa.pub")),
}))
self.assertEqual(os.listdir(self.tmpdir), [])
def testGetAllUserFilesNoDirectoryNoMkdir(self):
self.assertRaises(errors.OpExecError, ssh.GetAllUserFiles,
"example17270", mkdir=False, dircheck=True,
_homedir_fn=self._GetTempHomedir)
self.assertEqual(os.listdir(self.tmpdir), [])
if __name__ == "__main__":
testutils.GanetiTestProgram()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment