diff --git a/lib/ssh.py b/lib/ssh.py index 009071c9bb6447a24879c9e69b6afd6b72d07010..cec442df1607b22da306afe36fb91cd765be7614 100644 --- a/lib/ssh.py +++ b/lib/ssh.py @@ -48,7 +48,7 @@ def FormatParamikoFingerprint(fingerprint): return ":".join(re.findall(r"..", fingerprint.lower())) -def GetUserFiles(user, mkdir=False, kind=constants.SSHK_DSA, +def GetUserFiles(user, mkdir=False, dircheck=True, kind=constants.SSHK_DSA, _homedir_fn=None): """Return the paths of a user's SSH files. @@ -56,6 +56,8 @@ def GetUserFiles(user, mkdir=False, kind=constants.SSHK_DSA, @param user: Username @type mkdir: bool @param mkdir: Whether to create ".ssh" directory if it doesn't exist + @type dircheck: bool + @param dircheck: Whether to check if ".ssh" directory exists @type kind: string @param kind: One of L{constants.SSHK_ALL} @rtype: tuple; (string, string, string) @@ -64,7 +66,8 @@ def GetUserFiles(user, mkdir=False, kind=constants.SSHK_DSA, @raise errors.OpExecError: When home directory of the user can not be determined @raise errors.OpExecError: Regardless of the C{mkdir} parameters, this - exception is raised if C{~$user/.ssh} is not a directory + exception is raised if C{~$user/.ssh} is not a directory and C{dircheck} + is set to C{True} """ if _homedir_fn is None: @@ -84,7 +87,7 @@ def GetUserFiles(user, mkdir=False, kind=constants.SSHK_DSA, ssh_dir = utils.PathJoin(user_dir, ".ssh") if mkdir: utils.EnsureDirs([(ssh_dir, constants.SECURE_DIR_MODE)]) - elif not os.path.isdir(ssh_dir): + elif dircheck and not os.path.isdir(ssh_dir): raise errors.OpExecError("Path %s is not a directory" % ssh_dir) return [utils.PathJoin(ssh_dir, base) diff --git a/test/ganeti.ssh_unittest.py b/test/ganeti.ssh_unittest.py index 77960c237a2bc852dc551cd61e2596dc5ec44446..419c05e3e44b3f9830334345db99b6216d800d50 100755 --- a/test/ganeti.ssh_unittest.py +++ b/test/ganeti.ssh_unittest.py @@ -124,6 +124,14 @@ class TestGetUserFiles(unittest.TestCase): self.assertEqual(os.listdir(self.tmpdir), [".ssh"]) self.assertEqual(os.listdir(sshdir), []) + def testNoDirCheck(self): + self.assertEqual(os.listdir(self.tmpdir), []) + + for kind in constants.SSHK_ALL: + ssh.GetUserFiles("example14528", mkdir=False, dircheck=False, kind=kind, + _homedir_fn=self._GetTempHomedir) + self.assertEqual(os.listdir(self.tmpdir), []) + if __name__ == "__main__": testutils.GanetiTestProgram()