ssh.py 35.9 KB
Newer Older
Iustin Pop's avatar
Iustin Pop committed
1
#
Iustin Pop's avatar
Iustin Pop committed
2 3
#

4
# Copyright (C) 2006, 2007, 2010, 2011 Google Inc.
Klaus Aehlig's avatar
Klaus Aehlig committed
5
# All rights reserved.
Iustin Pop's avatar
Iustin Pop committed
6
#
Klaus Aehlig's avatar
Klaus Aehlig committed
7 8 9
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
Iustin Pop's avatar
Iustin Pop committed
10
#
Klaus Aehlig's avatar
Klaus Aehlig committed
11 12
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
Iustin Pop's avatar
Iustin Pop committed
13
#
Klaus Aehlig's avatar
Klaus Aehlig committed
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
# 2. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
# IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
# TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Iustin Pop's avatar
Iustin Pop committed
29 30 31 32 33 34 35


"""Module encapsulating ssh functionality.

"""


36
import logging
37 38
import os
import tempfile
Iustin Pop's avatar
Iustin Pop committed
39

40 41
from functools import partial

Iustin Pop's avatar
Iustin Pop committed
42 43
from ganeti import utils
from ganeti import errors
Iustin Pop's avatar
Iustin Pop committed
44
from ganeti import constants
45
from ganeti import netutils
46
from ganeti import pathutils
47
from ganeti import vcluster
48
from ganeti import compat
49
from ganeti import serializer
50
from ganeti import ssconf
Iustin Pop's avatar
Iustin Pop committed
51 52


53
def GetUserFiles(user, mkdir=False, dircheck=True, kind=constants.SSHK_DSA,
54
                 _homedir_fn=None):
55 56 57 58 59 60
  """Return the paths of a user's SSH files.

  @type user: string
  @param user: Username
  @type mkdir: bool
  @param mkdir: Whether to create ".ssh" directory if it doesn't exist
61 62
  @type dircheck: bool
  @param dircheck: Whether to check if ".ssh" directory exists
63 64 65 66 67 68 69 70
  @type kind: string
  @param kind: One of L{constants.SSHK_ALL}
  @rtype: tuple; (string, string, string)
  @return: Tuple containing three file system paths; the private SSH key file,
    the public SSH key file and the user's C{authorized_keys} file
  @raise errors.OpExecError: When home directory of the user can not be
    determined
  @raise errors.OpExecError: Regardless of the C{mkdir} parameters, this
71 72
    exception is raised if C{~$user/.ssh} is not a directory and C{dircheck}
    is set to C{True}
73 74

  """
75 76 77
  if _homedir_fn is None:
    _homedir_fn = utils.GetHomeDir

78
  user_dir = _homedir_fn(user)
79
  if not user_dir:
80 81 82 83 84 85
    raise errors.OpExecError("Cannot resolve home of user '%s'" % user)

  if kind == constants.SSHK_DSA:
    suffix = "dsa"
  elif kind == constants.SSHK_RSA:
    suffix = "rsa"
Helga Velroyen's avatar
Helga Velroyen committed
86 87
  elif kind == constants.SSHK_ECDSA:
    suffix = "ecdsa"
88 89
  else:
    raise errors.ProgrammerError("Unknown SSH key kind '%s'" % kind)
90

91
  ssh_dir = utils.PathJoin(user_dir, ".ssh")
92 93
  if mkdir:
    utils.EnsureDirs([(ssh_dir, constants.SECURE_DIR_MODE)])
94
  elif dircheck and not os.path.isdir(ssh_dir):
95
    raise errors.OpExecError("Path %s is not a directory" % ssh_dir)
96

97
  return [utils.PathJoin(ssh_dir, base)
98 99
          for base in ["id_%s" % suffix, "id_%s.pub" % suffix,
                       "authorized_keys"]]
100 101


102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
def GetAllUserFiles(user, mkdir=False, dircheck=True, _homedir_fn=None):
  """Wrapper over L{GetUserFiles} to retrieve files for all SSH key types.

  See L{GetUserFiles} for details.

  @rtype: tuple; (string, dict with string as key, tuple of (string, string) as
    value)

  """
  helper = compat.partial(GetUserFiles, user, mkdir=mkdir, dircheck=dircheck,
                          _homedir_fn=_homedir_fn)
  result = [(kind, helper(kind=kind)) for kind in constants.SSHK_ALL]

  authorized_keys = [i for (_, (_, _, i)) in result]

  assert len(frozenset(authorized_keys)) == 1, \
    "Different paths for authorized_keys were returned"

  return (authorized_keys[0],
          dict((kind, (privkey, pubkey))
               for (kind, (privkey, pubkey, _)) in result))


125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
def _SplitSshKey(key):
  """Splits a line for SSH's C{authorized_keys} file.

  If the line has no options (e.g. no C{command="..."}), only the significant
  parts, the key type and its hash, are used. Otherwise the whole line is used
  (split at whitespace).

  @type key: string
  @param key: Key line
  @rtype: tuple

  """
  parts = key.split()

  if parts and parts[0] in constants.SSHAK_ALL:
    # If the key has no options in front of it, we only want the significant
    # fields
    return (False, parts[:2])
  else:
    # Can't properly split the line, so use everything
    return (True, parts)


148 149
def AddAuthorizedKeys(file_obj, keys):
  """Adds a list of SSH public key to an authorized_keys file.
150 151 152

  @type file_obj: str or file handle
  @param file_obj: path to authorized_keys file
153 154
  @type keys: list of str
  @param keys: list of strings containing keys
155 156

  """
157
  key_field_list = [(key, _SplitSshKey(key)) for key in keys]
158 159 160 161 162 163 164 165 166 167

  if isinstance(file_obj, basestring):
    f = open(file_obj, "a+")
  else:
    f = file_obj

  try:
    nl = True
    for line in f:
      # Ignore whitespace changes
168 169 170 171
      line_key = _SplitSshKey(line)
      key_field_list[:] = [(key, split_key) for (key, split_key)
                           in key_field_list
                           if split_key != line_key]
172 173 174 175
      nl = line.endswith("\n")
    else:
      if not nl:
        f.write("\n")
176 177 178
      for (key, _) in key_field_list:
        f.write(key.rstrip("\r\n"))
        f.write("\n")
179 180 181 182 183
      f.flush()
  finally:
    f.close()


Helga Velroyen's avatar
Helga Velroyen committed
184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
def HasAuthorizedKey(file_obj, key):
  """Check if a particular key is in the 'authorized_keys' file.

  @type file_obj: str or file handle
  @param file_obj: path to authorized_keys file
  @type key: str
  @param key: string containing key

  """
  key_fields = _SplitSshKey(key)

  if isinstance(file_obj, basestring):
    f = open(file_obj, "r")
  else:
    f = file_obj

  try:
    for line in f:
      # Ignore whitespace changes
      line_key = _SplitSshKey(line)
      if line_key == key_fields:
        return True
  finally:
    f.close()

  return False


212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253
def CheckForMultipleKeys(file_obj, node_names):
  """Check if there is at most one key per host in 'authorized_keys' file.

  @type file_obj: str or file handle
  @param file_obj: path to authorized_keys file
  @type node_names: list of str
  @param node_names: list of names of nodes of the cluster
  @returns: a dictionary with hostnames which occur more than once

  """

  if isinstance(file_obj, basestring):
    f = open(file_obj, "r")
  else:
    f = file_obj

  occurrences = {}

  try:
    index = 0
    for line in f:
      index += 1
      if line.startswith("#"):
        continue
      chunks = line.split()
      # find the chunk with user@hostname
      user_hostname = [chunk.strip() for chunk in chunks if "@" in chunk][0]
      if not user_hostname in occurrences:
        occurrences[user_hostname] = []
      occurrences[user_hostname].append(index)
  finally:
    f.close()

  bad_occurrences = {}
  for user_hostname, occ in occurrences.items():
    _, hostname = user_hostname.split("@")
    if hostname in node_names and len(occ) > 1:
      bad_occurrences[user_hostname] = occ

  return bad_occurrences


254 255 256 257 258 259 260 261 262 263 264 265
def AddAuthorizedKey(file_obj, key):
  """Adds an SSH public key to an authorized_keys file.

  @type file_obj: str or file handle
  @param file_obj: path to authorized_keys file
  @type key: str
  @param key: string containing key

  """
  AddAuthorizedKeys(file_obj, [key])


266 267
def RemoveAuthorizedKeys(file_name, keys):
  """Removes public SSH keys from an authorized_keys file.
268 269 270

  @type file_name: str
  @param file_name: path to authorized_keys file
271 272
  @type keys: list of str
  @param keys: list of strings containing keys
273 274

  """
275
  key_field_list = [_SplitSshKey(key) for key in keys]
276 277 278 279 280 281 282 283 284

  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
  try:
    out = os.fdopen(fd, "w")
    try:
      f = open(file_name, "r")
      try:
        for line in f:
          # Ignore whitespace changes while comparing lines
285
          if _SplitSshKey(line) not in key_field_list:
286 287 288 289 290 291 292 293 294 295 296 297 298
            out.write(line)

        out.flush()
        os.rename(tmpname, file_name)
      finally:
        f.close()
    finally:
      out.close()
  except:
    utils.RemoveFile(tmpname)
    raise


299 300 301 302 303 304 305 306 307 308 309 310
def RemoveAuthorizedKey(file_name, key):
  """Removes an SSH public key from an authorized_keys file.

  @type file_name: str
  @param file_name: path to authorized_keys file
  @type key: str
  @param key: string containing key

  """
  RemoveAuthorizedKeys(file_name, [key])


311
def _AddPublicKeyProcessLine(new_uuid, new_key, line_uuid, line_key, found):
312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329
  """Processes one line of the public key file when adding a key.

  This is a sub function that can be called within the
  C{_ManipulatePublicKeyFile} function. It processes one line of the public
  key file, checks if this line contains the key to add already and if so,
  notes the occurrence in the return value.

  @type new_uuid: string
  @param new_uuid: the node UUID of the node whose key is added
  @type new_key: string
  @param new_key: the SSH key to be added
  @type line_uuid: the UUID of the node whose line in the public key file
    is processed in this function call
  @param line_key: the SSH key of the node whose line in the public key
    file is processed in this function call
  @type found: boolean
  @param found: whether or not the (UUID, key) pair of the node whose key
    is being added was found in the public key file already.
330 331
  @rtype: (boolean, string)
  @return: a possibly updated value of C{found} and the processed line
332 333 334 335 336

  """
  if line_uuid == new_uuid and line_key == new_key:
    logging.debug("SSH key of node '%s' already in key file.", new_uuid)
    found = True
337
  return (found, "%s %s\n" % (line_uuid, line_key))
338 339


340
def _AddPublicKeyElse(new_uuid, new_key):
341 342 343 344 345 346 347 348 349 350 351
  """Adds a new SSH key to the key file if it did not exist already.

  This is an auxiliary function for C{_ManipulatePublicKeyFile} which
  is carried out when a new key is added to the public key file and
  after processing the whole file, we found out that the key does
  not exist in the file yet but needs to be appended at the end.

  @type new_uuid: string
  @param new_uuid: the UUID of the node whose key is added
  @type new_key: string
  @param new_key: the SSH key to be added
352 353
  @rtype: string
  @return: a new line to be added to the file
354 355

  """
356
  return "%s %s\n" % (new_uuid, new_key)
357 358 359 360


def _RemovePublicKeyProcessLine(
    target_uuid, _target_key,
361
    line_uuid, line_key, found):
362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378
  """Processes a line in the public key file when aiming for removing a key.

  This is an auxiliary function for C{_ManipulatePublicKeyFile} when we
  are removing a key from the public key file. This particular function
  only checks if the current line contains the UUID of the node in
  question and writes the line to the temporary file otherwise.

  @type target_uuid: string
  @param target_uuid: UUID of the node whose key is being removed
  @type _target_key: string
  @param _target_key: SSH key of the node (not used)
  @type line_uuid: string
  @param line_uuid: UUID of the node whose line is processed in this call
  @type line_key: string
  @param line_key: SSH key of the nodes whose line is processed in this call
  @type found: boolean
  @param found: whether or not the UUID was already found.
379 380 381
  @rtype: (boolean, string)
  @return: a tuple, indicating if the target line was found and the processed
    line; the line is 'None', if the original line is removed
382 383 384

  """
  if line_uuid != target_uuid:
385
    return (found, "%s %s\n" % (line_uuid, line_key))
386
  else:
387
    return (True, None)
388 389 390


def _RemovePublicKeyElse(
391
    target_uuid, _target_key):
392 393 394 395 396 397 398 399 400 401 402
  """Logs when we tried to remove a key that does not exist.

  This is an auxiliary function for C{_ManipulatePublicKeyFile} which is
  run after we have processed the complete public key file and did not find
  the key to be removed.

  @type target_uuid: string
  @param target_uuid: the UUID of the node whose key was supposed to be removed
  @type _target_key: string
  @param _target_key: the key of the node which was supposed to be removed
    (not used)
403 404
  @rtype: string
  @return: in this case, always None
405 406 407 408

  """
  logging.debug("Trying to remove key of node '%s' which is not in list"
                " of public keys.", target_uuid)
409
  return None
410 411 412


def _ReplaceNameByUuidProcessLine(
413
    node_name, _key, line_identifier, line_key, found, node_uuid=None):
414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433
  """Replaces a node's name with its UUID on a matching line in the key file.

  This is an auxiliary function for C{_ManipulatePublicKeyFile} which processes
  a line of the ganeti public key file. If the line in question matches the
  node's name, the name will be replaced by the node's UUID.

  @type node_name: string
  @param node_name: name of the node to be replaced by the UUID
  @type _key: string
  @param _key: SSH key of the node (not used)
  @type line_identifier: string
  @param line_identifier: an identifier of a node in a line of the public key
    file. This can be either a node name or a node UUID, depending on if it
    got replaced already or not.
  @type line_key: string
  @param line_key: SSH key of the node whose line is processed
  @type found: boolean
  @param found: whether or not the line matches the node's name
  @type node_uuid: string
  @param node_uuid: the node's UUID which will replace the node name
434 435 436
  @rtype: (boolean, string)
  @return: a tuple indicating whether the target line was found and the
    processed line
437 438 439

  """
  if node_name == line_identifier:
440
    return (True, "%s %s\n" % (node_uuid, line_key))
441
  else:
442
    return (found, "%s %s\n" % (line_identifier, line_key))
443 444 445


def _ReplaceNameByUuidElse(
446
    node_uuid, node_name, _key):
447 448 449 450 451 452 453 454 455 456 457 458 459
  """Logs a debug message when we try to replace a key that is not there.

  This is an implementation of the auxiliary C{process_else_fn} function for
  the C{_ManipulatePubKeyFile} function when we use it to replace a line
  in the public key file that is indexed by the node's name instead of the
  node's UUID.

  @type node_uuid: string
  @param node_uuid: the node's UUID
  @type node_name: string
  @param node_name: the node's UUID
  @type _key: string (not used)
  @param _key: the node's SSH key (not used)
460 461
  @rtype: string
  @return: in this case, always None
462 463 464 465

  """
  logging.debug("Trying to replace node name '%s' with UUID '%s', but"
                " no line with that name was found.", node_name, node_uuid)
466
  return None
467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500


def _ParseKeyLine(line, error_fn):
  """Parses a line of the public key file.

  @type line: string
  @param line: line of the public key file
  @type error_fn: function
  @param error_fn: function to process error messages
  @rtype: tuple (string, string)
  @return: a tuple containing the UUID of the node and a string containing
    the SSH key and possible more parameters for the key

  """
  if len(line.rstrip()) == 0:
    return (None, None)
  chunks = line.split(" ")
  if len(chunks) < 2:
    raise error_fn("Error parsing public SSH key file. Line: '%s'"
                   % line)
  uuid = chunks[0]
  key = " ".join(chunks[1:]).rstrip()
  return (uuid, key)


def _ManipulatePubKeyFile(target_identifier, target_key,
                          key_file=pathutils.SSH_PUB_KEYS,
                          error_fn=errors.ProgrammerError,
                          process_line_fn=None, process_else_fn=None):
  """Manipulates the list of public SSH keys of the cluster.

  This is a general function to manipulate the public key file. It needs
  two auxiliary functions C{process_line_fn} and C{process_else_fn} to
  work. Generally, the public key file is processed as follows:
501 502 503 504 505 506
  1) The function processes each line of the original ganeti public key file,
  applies the C{process_line_fn} function on it, which returns a possibly
  manipulated line and an indicator whether the line in question was found.
  If a line is returned, it is added to a list of lines for later writing
  to the file.
  2) If all lines are processed and the 'found' variable is False, the
507
  seconds auxiliary function C{process_else_fn} is called to possibly
508 509 510
  add more lines to the list of lines.
  3) Finally, the list of lines is assembled to a string and written
  atomically to the public key file, thereby overriding it.
511

512 513 514
  If the public key file does not exist, we create it. This is necessary for
  a smooth transition after an upgrade.

515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536
  @type target_identifier: str
  @param target_identifier: identifier of the node whose key is added; in most
    cases this is the node's UUID, but in some it is the node's host name
  @type target_key: str
  @param target_key: string containing a public SSH key (a complete line
    possibly including more parameters than just the key)
  @type key_file: str
  @param key_file: filename of the file of public node keys (optional
    parameter for testing)
  @type error_fn: function
  @param error_fn: Function that returns an exception, used to customize
    exception types depending on the calling context
  @type process_line_fn: function
  @param process_line_fn: function to process one line of the public key file
  @type process_else_fn: function
  @param process_else_fn: function to be called if no line of the key file
    matches the target uuid

  """
  assert process_else_fn is not None
  assert process_line_fn is not None

537
  old_lines = []
538 539 540 541 542 543 544 545 546 547 548 549 550
  f_orig = None
  if os.path.exists(key_file):
    try:
      f_orig = open(key_file, "r")
      old_lines = f_orig.readlines()
    finally:
      f_orig.close()
  else:
    try:
      f_orig = open(key_file, "w")
      f_orig.close()
    except IOError as e:
      raise errors.SshUpdateError("Cannot create public key file: %s" % e)
551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569

  found = False
  new_lines = []
  for line in old_lines:
    (uuid, key) = _ParseKeyLine(line, error_fn)
    if not uuid:
      continue
    (new_found, new_line) = process_line_fn(target_identifier, target_key,
                                            uuid, key, found)
    if new_found:
      found = True
    if new_line is not None:
      new_lines.append(new_line)
  if not found:
    new_line = process_else_fn(target_identifier, target_key)
    if new_line is not None:
      new_lines.append(new_line)
  new_file_content = "".join(new_lines)
  utils.WriteFile(key_file, data=new_file_content)
570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621


def AddPublicKey(new_uuid, new_key, key_file=pathutils.SSH_PUB_KEYS,
                 error_fn=errors.ProgrammerError):
  """Adds a new key to the list of public keys.

  @see: _ManipulatePubKeyFile for parameter descriptions.

  """
  _ManipulatePubKeyFile(new_uuid, new_key, key_file=key_file,
                        process_line_fn=_AddPublicKeyProcessLine,
                        process_else_fn=_AddPublicKeyElse,
                        error_fn=error_fn)


def RemovePublicKey(target_uuid, key_file=pathutils.SSH_PUB_KEYS,
                    error_fn=errors.ProgrammerError):
  """Removes a key from the list of public keys.

  @see: _ManipulatePubKeyFile for parameter descriptions.

  """
  _ManipulatePubKeyFile(target_uuid, None, key_file=key_file,
                        process_line_fn=_RemovePublicKeyProcessLine,
                        process_else_fn=_RemovePublicKeyElse,
                        error_fn=error_fn)


def ReplaceNameByUuid(node_uuid, node_name, key_file=pathutils.SSH_PUB_KEYS,
                      error_fn=errors.ProgrammerError):
  """Replaces a host name with the node's corresponding UUID.

  When a node is added to the cluster, we don't know it's UUID yet. So first
  its SSH key gets added to the public key file and in a second step, the
  node's name gets replaced with the node's UUID as soon as we know the UUID.

  @type node_uuid: string
  @param node_uuid: the node's UUID to replace the node's name
  @type node_name: string
  @param node_name: the node's name to be replaced by the node's UUID

  @see: _ManipulatePubKeyFile for the other parameter descriptions.

  """
  process_line_fn = partial(_ReplaceNameByUuidProcessLine, node_uuid=node_uuid)
  process_else_fn = partial(_ReplaceNameByUuidElse, node_uuid=node_uuid)
  _ManipulatePubKeyFile(node_name, None, key_file=key_file,
                        process_line_fn=process_line_fn,
                        process_else_fn=process_else_fn,
                        error_fn=error_fn)


622 623 624 625 626 627 628
def ClearPubKeyFile(key_file=pathutils.SSH_PUB_KEYS, mode=0600):
  """Resets the content of the public key file.

  """
  utils.WriteFile(key_file, data="", mode=mode)


629
def OverridePubKeyFile(key_map, key_file=pathutils.SSH_PUB_KEYS):
630 631 632 633 634 635
  """Overrides the public key file with a list of given keys.

  @type key_map: dict from str to list of str
  @param key_map: dictionary mapping uuids to lists of SSH keys

  """
636 637 638 639 640 641
  new_lines = []
  for (uuid, keys) in key_map.items():
    for key in keys:
      new_lines.append("%s %s\n" % (uuid, key))
  new_file_content = "".join(new_lines)
  utils.WriteFile(key_file, data=new_file_content)
642 643


644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660
def QueryPubKeyFile(target_uuids, key_file=pathutils.SSH_PUB_KEYS,
                    error_fn=errors.ProgrammerError):
  """Retrieves a map of keys for the requested node UUIDs.

  @type target_uuids: str or list of str
  @param target_uuids: UUID of the node to retrieve the key for or a list
    of UUIDs of nodes to retrieve the keys for
  @type key_file: str
  @param key_file: filename of the file of public node keys (optional
    parameter for testing)
  @type error_fn: function
  @param error_fn: Function that returns an exception, used to customize
    exception types depending on the calling context
  @rtype: dict mapping strings to list of strings
  @return: dictionary mapping node uuids to their ssh keys

  """
661
  all_keys = target_uuids is None
662 663 664 665 666 667 668 669 670
  if isinstance(target_uuids, str):
    target_uuids = [target_uuids]
  result = {}
  f = open(key_file, "r")
  try:
    for line in f:
      (uuid, key) = _ParseKeyLine(line, error_fn)
      if not uuid:
        continue
671
      if all_keys or (uuid in target_uuids):
672 673 674 675 676 677 678 679
        if uuid not in result:
          result[uuid] = []
        result[uuid].append(key)
  finally:
    f.close()
  return result


680 681
def InitSSHSetup(error_fn=errors.OpPrereqError, _homedir_fn=None,
                 _suffix=""):
682 683 684 685 686 687
  """Setup the SSH configuration for the node.

  This generates a dsa keypair for root, adds the pub key to the
  permitted hosts and adds the hostkey to its own known hosts.

  """
688
  priv_key, _, auth_keys = GetUserFiles(constants.SSH_LOGIN_USER,
Helga Velroyen's avatar
Helga Velroyen committed
689
                                        _homedir_fn=_homedir_fn)
690

691 692
  new_priv_key_name = priv_key + _suffix
  new_pub_key_name = priv_key + _suffix + ".pub"
693

694 695 696 697 698
  for name in new_priv_key_name, new_pub_key_name:
    if os.path.exists(name):
      utils.CreateBackup(name)
    utils.RemoveFile(name)

699
  result = utils.RunCmd(["ssh-keygen", "-t", "dsa",
700
                         "-f", new_priv_key_name,
701 702 703 704 705
                         "-q", "-N", ""])
  if result.failed:
    raise error_fn("Could not generate ssh keypair, error %s" %
                   result.output)

706
  AddAuthorizedKey(auth_keys, utils.ReadFile(new_pub_key_name))
707 708


709 710 711 712 713 714 715 716 717 718
def InitPubKeyFile(master_uuid, key_file=pathutils.SSH_PUB_KEYS):
  """Creates the public key file and adds the master node's SSH key.

  @type master_uuid: str
  @param master_uuid: the master node's UUID
  @type key_file: str
  @param key_file: name of the file containing the public keys

  """
  _, pub_key, _ = GetUserFiles(constants.SSH_LOGIN_USER)
719
  ClearPubKeyFile(key_file=key_file)
720 721 722 723
  key = utils.ReadFile(pub_key)
  AddPublicKey(master_uuid, key, key_file=key_file)


724
class SshRunner:
725
  """Wrapper for SSH commands.
Iustin Pop's avatar
Iustin Pop committed
726 727

  """
728
  def __init__(self, cluster_name):
Manuel Franceschini's avatar
Manuel Franceschini committed
729 730 731 732 733 734
    """Initializes this class.

    @type cluster_name: str
    @param cluster_name: name of the cluster

    """
Iustin Pop's avatar
Iustin Pop committed
735
    self.cluster_name = cluster_name
736 737
    family = ssconf.SimpleStore().GetPrimaryIPFamily()
    self.ipv6 = (family == netutils.IP6Address.family)
738

739
  def _BuildSshOptions(self, batch, ask_key, use_cluster_key,
740 741
                       strict_host_check, private_key=None, quiet=True,
                       port=None):
742 743 744 745 746 747 748 749
    """Builds a list with needed SSH options.

    @param batch: same as ssh's batch option
    @param ask_key: allows ssh to ask for key confirmation; this
        parameter conflicts with the batch one
    @param use_cluster_key: if True, use the cluster name as the
        HostKeyAlias name
    @param strict_host_check: this makes the host key checking strict
750
    @param private_key: use this private key instead of the default
Iustin Pop's avatar
Iustin Pop committed
751
    @param quiet: whether to enable -q to ssh
752
    @param port: the SSH port to use, or None to use the default
753 754

    @rtype: list
755
    @return: the list of options ready to use in L{utils.process.RunCmd}
756 757

    """
758 759 760
    options = [
      "-oEscapeChar=none",
      "-oHashKnownHosts=no",
761
      "-oGlobalKnownHostsFile=%s" % pathutils.SSH_KNOWN_HOSTS_FILE,
762
      "-oUserKnownHostsFile=/dev/null",
Iustin Pop's avatar
Iustin Pop committed
763
      "-oCheckHostIp=no",
764 765 766
      ]

    if use_cluster_key:
Iustin Pop's avatar
Iustin Pop committed
767
      options.append("-oHostKeyAlias=%s" % self.cluster_name)
768

Iustin Pop's avatar
Iustin Pop committed
769 770 771
    if quiet:
      options.append("-q")

772 773 774
    if private_key:
      options.append("-i%s" % private_key)

775 776 777
    if port:
      options.append("-oPort=%d" % port)

778 779 780
    # TODO: Too many boolean options, maybe convert them to more descriptive
    # constants.

781 782 783 784 785
    # Note: ask_key conflicts with batch mode
    if batch:
      if ask_key:
        raise errors.ProgrammerError("SSH call requested conflicting options")

786 787 788 789 790 791
      options.append("-oBatchMode=yes")

      if strict_host_check:
        options.append("-oStrictHostKeyChecking=yes")
      else:
        options.append("-oStrictHostKeyChecking=no")
792

793 794 795 796 797 798 799 800 801
    else:
      # non-batch mode

      if ask_key:
        options.append("-oStrictHostKeyChecking=ask")
      elif strict_host_check:
        options.append("-oStrictHostKeyChecking=yes")
      else:
        options.append("-oStrictHostKeyChecking=no")
802

Manuel Franceschini's avatar
Manuel Franceschini committed
803 804
    if self.ipv6:
      options.append("-6")
805 806
    else:
      options.append("-4")
Manuel Franceschini's avatar
Manuel Franceschini committed
807

808
    return options
809

810
  def BuildCmd(self, hostname, user, command, batch=True, ask_key=False,
811
               tty=False, use_cluster_key=True, strict_host_check=True,
812
               private_key=None, quiet=True, port=None):
813 814
    """Build an ssh command to execute a command on a remote node.

Iustin Pop's avatar
Iustin Pop committed
815 816 817 818 819 820 821 822 823 824
    @param hostname: the target host, string
    @param user: user to auth as
    @param command: the command
    @param batch: if true, ssh will run in batch mode with no prompting
    @param ask_key: if true, ssh will run with
        StrictHostKeyChecking=ask, so that we can connect to an
        unknown host (not valid in batch mode)
    @param use_cluster_key: whether to expect and use the
        cluster-global SSH key
    @param strict_host_check: whether to check the host's SSH key at all
825
    @param private_key: use this private key instead of the default
Iustin Pop's avatar
Iustin Pop committed
826
    @param quiet: whether to enable -q to ssh
827
    @param port: the SSH port on which the node's daemon is running
Iustin Pop's avatar
Iustin Pop committed
828 829

    @return: the ssh call to run 'command' on the remote host.
830 831

    """
Iustin Pop's avatar
Iustin Pop committed
832
    argv = [constants.SSH]
833
    argv.extend(self._BuildSshOptions(batch, ask_key, use_cluster_key,
Iustin Pop's avatar
Iustin Pop committed
834
                                      strict_host_check, private_key,
835
                                      quiet=quiet, port=port))
836
    if tty:
Balazs Lecz's avatar
Balazs Lecz committed
837
      argv.extend(["-t", "-t"])
838 839 840 841 842 843 844 845 846 847 848

    argv.append("%s@%s" % (user, hostname))

    # Insert variables for virtual nodes
    argv.extend("export %s=%s;" %
                (utils.ShellQuote(name), utils.ShellQuote(value))
                for (name, value) in
                  vcluster.EnvironmentForHost(hostname).items())

    argv.append(command)

849 850
    return argv

851
  def Run(self, *args, **kwargs):
852 853 854 855 856
    """Runs a command on a remote node.

    This method has the same return value as `utils.RunCmd()`, which it
    uses to launch ssh.

Iustin Pop's avatar
Iustin Pop committed
857
    Args: see SshRunner.BuildCmd.
858

859 860
    @rtype: L{utils.process.RunResult}
    @return: the result as from L{utils.process.RunCmd()}
861 862

    """
863
    return utils.RunCmd(self.BuildCmd(*args, **kwargs))
864

865
  def CopyFileToNode(self, node, port, filename):
866 867
    """Copy a file to another node with scp.

Iustin Pop's avatar
Iustin Pop committed
868 869
    @param node: node in the cluster
    @param filename: absolute pathname of a local file
870

Iustin Pop's avatar
Iustin Pop committed
871 872
    @rtype: boolean
    @return: the success of the operation
Iustin Pop's avatar
Iustin Pop committed
873

874 875
    """
    if not os.path.isabs(filename):
876
      logging.error("File %s must be an absolute path", filename)
877
      return False
Iustin Pop's avatar
Iustin Pop committed
878

879
    if not os.path.isfile(filename):
880
      logging.error("File %s does not exist", filename)
881 882
      return False

Iustin Pop's avatar
Iustin Pop committed
883
    command = [constants.SCP, "-p"]
884
    command.extend(self._BuildSshOptions(True, False, True, True, port=port))
885
    command.append(filename)
886 887 888
    if netutils.IP6Address.IsValid(node):
      node = netutils.FormatAddress((node, None))

889
    command.append("%s:%s" % (node, vcluster.ExchangeNodeRoot(node, filename)))
Iustin Pop's avatar
Iustin Pop committed
890

891
    result = utils.RunCmd(command)
Iustin Pop's avatar
Iustin Pop committed
892

893
    if result.failed:
894 895
      logging.error("Copy to node %s failed (%s) error '%s',"
                    " command was '%s'",
896
                    node, result.fail_reason, result.output, result.cmd)
Iustin Pop's avatar
Iustin Pop committed
897

898
    return not result.failed
Iustin Pop's avatar
Iustin Pop committed
899

900
  def VerifyNodeHostname(self, node, ssh_port):
901
    """Verify hostname consistency via SSH.
Iustin Pop's avatar
Iustin Pop committed
902

903 904 905
    This functions connects via ssh to a node and compares the hostname
    reported by the node to the name with have (the one that we
    connected to).
Iustin Pop's avatar
Iustin Pop committed
906

907
    This is used to detect problems in ssh known_hosts files
Michael Hanselmann's avatar
Michael Hanselmann committed
908
    (conflicting known hosts) and inconsistencies between dns/hosts
909
    entries and local machine names
Iustin Pop's avatar
Iustin Pop committed
910

Iustin Pop's avatar
Iustin Pop committed
911 912
    @param node: nodename of a host to check; can be short or
        full qualified hostname
913
    @param ssh_port: the port of a SSH daemon running on the node
Iustin Pop's avatar
Iustin Pop committed
914

Iustin Pop's avatar
Iustin Pop committed
915 916 917
    @return: (success, detail), where:
        - success: True/False
        - detail: string with details
Iustin Pop's avatar
Iustin Pop committed
918

919
    """
920 921 922 923 924
    cmd = ("if test -z \"$GANETI_HOSTNAME\"; then"
           "  hostname --fqdn;"
           "else"
           "  echo \"$GANETI_HOSTNAME\";"
           "fi")
925 926
    retval = self.Run(node, constants.SSH_LOGIN_USER, cmd,
                      quiet=False, port=ssh_port)
Iustin Pop's avatar
Iustin Pop committed
927

928 929 930 931 932
    if retval.failed:
      msg = "ssh problem"
      output = retval.output
      if output:
        msg += ": %s" % output
Iustin Pop's avatar
Iustin Pop committed
933 934
      else:
        msg += ": %s (no output)" % retval.fail_reason
935
      logging.error("Command %s failed: %s", retval.cmd, msg)
936
      return False, msg
Iustin Pop's avatar
Iustin Pop committed
937

938
    remotehostname = retval.stdout.strip()
Iustin Pop's avatar
Iustin Pop committed
939

940
    if not remotehostname or remotehostname != node:
941 942 943
      if node.startswith(remotehostname + "."):
        msg = "hostname not FQDN"
      else:
944
        msg = "hostname mismatch"
945 946
      return False, ("%s: expected %s but got %s" %
                     (msg, node, remotehostname))
Iustin Pop's avatar
Iustin Pop committed
947

948
    return True, "host matches"
949 950


Michael Hanselmann's avatar
Michael Hanselmann committed
951
def WriteKnownHostsFile(cfg, file_name):
952 953 954
  """Writes the cluster-wide equally known_hosts file.

  """
955 956 957 958 959 960 961
  data = ""
  if cfg.GetRsaHostKey():
    data += "%s ssh-rsa %s\n" % (cfg.GetClusterName(), cfg.GetRsaHostKey())
  if cfg.GetDsaHostKey():
    data += "%s ssh-dss %s\n" % (cfg.GetClusterName(), cfg.GetDsaHostKey())

  utils.WriteFile(file_name, mode=0600, data=data)
962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002


def _EnsureCorrectGanetiVersion(cmd):
  """Ensured the correct Ganeti version before running a command via SSH.

  Before a command is run on a node via SSH, it makes sense in some
  situations to ensure that this node is indeed running the correct
  version of Ganeti like the rest of the cluster.

  @type cmd: string
  @param cmd: string
  @rtype: list of strings
  @return: a list of commands with the newly added ones at the beginning

  """
  logging.debug("Ensure correct Ganeti version: %s", cmd)

  version = constants.DIR_VERSION
  all_cmds = [["test", "-d", os.path.join(pathutils.PKGLIBDIR, version)]]
  if constants.HAS_GNU_LN:
    all_cmds.extend([["ln", "-s", "-f", "-T",
                      os.path.join(pathutils.PKGLIBDIR, version),
                      os.path.join(pathutils.SYSCONFDIR, "ganeti/lib")],
                     ["ln", "-s", "-f", "-T",
                      os.path.join(pathutils.SHAREDIR, version),
                      os.path.join(pathutils.SYSCONFDIR, "ganeti/share")]])
  else:
    all_cmds.extend([["rm", "-f",
                      os.path.join(pathutils.SYSCONFDIR, "ganeti/lib")],
                     ["ln", "-s", "-f",
                      os.path.join(pathutils.PKGLIBDIR, version),
                      os.path.join(pathutils.SYSCONFDIR, "ganeti/lib")],
                     ["rm", "-f",
                      os.path.join(pathutils.SYSCONFDIR, "ganeti/share")],
                     ["ln", "-s", "-f",
                      os.path.join(pathutils.SHAREDIR, version),
                      os.path.join(pathutils.SYSCONFDIR, "ganeti/share")]])
  all_cmds.append(cmd)
  return all_cmds


