From c4da9eaf29593f2ddb62b0c01191d9fb7183ff90 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ren=C3=A9=20Nussbaumer?= <rn@google.com>
Date: Fri, 29 Oct 2010 14:52:52 +0200
Subject: [PATCH] setup-ssh: Better error reporting
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Together with Michael we refactored the code to make it better and
easier error reporting. Without printing backtraces for authentication
and verification issues.

Signed-off-by: RenΓ© Nussbaumer <rn@google.com>
Reviewed-by: Michael Hanselmann <hansmi@google.com>
---
 tools/setup-ssh | 203 ++++++++++++++++++++++++++++--------------------
 1 file changed, 117 insertions(+), 86 deletions(-)

diff --git a/tools/setup-ssh b/tools/setup-ssh
index 4bc63bab8..112caf61b 100755
--- a/tools/setup-ssh
+++ b/tools/setup-ssh
@@ -55,6 +55,18 @@ class JoinCheckError(errors.GenericError):
   """
 
 
+class HostKeyVerificationError(errors.GenericError):
+  """Exception if host key do not match.
+
+  """
+
+
+class AuthError(errors.GenericError):
+  """Exception for authentication errors to hosts.
+
+  """
+
+
 def _CheckJoin(transport):
   """Checks if a join is safe or dangerous.
 
@@ -70,31 +82,24 @@ def _CheckJoin(transport):
   ss = ssconf.SimpleStore()
   ss_cluster_name_path = ss.KeyToFilename(constants.SS_CLUSTER_NAME)
 
-  cluster_files = {
-    ss_cluster_name_path: utils.ReadFile(ss_cluster_name_path),
-    constants.NODED_CERT_FILE: utils.ReadFile(constants.NODED_CERT_FILE),
-    }
+  cluster_files = [
+    (constants.NODED_CERT_FILE, utils.ReadFile(constants.NODED_CERT_FILE)),
+    (ss_cluster_name_path, utils.ReadFile(ss_cluster_name_path)),
+    ]
 
-  try:
-    remote_noded_file = _ReadSftpFile(sftp, constants.NODED_CERT_FILE)
-  except IOError:
-    # We can just assume that the file doesn't exist as such error reporting
-    # is lacking from paramiko
-    #
-    # We don't have the noded certificate. As without the cert, the
-    # noded is not running, we are on the safe bet to say that this
-    # node doesn't belong to a cluster
-    return True
+  for (filename, local_content) in cluster_files:
+    try:
+      remote_content = _ReadSftpFile(sftp, filename)
+    except IOError, err:
+      # Assume file does not exist. Paramiko's error reporting is lacking.
+      logging.debug("Failed to read %s: %s", filename, err)
+      continue
 
-  try:
-    remote_cluster_name = _ReadSftpFile(sftp, ss_cluster_name_path)
-  except IOError:
-    # This can indicate that a previous join was not successful
-    # So if the noded cert was found and matches we are fine
-    return cluster_files[constants.NODED_CERT_FILE] == remote_noded_file
+    if remote_content != local_content:
+      logging.error("File %s doesn't match local version", filename)
+      return False
 
-  return (cluster_files[constants.NODED_CERT_FILE] == remote_noded_file and
-          cluster_files[ss_cluster_name_path] == remote_cluster_name)
+  return True
 
 
 def _RunRemoteCommand(transport, command):
@@ -190,11 +195,11 @@ def SetupSSH(transport):
 
   try:
     sftp.mkdir(auth_path, 0700)
-  except IOError:
+  except IOError, err:
     # Sadly paramiko doesn't provide errno or similiar
     # so we can just assume that the path already exists
-    logging.info("Path %s seems already to exist on remote node. Ignoring.",
-                 auth_path)
+    logging.info("Assuming directory %s on remote node exists: %s",
+                 auth_path, err)
 
   for name, (data, perm) in filemap.iteritems():
     _WriteSftpFile(sftp, name, perm, data)
@@ -322,6 +327,16 @@ def LoadPrivateKeys(options):
   return [private_key] + list(agent_keys)
 
 
+def _FormatFingerprint(fpr):
+  """Formats a paramiko.PKey.get_fingerprint() human readable.
+
+  @param fpr: The fingerprint to be formatted
+  @return: A human readable fingerprint
+
+  """
+  return ssh.FormatParamikoFingerprint(paramiko.util.hexify(fpr))
+
+
 def LoginViaKeys(transport, username, keys):
   """Try to login on the given transport via a list of keys.
 
@@ -337,7 +352,7 @@ def LoginViaKeys(transport, username, keys):
   for private_key in keys:
     try:
       transport.auth_publickey(username, private_key)
-      fpr = ":".join("%02x" % ord(i) for i in private_key.get_fingerprint())
+      fpr = _FormatFingerprint(private_key.get_fingerprint())
       if isinstance(private_key, paramiko.AgentKey):
         logging.debug("Authentication via the ssh-agent key %s", fpr)
       else:
@@ -362,10 +377,36 @@ def LoadKnownHosts():
   try:
     return paramiko.util.load_host_keys(known_hosts)
   except EnvironmentError:
-    # We didn't found the path, silently ignore and return an empty dict
+    # We didn't find the path, silently ignore and return an empty dict
     return {}
 
 
+def _VerifyServerKey(transport, host, host_keys):
+  """Verify the server keys.
+
+  @param transport: A paramiko.transport instance
+  @param host: Name of the host we verify
+  @param host_keys: Loaded host keys
+  @raises HostkeyVerificationError: When the host identify couldn't be verified
+
+  """
+
+  server_key = transport.get_remote_server_key()
+  keytype = server_key.get_name()
+
+  our_server_key = host_keys.get(host, {}).get(keytype, None)
+  if not our_server_key:
+    hexified_key = _FormatFingerprint(server_key.get_fingerprint())
+    msg = ("Unable to verify hostkey of host %s: %s. Do you want to accept"
+           " it?" % (host, hexified_key))
+
+    if cli.AskUser(msg):
+      our_server_key = server_key
+
+  if our_server_key != server_key:
+    raise HostKeyVerificationError("Unable to verify host identity")
+
+
 def main():
   """Main routine.
 
@@ -374,8 +415,6 @@ def main():
 
   SetupLogging(options)
 
-  errs = 0
-
   all_keys = LoadPrivateKeys(options)
 
   passwd = None
@@ -392,73 +431,65 @@ def main():
   #   wants to log one more message, which fails as the file is closed
   #   now
 
+  success = True
+
   for host in args:
-    transport = paramiko.Transport((host, ssh_port))
-    transport.start_client()
-    server_key = transport.get_remote_server_key()
-    keytype = server_key.get_name()
-
-    our_server_key = host_keys.get(host, {}).get(keytype, None)
-    if options.ssh_key_check:
-      if not our_server_key:
-        hexified_key = ssh.FormatParamikoFingerprint(
-            paramiko.util.hexify(server_key.get_fingerprint()))
-        msg = ("Unable to verify hostkey of host %s: %s. Do you want to accept"
-               " it?" % (host, hexified_key))
-
-        if cli.AskUser(msg):
-          our_server_key = server_key
-
-      if our_server_key != server_key:
-        logging.error("Unable to verify identity of host. Aborting")
-        transport.close()
-        transport.join()
-        # TODO: Run over all hosts, fetch the keys and let them verify from the
-        #       user beforehand then proceed with actual work later on
-        raise paramiko.SSHException("Unable to verify identity of host")
+    logging.info("Configuring %s", host)
 
-    try:
-      if LoginViaKeys(transport, username, all_keys):
-        logging.info("Authenticated to %s via public key", host)
-      else:
-        logging.warning("Authentication to %s via public key failed, trying"
-                        " password", host)
-        if passwd is None:
-          passwd = getpass.getpass(prompt="%s password:" % username)
-        transport.auth_password(username=username, password=passwd)
-        logging.info("Authenticated to %s via password", host)
-    except paramiko.SSHException, err:
-      logging.error("Connection or authentication failed to host %s: %s",
-                    host, err)
-      transport.close()
-      # this is needed for compatibility with older Paramiko or Python
-      # versions
-      transport.join()
-      continue
+    transport = paramiko.Transport((host, ssh_port))
     try:
       try:
-        if not _CheckJoin(transport):
-          if options.force_join:
-            logging.warning("Host %s failed join check, forced to continue",
-                            host)
+        transport.start_client()
+
+        if options.ssh_key_check:
+          _VerifyServerKey(transport, host, host_keys)
+
+        try:
+          if LoginViaKeys(transport, username, all_keys):
+            logging.info("Authenticated to %s via public key", host)
           else:
+            if all_keys:
+              logging.warning("Authentication to %s via public key failed,"
+                              " trying password", host)
+            if passwd is None:
+              passwd = getpass.getpass(prompt="%s password:" % username)
+            transport.auth_password(username=username, password=passwd)
+            logging.info("Authenticated to %s via password", host)
+        except paramiko.SSHException, err:
+          raise AuthError("Auth error TODO" % err)
+
+        if not _CheckJoin(transport):
+          if not options.force_join:
             raise JoinCheckError(("Host %s failed join check; Please verify"
                                   " that the host was not previously joined"
                                   " to another cluster and use --force-join"
                                   " to continue") % host)
+
+          logging.warning("Host %s failed join check, forced to continue",
+                          host)
+
         SetupSSH(transport)
-      except errors.GenericError, err:
-        logging.error("While doing setup on host %s an error occurred: %s",
-                      host, err)
-        errs += 1
-    finally:
-      transport.close()
-      # this is needed for compatibility with older Paramiko or Python
-      # versions
-      transport.join()
-
-    if errs > 0:
-      sys.exit(1)
+        logging.info("%s successfully configured", host)
+      finally:
+        transport.close()
+        # this is needed for compatibility with older Paramiko or Python
+        # versions
+        transport.join()
+    except AuthError, err:
+      logging.error("Authentication error: %s", err)
+      success = False
+      break
+    except HostKeyVerificationError, err:
+      logging.error("Host key verification error: %s", err)
+      success = False
+    except Exception, err:
+      logging.exception("During setup of %s: %s", host, err)
+      success = False
+
+  if success:
+    sys.exit(constants.EXIT_SUCCESS)
+
+  sys.exit(constants.EXIT_FAILURE)
 
 
 if __name__ == "__main__":
-- 
GitLab