From 7a6a27af496f4cc477c4a9bc08fbd42e27d3f3ca Mon Sep 17 00:00:00 2001
From: Iustin Pop <iustin@google.com>
Date: Fri, 20 Aug 2010 15:12:02 +0200
Subject: [PATCH] setup-ssh: try to use key auth first

This patch changes the setup-ssh workflow to try key authentication
first, and then fall-back to password authentication. The password is
also read lazily, with no prompts if we can authenticate via keys.

Signed-off-by: Iustin Pop <iustin@google.com>
Reviewed-by: Guido Trotter <ultrotter@google.com>
---
 tools/setup-ssh | 40 +++++++++++++++++++++++++++++++++++++---
 1 file changed, 37 insertions(+), 3 deletions(-)

diff --git a/tools/setup-ssh b/tools/setup-ssh
index ff9880d55..8f4cfd194 100755
--- a/tools/setup-ssh
+++ b/tools/setup-ssh
@@ -169,6 +169,14 @@ def ParseOptions():
                                         " <node...>"), prog=program)
   parser.add_option(cli.DEBUG_OPT)
   parser.add_option(cli.VERBOSE_OPT)
+  default_key = ssh.GetUserFiles(constants.GANETI_RUNAS)[0]
+  parser.add_option(optparse.Option("-f", dest="private_key",
+                                    default=default_key,
+                                    help="The private key to (try to) use for"
+                                    "authentication "))
+  parser.add_option(optparse.Option("--key-type", dest="key_type",
+                                    choices=("rsa", "dsa"), default="dsa",
+                                    help="The private key type (rsa or dsa)"))
 
   (options, args) = parser.parse_args()
 
@@ -221,7 +229,23 @@ def main():
 
   SetupLogging(options)
 
-  passwd = getpass.getpass(prompt="%s password:" % constants.GANETI_RUNAS)
+  if options.key_type == "rsa":
+    pkclass = paramiko.RSAKey
+  elif options.key_type == "dsa":
+    pkclass = paramiko.DSSKey
+  else:
+    logging.critical("Unknown key type %s selected (choose either rsa or dsa)",
+                     options.key_type)
+    sys.exit(1)
+
+  try:
+    private_key = pkclass.from_private_key_file(options.private_key)
+  except (paramiko.SSHException, EnvironmentError), err:
+    logging.critical("Can't load private key %s: %s", options.private_key, err)
+    sys.exit(1)
+
+  passwd = None
+  username = constants.GANETI_RUNAS
   ssh_port = netutils.GetDaemonPort("ssh")
 
   # Below, we need to join() the transport objects, as otherwise the
@@ -235,9 +259,19 @@ def main():
 
   for host in args:
     transport = paramiko.Transport((host, ssh_port))
+    transport.start_client()
     try:
-      transport.connect(username=constants.GANETI_RUNAS, password=passwd)
-    except Exception, err:
+      try:
+        transport.auth_publickey(username, private_key)
+        logging.info("Authenticated to %s via public key", host)
+      except paramiko.SSHException:
+        logging.warning("Authentication to %s via public key failed, trying"
+                        " password", host)
+        if passwd is None:
+          passwd = getpass.getpass(prompt="%s password:" % username)
+        transport.auth_password(username=username, password=passwd)
+        logging.info("Authenticated to %s via password", host)
+    except paramiko.SSHException, err:
       logging.error("Connection or authentication failed to host %s: %s",
                     host, err)
       transport.close()
-- 
GitLab