Commit d12b9f66 authored by Michael Hanselmann's avatar Michael Hanselmann
Browse files

Add initial implementation of prepare-node-join



This is a new tool as per the design document “design-ssh-setup”. It
receives a JSON data structure on its standard input and configures the
SSH daemon and root's SSH keys accordingly. Unit tests are included.
Signed-off-by: default avatarMichael Hanselmann <hansmi@google.com>
Reviewed-by: default avatarIustin Pop <iustin@google.com>
parent 8a3c9e8a
......@@ -94,6 +94,7 @@
/tools/kvm-ifup
/tools/ensure-dirs
/tools/vcluster-setup
/tools/prepare-node-join
# scripts
/scripts/gnt-backup
......
......@@ -315,7 +315,8 @@ server_PYTHON = \
pytools_PYTHON = \
lib/tools/__init__.py \
lib/tools/ensure_dirs.py
lib/tools/ensure_dirs.py \
lib/tools/prepare_node_join.py
utils_PYTHON = \
lib/utils/__init__.py \
......@@ -578,7 +579,8 @@ PYTHON_BOOTSTRAP_SBIN = \
PYTHON_BOOTSTRAP = \
$(PYTHON_BOOTSTRAP_SBIN) \
tools/ensure-dirs
tools/ensure-dirs \
tools/prepare-node-join
qa_scripts = \
qa/__init__.py \
......@@ -690,7 +692,8 @@ pkglib_python_scripts = \
tools/check-cert-expired
nodist_pkglib_python_scripts = \
tools/ensure-dirs
tools/ensure-dirs \
tools/prepare-node-join
myexeclib_SCRIPTS = \
daemons/daemon-util \
......@@ -822,6 +825,7 @@ TEST_FILES = \
test/data/bdev-drbd-net-ip4.txt \
test/data/bdev-drbd-net-ip6.txt \
test/data/cert1.pem \
test/data/cert2.pem \
test/data/ip-addr-show-dummy0.txt \
test/data/ip-addr-show-lo-ipv4.txt \
test/data/ip-addr-show-lo-ipv6.txt \
......@@ -926,6 +930,7 @@ python_tests = \
test/ganeti.ssh_unittest.py \
test/ganeti.storage_unittest.py \
test/ganeti.tools.ensure_dirs_unittest.py \
test/ganeti.tools.prepare_node_join_unittest.py \
test/ganeti.uidpool_unittest.py \
test/ganeti.utils.algo_unittest.py \
test/ganeti.utils.filelock_unittest.py \
......@@ -1327,6 +1332,7 @@ daemons/ganeti-%: MODULE = ganeti.server.$(patsubst ganeti-%,%,$(notdir $@))
daemons/ganeti-watcher: MODULE = ganeti.watcher
scripts/%: MODULE = ganeti.client.$(subst -,_,$(notdir $@))
tools/ensure-dirs: MODULE = ganeti.tools.ensure_dirs
tools/prepare-node-join: MODULE = ganeti.tools.prepare_node_join
$(HS_BUILT_TEST_HELPERS): TESTROLE = $(patsubst htest/%,%,$@)
$(PYTHON_BOOTSTRAP): Makefile | stamp-directories
......
......@@ -2049,5 +2049,17 @@ SSHK_RSA = "rsa"
SSHK_DSA = "dsa"
SSHK_ALL = frozenset([SSHK_RSA, SSHK_DSA])
# SSH authorized key types
SSHAK_RSA = "ssh-rsa"
SSHAK_DSS = "ssh-dss"
SSHAK_ALL = frozenset([SSHAK_RSA, SSHAK_DSS])
# SSH setup
SSHS_CLUSTER_NAME = "cluster_name"
SSHS_FORCE = "force"
SSHS_SSH_HOST_KEY = "ssh_host_key"
SSHS_SSH_ROOT_KEY = "ssh_root_key"
SSHS_NODE_DAEMON_CERTIFICATE = "node_daemon_certificate"
# Do not re-export imported modules
del re, _vcsversion, _autoconf, socket, pathutils
......@@ -49,7 +49,7 @@ def FormatParamikoFingerprint(fingerprint):
def GetUserFiles(user, mkdir=False, kind=constants.SSHK_DSA,
_homedir_fn=utils.GetHomeDir):
_homedir_fn=None):
"""Return the paths of a user's SSH files.
@type user: string
......@@ -67,6 +67,9 @@ def GetUserFiles(user, mkdir=False, kind=constants.SSHK_DSA,
exception is raised if C{~$user/.ssh} is not a directory
"""
if _homedir_fn is None:
_homedir_fn = utils.GetHomeDir
user_dir = _homedir_fn(user)
if not user_dir:
raise errors.OpExecError("Cannot resolve home of user '%s'" % user)
......
#
#
# Copyright (C) 2012 Google Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
# 02110-1301, USA.
"""Script to prepare a node for joining a cluster.
"""
import os
import os.path
import optparse
import sys
import logging
import errno
import OpenSSL
from ganeti import cli
from ganeti import constants
from ganeti import errors
from ganeti import pathutils
from ganeti import utils
from ganeti import serializer
from ganeti import ht
from ganeti import ssh
from ganeti import ssconf
_SSH_KEY_LIST = \
ht.TListOf(ht.TAnd(ht.TIsLength(3),
ht.TItems([
ht.TElemOf(constants.SSHK_ALL),
ht.Comment("public")(ht.TNonEmptyString),
ht.Comment("private")(ht.TNonEmptyString),
])))
_DATA_CHECK = ht.TStrictDict(False, True, {
constants.SSHS_CLUSTER_NAME: ht.TNonEmptyString,
constants.SSHS_FORCE: ht.TBool,
constants.SSHS_NODE_DAEMON_CERTIFICATE: ht.TNonEmptyString,
constants.SSHS_SSH_HOST_KEY: _SSH_KEY_LIST,
constants.SSHS_SSH_ROOT_KEY: _SSH_KEY_LIST,
})
_SSHK_TO_SSHAK = {
constants.SSHK_RSA: constants.SSHAK_RSA,
constants.SSHK_DSA: constants.SSHAK_DSS,
}
_SSH_DAEMON_KEYFILES = {
constants.SSHK_RSA:
(pathutils.SSH_HOST_RSA_PUB, pathutils.SSH_HOST_RSA_PRIV),
constants.SSHK_DSA:
(pathutils.SSH_HOST_DSA_PUB, pathutils.SSH_HOST_DSA_PRIV),
}
assert frozenset(_SSHK_TO_SSHAK.keys()) == constants.SSHK_ALL
assert frozenset(_SSHK_TO_SSHAK.values()) == constants.SSHAK_ALL
class JoinError(errors.GenericError):
"""Local class for reporting errors.
"""
def ParseOptions():
"""Parses the options passed to the program.
@return: Options and arguments
"""
program = os.path.basename(sys.argv[0])
parser = optparse.OptionParser(usage="%prog [--dry-run]",
prog=program)
parser.add_option(cli.DEBUG_OPT)
parser.add_option(cli.VERBOSE_OPT)
parser.add_option(cli.DRY_RUN_OPT)
(opts, args) = parser.parse_args()
return VerifyOptions(parser, opts, args)
def VerifyOptions(parser, opts, args):
"""Verifies options and arguments for correctness.
"""
if args:
parser.error("No arguments are expected")
return opts
def SetupLogging(opts):
"""Configures the logging module.
"""
formatter = logging.Formatter("%(asctime)s: %(message)s")
stderr_handler = logging.StreamHandler()
stderr_handler.setFormatter(formatter)
if opts.debug:
stderr_handler.setLevel(logging.NOTSET)
elif opts.verbose:
stderr_handler.setLevel(logging.INFO)
else:
stderr_handler.setLevel(logging.WARNING)
root_logger = logging.getLogger("")
root_logger.setLevel(logging.NOTSET)
root_logger.addHandler(stderr_handler)
def _VerifyCertificate(cert, _noded_cert_file=pathutils.NODED_CERT_FILE):
"""Verifies a certificate against the local node daemon certificate.
@type cert: string
@param cert: Certificate in PEM format (no key)
"""
try:
OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, cert)
except OpenSSL.crypto.Error, err:
pass
else:
raise JoinError("No private key may be given")
try:
cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
except Exception, err:
raise errors.X509CertError("(stdin)",
"Unable to load certificate: %s" % err)
try:
noded_pem = utils.ReadFile(_noded_cert_file)
except EnvironmentError, err:
if err.errno != errno.ENOENT:
raise
logging.debug("Local node certificate was not found (file %s)",
_noded_cert_file)
return
try:
key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, noded_pem)
except Exception, err:
raise errors.X509CertError(_noded_cert_file,
"Unable to load private key: %s" % err)
ctx = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_METHOD)
ctx.use_privatekey(key)
ctx.use_certificate(cert)
try:
ctx.check_privatekey()
except OpenSSL.SSL.Error:
raise JoinError("Given cluster certificate does not match local key")
def VerifyCertificate(data, _verify_fn=_VerifyCertificate):
"""Verifies cluster certificate.
@type data: dict
"""
cert = data.get(constants.SSHS_NODE_DAEMON_CERTIFICATE)
if cert:
_verify_fn(cert)
def _VerifyClusterName(name, _ss_cluster_name_file=None):
"""Verifies cluster name against a local cluster name.
@type name: string
@param name: Cluster name
"""
if _ss_cluster_name_file is None:
_ss_cluster_name_file = \
ssconf.SimpleStore().KeyToFilename(constants.SS_CLUSTER_NAME)
try:
local_name = utils.ReadOneLineFile(_ss_cluster_name_file)
except EnvironmentError, err:
if err.errno != errno.ENOENT:
raise
logging.debug("Local cluster name was not found (file %s)",
_ss_cluster_name_file)
else:
if name != local_name:
raise JoinError("Current cluster name is '%s'" % local_name)
def VerifyClusterName(data, _verify_fn=_VerifyClusterName):
"""Verifies cluster name.
@type data: dict
"""
name = data.get(constants.SSHS_CLUSTER_NAME)
if name:
_verify_fn(name)
else:
raise JoinError("Cluster name must be specified")
def _UpdateKeyFiles(keys, dry_run, keyfiles):
"""Updates SSH key files.
@type keys: sequence of tuple; (string, string, string)
@param keys: Keys to write, tuples consist of key type
(L{constants.SSHK_ALL}), public and private key
@type dry_run: boolean
@param dry_run: Whether to perform a dry run
@type keyfiles: dict; (string as key, tuple with (string, string) as values)
@param keyfiles: Mapping from key types (L{constants.SSHK_ALL}) to file
names; value tuples consist of public key filename and private key filename
"""
assert set(keyfiles) == constants.SSHK_ALL
for (kind, public_key, private_key) in keys:
(public_file, private_file) = keyfiles[kind]
logging.debug("Writing %s ...", public_file)
utils.WriteFile(public_file, data=public_key, mode=0644,
backup=True, dry_run=dry_run)
logging.debug("Writing %s ...", private_file)
utils.WriteFile(private_file, data=private_key, mode=0600,
backup=True, dry_run=dry_run)
def UpdateSshDaemon(data, dry_run, _runcmd_fn=utils.RunCmd,
_keyfiles=None):
"""Updates SSH daemon's keys.
Unless C{dry_run} is set, the daemon is restarted at the end.
@type data: dict
@param data: Input data
@type dry_run: boolean
@param dry_run: Whether to perform a dry run
"""
keys = data.get(constants.SSHS_SSH_HOST_KEY)
if not keys:
return
if _keyfiles is None:
_keyfiles = _SSH_DAEMON_KEYFILES
logging.info("Updating SSH daemon key files")
_UpdateKeyFiles(keys, dry_run, _keyfiles)
if dry_run:
logging.info("This is a dry run, not restarting SSH daemon")
else:
result = _runcmd_fn([pathutils.DAEMON_UTIL, "reload-ssh-keys"],
interactive=True)
if result.failed:
raise JoinError("Could not reload SSH keys, command '%s'"
" had exitcode %s and error %s" %
(result.cmd, result.exit_code, result.output))
def UpdateSshRoot(data, dry_run, _homedir_fn=None):
"""Updates root's SSH keys.
Root's C{authorized_keys} file is also updated with new public keys.
@type data: dict
@param data: Input data
@type dry_run: boolean
@param dry_run: Whether to perform a dry run
"""
keys = data.get(constants.SSHS_SSH_ROOT_KEY)
if not keys:
return
(dsa_private_file, dsa_public_file, auth_keys_file) = \
ssh.GetUserFiles(constants.SSH_LOGIN_USER, mkdir=True,
kind=constants.SSHK_DSA, _homedir_fn=_homedir_fn)
(rsa_private_file, rsa_public_file, _) = \
ssh.GetUserFiles(constants.SSH_LOGIN_USER, mkdir=True,
kind=constants.SSHK_RSA, _homedir_fn=_homedir_fn)
_UpdateKeyFiles(keys, dry_run, {
constants.SSHK_RSA: (rsa_public_file, rsa_private_file),
constants.SSHK_DSA: (dsa_public_file, dsa_private_file),
})
if dry_run:
logging.info("This is a dry run, not modifying %s", auth_keys_file)
else:
for (kind, public_key, _) in keys:
line = "%s %s" % (_SSHK_TO_SSHAK[kind], public_key)
utils.AddAuthorizedKey(auth_keys_file, line)
def LoadData(raw):
"""Parses and verifies input data.
@rtype: dict
"""
try:
data = serializer.LoadJson(raw)
except Exception, err:
raise errors.ParseError("Can't parse input data: %s" % err)
if not _DATA_CHECK(data):
raise errors.ParseError("Input data does not match expected format: %s" %
_DATA_CHECK)
return data
def Main():
"""Main routine.
"""
opts = ParseOptions()
SetupLogging(opts)
try:
data = LoadData(sys.stdin.read())
# Check if input data is correct
VerifyClusterName(data)
VerifyCertificate(data)
# Update SSH files
UpdateSshDaemon(data, opts.dry_run)
UpdateSshRoot(data, opts.dry_run)
logging.info("Setup finished successfully")
except Exception, err: # pylint: disable=W0703
logging.debug("Caught unhandled exception", exc_info=True)
(retcode, message) = cli.FormatError(err)
logging.error(message)
return retcode
else:
return constants.EXIT_SUCCESS
......@@ -828,9 +828,6 @@ def ReadLockedPidFile(path):
return None
_SSH_KEYS_WITH_TWO_PARTS = frozenset(["ssh-dss", "ssh-rsa"])
def _SplitSshKey(key):
"""Splits a line for SSH's C{authorized_keys} file.
......@@ -845,7 +842,7 @@ def _SplitSshKey(key):
"""
parts = key.split()
if parts and parts[0] in _SSH_KEYS_WITH_TWO_PARTS:
if parts and parts[0] in constants.SSHAK_ALL:
# If the key has no options in front of it, we only want the significant
# fields
return (False, parts[:2])
......
-----BEGIN PRIVATE KEY-----
MIIBUwIBADANBgkqhkiG9w0BAQEFAASCAT0wggE5AgEAAkEAt8OZYvvi8noVPlpR
/SrHcya9ne7RG5DjvMssksUqyGriUs/WGnpZlL4nz+BcLFGwNNntoxqR30Tjk47S
cmSBRQIDAQABAkAqTP5MCMuPIYcuWUAyVNygpzRS3JyKCepClUpnZreYdo4sUQE3
/AM7xeb92R06iZ3f9/MPrbaMKTWRh3uCyfKBAiEA5TxdacnVxdS8+ZLyys4p/C1s
iajrarBb/j+NIAnsdnECIQDNOCDO7Jq/iN5qE4Vbi/3zmnP1Ca5aBo+KJ/hhSjRq
FQIgIBpWEqybbXsfg+waaGB67MAHxTeM0IImP/LydpwtK2ECIB3SrlHj6Ik1Jr1b
oOGw8nLYW0mc4o2KrolxTZM16XARAiBKW3aSjY5UrnoEqa8pAeiO8LJaRj73Epmr
zC89IuLZfg==
-----END PRIVATE KEY-----
-----BEGIN CERTIFICATE-----
MIIB0zCCAX2gAwIBAgIJAKrAqGX6UolVMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV
BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX
aWRnaXRzIFB0eSBMdGQwHhcNMTIxMDE5MTQ1NjA4WhcNMTIxMDIwMTQ1NjA4WjBF
MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50
ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBALfD
mWL74vJ6FT5aUf0qx3MmvZ3u0RuQ47zLLJLFKshq4lLP1hp6WZS+J8/gXCxRsDTZ
7aMakd9E45OO0nJkgUUCAwEAAaNQME4wHQYDVR0OBBYEFA1Fc/GIVtd6nMocrSsA
e5bxmVhMMB8GA1UdIwQYMBaAFA1Fc/GIVtd6nMocrSsAe5bxmVhMMAwGA1UdEwQF
MAMBAf8wDQYJKoZIhvcNAQEFBQADQQCTUwzDGU+IJTQ3PIJrA3fHMyKbBvc4Rkvi
ZNFsmgsidWhb+5APlPjtlS7rXlonNHBzDoGb4RNArtxhEx+rBcAE
-----END CERTIFICATE-----
#!/usr/bin/python
#
# Copyright (C) 2012 Google Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
# 02110-1301, USA.
"""Script for testing ganeti.tools.prepare_node_join"""
import unittest
import shutil
import tempfile
import os.path
import OpenSSL
from ganeti import errors
from ganeti import constants
from ganeti import serializer
from ganeti import pathutils
from ganeti import compat
from ganeti import utils
from ganeti.tools import prepare_node_join
import testutils
_JoinError = prepare_node_join.JoinError
class TestLoadData(unittest.TestCase):
def testNoJson(self):
self.assertRaises(errors.ParseError, prepare_node_join.LoadData, "")
self.assertRaises(errors.ParseError, prepare_node_join.LoadData, "}")
def testInvalidDataStructure(self):
raw = serializer.DumpJson({
"some other thing": False,
})
self.assertRaises(errors.ParseError, prepare_node_join.LoadData, raw)
raw = serializer.DumpJson([])
self.assertRaises(errors.ParseError, prepare_node_join.LoadData, raw)
def testValidData(self):
raw = serializer.DumpJson({})
self.assertEqual(prepare_node_join.LoadData(raw), {})
class TestVerifyCertificate(testutils.GanetiTestCase):
def setUp(self):
testutils.GanetiTestCase.setUp(self)
self.tmpdir = tempfile.mkdtemp()
def tearDown(self):
testutils.GanetiTestCase.tearDown(self)
shutil.rmtree(self.tmpdir)
def testNoCert(self):
prepare_node_join.VerifyCertificate({}, _verify_fn=NotImplemented)
def testMismatchingKey(self):
other_cert = self._TestDataFilename("cert1.pem")
node_cert = self._TestDataFilename("cert2.pem")
self.assertRaises(_JoinError, prepare_node_join._VerifyCertificate,
utils.ReadFile(other_cert), _noded_cert_file=node_cert)
def testGivenPrivateKey(self):
cert_filename = self._TestDataFilename("cert2.pem")
cert_pem = utils.ReadFile(cert_filename)
self.assertRaises(_JoinError, prepare_node_join._VerifyCertificate,
cert_pem, _noded_cert_file=cert_filename)
def testMatchingKey(self):
cert_filename = self._TestDataFilename("cert2.pem")
# Extract certificate
cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
utils.ReadFile(cert_filename))
cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM,
cert)
prepare_node_join._VerifyCertificate(cert_pem,
_noded_cert_file=cert_filename)
def testMissingFile(self):
cert = self._TestDataFilename("cert1.pem")
nodecert = utils.PathJoin(self.tmpdir, "does-not-exist")
prepare_node_join._VerifyCertificate(utils.ReadFile(cert),
_noded_cert_file=nodecert)
def testInvalidCertificate(self):
self.assertRaises(errors.X509CertError,
prepare_node_join._VerifyCertificate,
"Something that's not a certificate",
_noded_cert_file=NotImplemented)
def testNoPrivateKey(self):
cert = self._TestDataFilename("cert1.pem")