diff --git a/lib/rpc.py b/lib/rpc.py index f94c701a466aaf3071d0e11e5673948fb3874eaa..bf854046f7ebda156b2cecab0c325700508a3562 100644 --- a/lib/rpc.py +++ b/lib/rpc.py @@ -397,7 +397,7 @@ class _RpcProcessor: return results def __call__(self, hosts, procedure, body, read_timeout, resolver_opts, - _req_process_fn=http.client.ProcessRequests): + _req_process_fn=None): """Makes an RPC request to a number of nodes. @type hosts: sequence @@ -413,6 +413,9 @@ class _RpcProcessor: assert read_timeout is not None, \ "Missing RPC read timeout for procedure '%s'" % procedure + if _req_process_fn is None: + _req_process_fn = http.client.ProcessRequests + (results, requests) = \ self._PrepareRequests(self._resolver(hosts, resolver_opts), self._port, procedure, body, read_timeout) @@ -425,13 +428,15 @@ class _RpcProcessor: class _RpcClientBase: - def __init__(self, resolver, encoder_fn, lock_monitor_cb=None): + def __init__(self, resolver, encoder_fn, lock_monitor_cb=None, + _req_process_fn=None): """Initializes this class. """ - self._proc = _RpcProcessor(resolver, - netutils.GetDaemonPort(constants.NODED), - lock_monitor_cb=lock_monitor_cb) + proc = _RpcProcessor(resolver, + netutils.GetDaemonPort(constants.NODED), + lock_monitor_cb=lock_monitor_cb) + self._proc = compat.partial(proc, _req_process_fn=_req_process_fn) self._encoder = compat.partial(self._EncodeArg, encoder_fn) @staticmethod diff --git a/test/ganeti.rpc_unittest.py b/test/ganeti.rpc_unittest.py index 990a4127db0ce57da4c35339633c154fc34bbe68..a9349a67e52df2b327b21f34a83b617f94e0d8d5 100755 --- a/test/ganeti.rpc_unittest.py +++ b/test/ganeti.rpc_unittest.py @@ -24,6 +24,7 @@ import os import sys import unittest +import random from ganeti import constants from ganeti import compat @@ -469,5 +470,222 @@ class TestCompress(unittest.TestCase): (constants.RPC_ENCODING_ZLIB_BASE64, "invalid zlib data")) +class TestRpcClientBase(unittest.TestCase): + def testNoHosts(self): + cdef = ("test_call", NotImplemented, None, rpc_defs.TMO_SLOW, [], + None, None, NotImplemented) + http_proc = _FakeRequestProcessor(NotImplemented) + client = rpc._RpcClientBase(rpc._StaticResolver([]), NotImplemented, + _req_process_fn=http_proc) + self.assertEqual(client._Call(cdef, [], []), {}) + + # Test wrong number of arguments + self.assertRaises(errors.ProgrammerError, client._Call, + cdef, [], [0, 1, 2]) + + def testTimeout(self): + def _CalcTimeout((arg1, arg2)): + return arg1 + arg2 + + def _VerifyRequest(exp_timeout, req): + self.assertEqual(req.read_timeout, exp_timeout) + + req.success = True + req.resp_status_code = http.HTTP_OK + req.resp_body = serializer.DumpJson((True, hex(req.read_timeout))) + + resolver = rpc._StaticResolver([ + "192.0.2.1", + "192.0.2.2", + ]) + + nodes = [ + "node1.example.com", + "node2.example.com", + ] + + tests = [(100, None, 100), (30, None, 30)] + tests.extend((_CalcTimeout, i, i + 300) + for i in [0, 5, 16485, 30516]) + + for timeout, arg1, exp_timeout in tests: + cdef = ("test_call", NotImplemented, None, timeout, [ + ("arg1", None, NotImplemented), + ("arg2", None, NotImplemented), + ], None, None, NotImplemented) + + http_proc = _FakeRequestProcessor(compat.partial(_VerifyRequest, + exp_timeout)) + client = rpc._RpcClientBase(resolver, NotImplemented, + _req_process_fn=http_proc) + result = client._Call(cdef, nodes, [arg1, 300]) + self.assertEqual(len(result), len(nodes)) + self.assertTrue(compat.all(not res.fail_msg and + res.payload == hex(exp_timeout) + for res in result.values())) + + def testArgumentEncoder(self): + (AT1, AT2) = range(1, 3) + + resolver = rpc._StaticResolver([ + "192.0.2.5", + "192.0.2.6", + ]) + + nodes = [ + "node5.example.com", + "node6.example.com", + ] + + encoders = { + AT1: hex, + AT2: hash, + } + + cdef = ("test_call", NotImplemented, None, rpc_defs.TMO_NORMAL, [ + ("arg0", None, NotImplemented), + ("arg1", AT1, NotImplemented), + ("arg1", AT2, NotImplemented), + ], None, None, NotImplemented) + + def _VerifyRequest(req): + req.success = True + req.resp_status_code = http.HTTP_OK + req.resp_body = serializer.DumpJson((True, req.post_data)) + + http_proc = _FakeRequestProcessor(_VerifyRequest) + + for num in [0, 3796, 9032119]: + client = rpc._RpcClientBase(resolver, encoders.get, + _req_process_fn=http_proc) + result = client._Call(cdef, nodes, ["foo", num, "Hello%s" % num]) + self.assertEqual(len(result), len(nodes)) + for res in result.values(): + self.assertFalse(res.fail_msg) + self.assertEqual(serializer.LoadJson(res.payload), + ["foo", hex(num), hash("Hello%s" % num)]) + + def testPostProc(self): + def _VerifyRequest(nums, req): + req.success = True + req.resp_status_code = http.HTTP_OK + req.resp_body = serializer.DumpJson((True, nums)) + + resolver = rpc._StaticResolver([ + "192.0.2.90", + "192.0.2.95", + ]) + + nodes = [ + "node90.example.com", + "node95.example.com", + ] + + def _PostProc(res): + self.assertFalse(res.fail_msg) + res.payload = sum(res.payload) + return res + + cdef = ("test_call", NotImplemented, None, rpc_defs.TMO_NORMAL, [], + None, _PostProc, NotImplemented) + + # Seeded random generator + rnd = random.Random(20299) + + for i in [0, 4, 74, 1391]: + nums = [rnd.randint(0, 1000) for _ in range(i)] + http_proc = _FakeRequestProcessor(compat.partial(_VerifyRequest, nums)) + client = rpc._RpcClientBase(resolver, NotImplemented, + _req_process_fn=http_proc) + result = client._Call(cdef, nodes, []) + self.assertEqual(len(result), len(nodes)) + for res in result.values(): + self.assertFalse(res.fail_msg) + self.assertEqual(res.payload, sum(nums)) + + def testPreProc(self): + def _VerifyRequest(req): + req.success = True + req.resp_status_code = http.HTTP_OK + req.resp_body = serializer.DumpJson((True, req.post_data)) + + resolver = rpc._StaticResolver([ + "192.0.2.30", + "192.0.2.35", + ]) + + nodes = [ + "node30.example.com", + "node35.example.com", + ] + + def _PreProc(node, data): + self.assertEqual(len(data), 1) + return data[0] + node + + cdef = ("test_call", NotImplemented, None, rpc_defs.TMO_NORMAL, [ + ("arg0", None, NotImplemented), + ], _PreProc, None, NotImplemented) + + http_proc = _FakeRequestProcessor(_VerifyRequest) + client = rpc._RpcClientBase(resolver, NotImplemented, + _req_process_fn=http_proc) + + for prefix in ["foo", "bar", "baz"]: + result = client._Call(cdef, nodes, [prefix]) + self.assertEqual(len(result), len(nodes)) + for (idx, (node, res)) in enumerate(result.items()): + self.assertFalse(res.fail_msg) + self.assertEqual(serializer.LoadJson(res.payload), prefix + node) + + def testResolverOptions(self): + def _VerifyRequest(req): + req.success = True + req.resp_status_code = http.HTTP_OK + req.resp_body = serializer.DumpJson((True, req.post_data)) + + nodes = [ + "node30.example.com", + "node35.example.com", + ] + + def _Resolver(expected, hosts, options): + self.assertEqual(hosts, nodes) + self.assertEqual(options, expected) + return zip(hosts, nodes) + + def _DynamicResolverOptions((arg0, )): + return sum(arg0) + + tests = [ + (None, None, None), + (rpc_defs.ACCEPT_OFFLINE_NODE, None, rpc_defs.ACCEPT_OFFLINE_NODE), + (False, None, False), + (True, None, True), + (0, None, 0), + (_DynamicResolverOptions, [1, 2, 3], 6), + (_DynamicResolverOptions, range(4, 19), 165), + ] + + for (resolver_opts, arg0, expected) in tests: + cdef = ("test_call", NotImplemented, resolver_opts, rpc_defs.TMO_NORMAL, [ + ("arg0", None, NotImplemented), + ], None, None, NotImplemented) + + http_proc = _FakeRequestProcessor(_VerifyRequest) + + client = rpc._RpcClientBase(compat.partial(_Resolver, expected), + NotImplemented, _req_process_fn=http_proc) + result = client._Call(cdef, nodes, [arg0]) + self.assertEqual(len(result), len(nodes)) + for (idx, (node, res)) in enumerate(result.items()): + self.assertFalse(res.fail_msg) + + +class TestRpcRunner(unittest.TestCase): + def testUploadFile(self): + runner = rpc.RpcRunner(_req_process_fn=http_proc) + + if __name__ == "__main__": testutils.GanetiTestProgram()