From 310a894471516eb1e3a89bde25e521a408558f85 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ren=C3=A9=20Nussbaumer?= <rn@google.com>
Date: Tue, 24 Aug 2010 13:54:01 +0200
Subject: [PATCH] Adding host key verification to setup-ssh
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Signed-off-by: RenΓ© Nussbaumer <rn@google.com>
Reviewed-by: Michael Hanselmann <hansmi@google.com>
Reviewed-by: Iustin Pop <iustin@google.com>
---
 tools/setup-ssh | 40 ++++++++++++++++++++++++++++++++++++++++
 1 file changed, 40 insertions(+)

diff --git a/tools/setup-ssh b/tools/setup-ssh
index e5a512bf1..2dd2c9ab6 100755
--- a/tools/setup-ssh
+++ b/tools/setup-ssh
@@ -177,6 +177,7 @@ def ParseOptions():
                                         " <node...>"), prog=program)
   parser.add_option(cli.DEBUG_OPT)
   parser.add_option(cli.VERBOSE_OPT)
+  parser.add_option(cli.NOSSH_KEYCHECK_OPT)
   default_key = ssh.GetUserFiles(constants.GANETI_RUNAS)[0]
   parser.add_option(optparse.Option("-f", dest="private_key",
                                     default=default_key,
@@ -296,6 +297,22 @@ def LoginViaKeys(transport, username, keys):
     return False
 
 
+def LoadKnownHosts():
+  """Loads the known hosts
+
+    @return L{paramiko.util.load_host_keys} dict
+
+  """
+  homedir = utils.GetHomeDir(constants.GANETI_RUNAS)
+  known_hosts = os.path.join(homedir, ".ssh", "known_hosts")
+
+  try:
+    return paramiko.util.load_host_keys(known_hosts)
+  except EnvironmentError:
+    # We didn't found the path, silently ignore and return an empty dict
+    return {}
+
+
 def main():
   """Main routine.
 
@@ -309,6 +326,7 @@ def main():
   passwd = None
   username = constants.GANETI_RUNAS
   ssh_port = netutils.GetDaemonPort("ssh")
+  host_keys = LoadKnownHosts()
 
   # Below, we need to join() the transport objects, as otherwise the
   # following happens:
@@ -322,6 +340,28 @@ def main():
   for host in args:
     transport = paramiko.Transport((host, ssh_port))
     transport.start_client()
+    server_key = transport.get_remote_server_key()
+    keytype = server_key.get_name()
+
+    our_server_key = host_keys.get(host, {}).get(keytype, None)
+    if options.ssh_key_check:
+      if not our_server_key:
+        hexified_key = ssh.FormatParamikoFingerprint(
+            server_key.get_fingerprint())
+        msg = ("Unable to verify hostkey of host %s: %s. Do you want to accept"
+               " it?" % (host, hexified_key))
+
+        if cli.AskUser(msg):
+          our_server_key = server_key
+
+      if our_server_key != server_key:
+        logging.error("Unable to verify identity of host. Aborting")
+        transport.close()
+        transport.join()
+        # TODO: Run over all hosts, fetch the keys and let them verify from the
+        #       user beforehand then proceed with actual work later on
+        raise paramiko.SSHException("Unable to verify identity of host")
+
     try:
       if LoginViaKeys(transport, username, all_keys):
         logging.info("Authenticated to %s via public key", host)
-- 
GitLab