1003 1004 1005 1006
def RunSshCmdWithStdin(cluster_name, node, basecmd, port, data,
                       debug=False, verbose=False, use_cluster_key=False,
                       ask_key=False, strict_host_check=False,
                       ensure_version=False):
1007 1008 1009 1010 1011 1012 1013 1014
  """Runs a command on a remote machine via SSH and provides input in stdin.

  @type cluster_name: string
  @param cluster_name: Cluster name
  @type node: string
  @param node: Node name
  @type basecmd: string
  @param basecmd: Base command (path on the remote machine)
1015 1016 1017
  @type port: int
  @param port: The SSH port of the remote machine or None for the default
  @param data: JSON-serializable input data for script (passed to stdin)
1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046
  @type debug: bool
  @param debug: Enable debug output
  @type verbose: bool
  @param verbose: Enable verbose output
  @type use_cluster_key: bool
  @param use_cluster_key: See L{ssh.SshRunner.BuildCmd}
  @type ask_key: bool
  @param ask_key: See L{ssh.SshRunner.BuildCmd}
  @type strict_host_check: bool
  @param strict_host_check: See L{ssh.SshRunner.BuildCmd}

  """
  cmd = [basecmd]

  # Pass --debug/--verbose to the external script if set on our invocation
  if debug:
    cmd.append("--debug")

  if verbose:
    cmd.append("--verbose")

  if ensure_version:
    all_cmds = _EnsureCorrectGanetiVersion(cmd)
  else:
    all_cmds = [cmd]

  if port is None:
    port = netutils.GetDaemonPort(constants.SSH)

1047
  srun = SshRunner(cluster_name)
1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089
  scmd = srun.BuildCmd(node, constants.SSH_LOGIN_USER,
                       utils.ShellQuoteArgs(
                           utils.ShellCombineCommands(all_cmds)),
                       batch=False, ask_key=ask_key, quiet=False,
                       strict_host_check=strict_host_check,
                       use_cluster_key=use_cluster_key,
                       port=port)

  tempfh = tempfile.TemporaryFile()
  try:
    tempfh.write(serializer.DumpJson(data))
    tempfh.seek(0)

    result = utils.RunCmd(scmd, interactive=True, input_fd=tempfh)
  finally:
    tempfh.close()

  if result.failed:
    raise errors.OpExecError("Command '%s' failed: %s" %
                             (result.cmd, result.fail_reason))


def GetSshPortMap(nodes, cfg):
  """Retrieves SSH ports of given nodes from the config.

  @param nodes: the names of nodes
  @type nodes: a list of strings
  @param cfg: a configuration object
  @type cfg: L{ConfigWriter}
  @return: a map from node names to ssh ports
  @rtype: a dict from str to int

  """
  node_port_map = {}
  node_groups = dict(map(lambda n: (n.name, n.group),
                         cfg.GetAllNodesInfo().values()))
  group_port_map = cfg.GetGroupSshPorts()
  for node in nodes:
    group_uuid = node_groups.get(node)
    ssh_port = group_port_map.get(group_uuid)
    node_port_map[node] = ssh_port
  return node_port_map
1090 1091


Helga Velroyen's avatar
Helga Velroyen committed
1092
def ReadRemoteSshPubKeys(pub_key_file, node, cluster_name, port, ask_key,
1093
                         strict_host_check):
Helga Velroyen's avatar
Helga Velroyen committed
1094
  """Fetches the public DSA SSH key from a node via SSH.
1095

Helga Velroyen's avatar
Helga Velroyen committed
1096 1097
  @type pub_key_file: string
  @param pub_key_file: a tuple consisting of the file name of the public DSA key
1098 1099

  """
1100
  ssh_runner = SshRunner(cluster_name)
1101

Helga Velroyen's avatar
Helga Velroyen committed
1102 1103 1104 1105 1106 1107 1108
  cmd = ["cat", pub_key_file]
  ssh_cmd = ssh_runner.BuildCmd(node, constants.SSH_LOGIN_USER,
                                utils.ShellQuoteArgs(cmd),
                                batch=False, ask_key=ask_key, quiet=False,
                                strict_host_check=strict_host_check,
                                use_cluster_key=False,
                                port=port)
1109

Helga Velroyen's avatar
Helga Velroyen committed
1110 1111 1112 1113 1114 1115
  result = utils.RunCmd(ssh_cmd)
  if result.failed:
    raise errors.OpPrereqError("Could not fetch a public DSA SSH key from node"
                               " '%s': ran command '%s', failure reason: '%s'."
                               % (node, cmd, result.fail_reason))
  return result.stdout