Commit d40f0bc6 authored by Helga Velroyen's avatar Helga Velroyen

Prepare-node-join: use common functions

This patch makes prepare_node_join use some of the functions
that were moved to tools/common.py. The respective unittests
are removed, because they are already tested in
common_unittest.py.
Signed-off-by: default avatarHelga Velroyen <helgav@google.com>
Reviewed-by: default avatarKlaus Aehlig <aehlig@google.com>
parent a12aa2d2
...@@ -43,10 +43,9 @@ from ganeti import constants ...@@ -43,10 +43,9 @@ from ganeti import constants
from ganeti import errors from ganeti import errors
from ganeti import pathutils from ganeti import pathutils
from ganeti import utils from ganeti import utils
from ganeti import serializer
from ganeti import ht from ganeti import ht
from ganeti import ssh from ganeti import ssh
from ganeti import ssconf from ganeti.tools import common
_SSH_KEY_LIST_ITEM = \ _SSH_KEY_LIST_ITEM = \
...@@ -89,17 +88,7 @@ def ParseOptions(): ...@@ -89,17 +88,7 @@ def ParseOptions():
(opts, args) = parser.parse_args() (opts, args) = parser.parse_args()
return VerifyOptions(parser, opts, args) return common.VerifyOptions(parser, opts, args)
def VerifyOptions(parser, opts, args):
"""Verifies options and arguments for correctness.
"""
if args:
parser.error("No arguments are expected")
return opts
def _VerifyCertificate(cert_pem, _check_fn=utils.CheckNodeCertificate): def _VerifyCertificate(cert_pem, _check_fn=utils.CheckNodeCertificate):
...@@ -137,19 +126,6 @@ def VerifyCertificate(data, _verify_fn=_VerifyCertificate): ...@@ -137,19 +126,6 @@ def VerifyCertificate(data, _verify_fn=_VerifyCertificate):
_verify_fn(cert) _verify_fn(cert)
def VerifyClusterName(data, _verify_fn=ssconf.VerifyClusterName):
"""Verifies cluster name.
@type data: dict
"""
name = data.get(constants.SSHS_CLUSTER_NAME)
if name:
_verify_fn(name)
else:
raise JoinError("Cluster name must be specified")
def _UpdateKeyFiles(keys, dry_run, keyfiles): def _UpdateKeyFiles(keys, dry_run, keyfiles):
"""Updates SSH key files. """Updates SSH key files.
...@@ -238,15 +214,6 @@ def UpdateSshRoot(data, dry_run, _homedir_fn=None): ...@@ -238,15 +214,6 @@ def UpdateSshRoot(data, dry_run, _homedir_fn=None):
utils.AddAuthorizedKey(auth_keys_file, public_key) utils.AddAuthorizedKey(auth_keys_file, public_key)
def LoadData(raw):
"""Parses and verifies input data.
@rtype: dict
"""
return serializer.LoadAndVerifyJson(raw, _DATA_CHECK)
def Main(): def Main():
"""Main routine. """Main routine.
...@@ -256,10 +223,10 @@ def Main(): ...@@ -256,10 +223,10 @@ def Main():
utils.SetupToolLogging(opts.debug, opts.verbose) utils.SetupToolLogging(opts.debug, opts.verbose)
try: try:
data = LoadData(sys.stdin.read()) data = common.LoadData(sys.stdin.read(), _DATA_CHECK)
# Check if input data is correct # Check if input data is correct
VerifyClusterName(data) common.VerifyClusterName(data, JoinError)
VerifyCertificate(data) VerifyCertificate(data)
# Update SSH files # Update SSH files
......
...@@ -38,6 +38,8 @@ import OpenSSL ...@@ -38,6 +38,8 @@ import OpenSSL
import time import time
from ganeti import constants from ganeti import constants
from ganeti import errors
from ganeti import serializer
from ganeti import utils from ganeti import utils
from ganeti.tools import common from ganeti.tools import common
...@@ -78,5 +80,56 @@ class TestGenerateClientCert(unittest.TestCase): ...@@ -78,5 +80,56 @@ class TestGenerateClientCert(unittest.TestCase):
self.assertEqual(client_cert.get_subject().CN, my_node_name) self.assertEqual(client_cert.get_subject().CN, my_node_name)
class TestLoadData(unittest.TestCase):
def testNoJson(self):
self.assertRaises(errors.ParseError, common.LoadData, Exception, "")
self.assertRaises(errors.ParseError, common.LoadData, Exception, "}")
def testInvalidDataStructure(self):
raw = serializer.DumpJson({
"some other thing": False,
})
self.assertRaises(errors.ParseError, common.LoadData, Exception, raw)
raw = serializer.DumpJson([])
self.assertRaises(errors.ParseError, common.LoadData, Exception, raw)
def testValidData(self):
raw = serializer.DumpJson({})
self.assertEqual(common.LoadData(raw, Exception), {})
class TestVerifyClusterName(unittest.TestCase):
class MyException(Exception):
pass
def setUp(self):
unittest.TestCase.setUp(self)
self.tmpdir = tempfile.mkdtemp()
def tearDown(self):
unittest.TestCase.tearDown(self)
shutil.rmtree(self.tmpdir)
def testNoName(self):
self.assertRaises(self.MyException, common.VerifyClusterName,
{}, self.MyException, _verify_fn=NotImplemented)
@staticmethod
def _FailingVerify(name):
assert name == "cluster.example.com"
raise errors.GenericError()
def testFailingVerification(self):
data = {
constants.SSHS_CLUSTER_NAME: "cluster.example.com",
}
self.assertRaises(errors.GenericError, common.VerifyClusterName,
data, Exception, _verify_fn=self._FailingVerify)
if __name__ == "__main__": if __name__ == "__main__":
testutils.GanetiTestProgram() testutils.GanetiTestProgram()
...@@ -34,11 +34,9 @@ import unittest ...@@ -34,11 +34,9 @@ import unittest
import shutil import shutil
import tempfile import tempfile
import os.path import os.path
import OpenSSL
from ganeti import errors from ganeti import errors
from ganeti import constants from ganeti import constants
from ganeti import serializer
from ganeti import pathutils from ganeti import pathutils
from ganeti import compat from ganeti import compat
from ganeti import utils from ganeti import utils
...@@ -50,25 +48,6 @@ import testutils ...@@ -50,25 +48,6 @@ import testutils
_JoinError = prepare_node_join.JoinError _JoinError = prepare_node_join.JoinError
class TestLoadData(unittest.TestCase):
def testNoJson(self):
self.assertRaises(errors.ParseError, prepare_node_join.LoadData, "")
self.assertRaises(errors.ParseError, prepare_node_join.LoadData, "}")
def testInvalidDataStructure(self):
raw = serializer.DumpJson({
"some other thing": False,
})
self.assertRaises(errors.ParseError, prepare_node_join.LoadData, raw)
raw = serializer.DumpJson([])
self.assertRaises(errors.ParseError, prepare_node_join.LoadData, raw)
def testValidData(self):
raw = serializer.DumpJson({})
self.assertEqual(prepare_node_join.LoadData(raw), {})
class TestVerifyCertificate(testutils.GanetiTestCase): class TestVerifyCertificate(testutils.GanetiTestCase):
def setUp(self): def setUp(self):
testutils.GanetiTestCase.setUp(self) testutils.GanetiTestCase.setUp(self)
...@@ -104,33 +83,6 @@ class TestVerifyCertificate(testutils.GanetiTestCase): ...@@ -104,33 +83,6 @@ class TestVerifyCertificate(testutils.GanetiTestCase):
prepare_node_join._VerifyCertificate(cert_pem, _check_fn=self._Check) prepare_node_join._VerifyCertificate(cert_pem, _check_fn=self._Check)
class TestVerifyClusterName(unittest.TestCase):
def setUp(self):
unittest.TestCase.setUp(self)
self.tmpdir = tempfile.mkdtemp()
def tearDown(self):
unittest.TestCase.tearDown(self)
shutil.rmtree(self.tmpdir)
def testNoName(self):
self.assertRaises(_JoinError, prepare_node_join.VerifyClusterName,
{}, _verify_fn=NotImplemented)
@staticmethod
def _FailingVerify(name):
assert name == "cluster.example.com"
raise errors.GenericError()
def testFailingVerification(self):
data = {
constants.SSHS_CLUSTER_NAME: "cluster.example.com",
}
self.assertRaises(errors.GenericError, prepare_node_join.VerifyClusterName,
data, _verify_fn=self._FailingVerify)
class TestUpdateSshDaemon(unittest.TestCase): class TestUpdateSshDaemon(unittest.TestCase):
def setUp(self): def setUp(self):
unittest.TestCase.setUp(self) unittest.TestCase.setUp(self)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment