Commit d40f0bc6 authored by Helga Velroyen's avatar Helga Velroyen

Prepare-node-join: use common functions

This patch makes prepare_node_join use some of the functions
that were moved to tools/common.py. The respective unittests
are removed, because they are already tested in
common_unittest.py.
Signed-off-by: default avatarHelga Velroyen <helgav@google.com>
Reviewed-by: default avatarKlaus Aehlig <aehlig@google.com>
parent a12aa2d2
......@@ -43,10 +43,9 @@ from ganeti import constants
from ganeti import errors
from ganeti import pathutils
from ganeti import utils
from ganeti import serializer
from ganeti import ht
from ganeti import ssh
from ganeti import ssconf
from ganeti.tools import common
_SSH_KEY_LIST_ITEM = \
......@@ -89,17 +88,7 @@ def ParseOptions():
(opts, args) = parser.parse_args()
return VerifyOptions(parser, opts, args)
def VerifyOptions(parser, opts, args):
"""Verifies options and arguments for correctness.
"""
if args:
parser.error("No arguments are expected")
return opts
return common.VerifyOptions(parser, opts, args)
def _VerifyCertificate(cert_pem, _check_fn=utils.CheckNodeCertificate):
......@@ -137,19 +126,6 @@ def VerifyCertificate(data, _verify_fn=_VerifyCertificate):
_verify_fn(cert)
def VerifyClusterName(data, _verify_fn=ssconf.VerifyClusterName):
"""Verifies cluster name.
@type data: dict
"""
name = data.get(constants.SSHS_CLUSTER_NAME)
if name:
_verify_fn(name)
else:
raise JoinError("Cluster name must be specified")
def _UpdateKeyFiles(keys, dry_run, keyfiles):
"""Updates SSH key files.
......@@ -238,15 +214,6 @@ def UpdateSshRoot(data, dry_run, _homedir_fn=None):
utils.AddAuthorizedKey(auth_keys_file, public_key)
def LoadData(raw):
"""Parses and verifies input data.
@rtype: dict
"""
return serializer.LoadAndVerifyJson(raw, _DATA_CHECK)
def Main():
"""Main routine.
......@@ -256,10 +223,10 @@ def Main():
utils.SetupToolLogging(opts.debug, opts.verbose)
try:
data = LoadData(sys.stdin.read())
data = common.LoadData(sys.stdin.read(), _DATA_CHECK)
# Check if input data is correct
VerifyClusterName(data)
common.VerifyClusterName(data, JoinError)
VerifyCertificate(data)
# Update SSH files
......
......@@ -38,6 +38,8 @@ import OpenSSL
import time
from ganeti import constants
from ganeti import errors
from ganeti import serializer
from ganeti import utils
from ganeti.tools import common
......@@ -78,5 +80,56 @@ class TestGenerateClientCert(unittest.TestCase):
self.assertEqual(client_cert.get_subject().CN, my_node_name)
class TestLoadData(unittest.TestCase):
def testNoJson(self):
self.assertRaises(errors.ParseError, common.LoadData, Exception, "")
self.assertRaises(errors.ParseError, common.LoadData, Exception, "}")
def testInvalidDataStructure(self):
raw = serializer.DumpJson({
"some other thing": False,
})
self.assertRaises(errors.ParseError, common.LoadData, Exception, raw)
raw = serializer.DumpJson([])
self.assertRaises(errors.ParseError, common.LoadData, Exception, raw)
def testValidData(self):
raw = serializer.DumpJson({})
self.assertEqual(common.LoadData(raw, Exception), {})
class TestVerifyClusterName(unittest.TestCase):
class MyException(Exception):
pass
def setUp(self):
unittest.TestCase.setUp(self)
self.tmpdir = tempfile.mkdtemp()
def tearDown(self):
unittest.TestCase.tearDown(self)
shutil.rmtree(self.tmpdir)
def testNoName(self):
self.assertRaises(self.MyException, common.VerifyClusterName,
{}, self.MyException, _verify_fn=NotImplemented)
@staticmethod
def _FailingVerify(name):
assert name == "cluster.example.com"
raise errors.GenericError()
def testFailingVerification(self):
data = {
constants.SSHS_CLUSTER_NAME: "cluster.example.com",
}
self.assertRaises(errors.GenericError, common.VerifyClusterName,
data, Exception, _verify_fn=self._FailingVerify)
if __name__ == "__main__":
testutils.GanetiTestProgram()
......@@ -34,11 +34,9 @@ import unittest
import shutil
import tempfile
import os.path
import OpenSSL
from ganeti import errors
from ganeti import constants
from ganeti import serializer
from ganeti import pathutils
from ganeti import compat
from ganeti import utils
......@@ -50,25 +48,6 @@ import testutils
_JoinError = prepare_node_join.JoinError
class TestLoadData(unittest.TestCase):
def testNoJson(self):
self.assertRaises(errors.ParseError, prepare_node_join.LoadData, "")
self.assertRaises(errors.ParseError, prepare_node_join.LoadData, "}")
def testInvalidDataStructure(self):
raw = serializer.DumpJson({
"some other thing": False,
})
self.assertRaises(errors.ParseError, prepare_node_join.LoadData, raw)
raw = serializer.DumpJson([])
self.assertRaises(errors.ParseError, prepare_node_join.LoadData, raw)
def testValidData(self):
raw = serializer.DumpJson({})
self.assertEqual(prepare_node_join.LoadData(raw), {})
class TestVerifyCertificate(testutils.GanetiTestCase):
def setUp(self):
testutils.GanetiTestCase.setUp(self)
......@@ -104,33 +83,6 @@ class TestVerifyCertificate(testutils.GanetiTestCase):
prepare_node_join._VerifyCertificate(cert_pem, _check_fn=self._Check)
class TestVerifyClusterName(unittest.TestCase):
def setUp(self):
unittest.TestCase.setUp(self)
self.tmpdir = tempfile.mkdtemp()
def tearDown(self):
unittest.TestCase.tearDown(self)
shutil.rmtree(self.tmpdir)
def testNoName(self):
self.assertRaises(_JoinError, prepare_node_join.VerifyClusterName,
{}, _verify_fn=NotImplemented)
@staticmethod
def _FailingVerify(name):
assert name == "cluster.example.com"
raise errors.GenericError()
def testFailingVerification(self):
data = {
constants.SSHS_CLUSTER_NAME: "cluster.example.com",
}
self.assertRaises(errors.GenericError, prepare_node_join.VerifyClusterName,
data, _verify_fn=self._FailingVerify)
class TestUpdateSshDaemon(unittest.TestCase):
def setUp(self):
unittest.TestCase.setUp(self)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment