Skip to content
Snippets Groups Projects
Commit 065be3f0 authored by Michael Hanselmann's avatar Michael Hanselmann
Browse files

Add unittests for RPC client


This patch adds a number of unittests for the RPC client base class.
Some small changes were necessary in “rpc.py” to allow for better
testing.

Signed-off-by: default avatarMichael Hanselmann <hansmi@google.com>
Reviewed-by: default avatarIustin Pop <iustin@google.com>
parent e78667fe
No related branches found
No related tags found
No related merge requests found
......@@ -397,7 +397,7 @@ class _RpcProcessor:
return results
def __call__(self, hosts, procedure, body, read_timeout, resolver_opts,
_req_process_fn=http.client.ProcessRequests):
_req_process_fn=None):
"""Makes an RPC request to a number of nodes.
@type hosts: sequence
......@@ -413,6 +413,9 @@ class _RpcProcessor:
assert read_timeout is not None, \
"Missing RPC read timeout for procedure '%s'" % procedure
if _req_process_fn is None:
_req_process_fn = http.client.ProcessRequests
(results, requests) = \
self._PrepareRequests(self._resolver(hosts, resolver_opts), self._port,
procedure, body, read_timeout)
......@@ -425,13 +428,15 @@ class _RpcProcessor:
class _RpcClientBase:
def __init__(self, resolver, encoder_fn, lock_monitor_cb=None):
def __init__(self, resolver, encoder_fn, lock_monitor_cb=None,
_req_process_fn=None):
"""Initializes this class.
"""
self._proc = _RpcProcessor(resolver,
netutils.GetDaemonPort(constants.NODED),
lock_monitor_cb=lock_monitor_cb)
proc = _RpcProcessor(resolver,
netutils.GetDaemonPort(constants.NODED),
lock_monitor_cb=lock_monitor_cb)
self._proc = compat.partial(proc, _req_process_fn=_req_process_fn)
self._encoder = compat.partial(self._EncodeArg, encoder_fn)
@staticmethod
......
......@@ -24,6 +24,7 @@
import os
import sys
import unittest
import random
from ganeti import constants
from ganeti import compat
......@@ -469,5 +470,222 @@ class TestCompress(unittest.TestCase):
(constants.RPC_ENCODING_ZLIB_BASE64, "invalid zlib data"))
class TestRpcClientBase(unittest.TestCase):
def testNoHosts(self):
cdef = ("test_call", NotImplemented, None, rpc_defs.TMO_SLOW, [],
None, None, NotImplemented)
http_proc = _FakeRequestProcessor(NotImplemented)
client = rpc._RpcClientBase(rpc._StaticResolver([]), NotImplemented,
_req_process_fn=http_proc)
self.assertEqual(client._Call(cdef, [], []), {})
# Test wrong number of arguments
self.assertRaises(errors.ProgrammerError, client._Call,
cdef, [], [0, 1, 2])
def testTimeout(self):
def _CalcTimeout((arg1, arg2)):
return arg1 + arg2
def _VerifyRequest(exp_timeout, req):
self.assertEqual(req.read_timeout, exp_timeout)
req.success = True
req.resp_status_code = http.HTTP_OK
req.resp_body = serializer.DumpJson((True, hex(req.read_timeout)))
resolver = rpc._StaticResolver([
"192.0.2.1",
"192.0.2.2",
])
nodes = [
"node1.example.com",
"node2.example.com",
]
tests = [(100, None, 100), (30, None, 30)]
tests.extend((_CalcTimeout, i, i + 300)
for i in [0, 5, 16485, 30516])
for timeout, arg1, exp_timeout in tests:
cdef = ("test_call", NotImplemented, None, timeout, [
("arg1", None, NotImplemented),
("arg2", None, NotImplemented),
], None, None, NotImplemented)
http_proc = _FakeRequestProcessor(compat.partial(_VerifyRequest,
exp_timeout))
client = rpc._RpcClientBase(resolver, NotImplemented,
_req_process_fn=http_proc)
result = client._Call(cdef, nodes, [arg1, 300])
self.assertEqual(len(result), len(nodes))
self.assertTrue(compat.all(not res.fail_msg and
res.payload == hex(exp_timeout)
for res in result.values()))
def testArgumentEncoder(self):
(AT1, AT2) = range(1, 3)
resolver = rpc._StaticResolver([
"192.0.2.5",
"192.0.2.6",
])
nodes = [
"node5.example.com",
"node6.example.com",
]
encoders = {
AT1: hex,
AT2: hash,
}
cdef = ("test_call", NotImplemented, None, rpc_defs.TMO_NORMAL, [
("arg0", None, NotImplemented),
("arg1", AT1, NotImplemented),
("arg1", AT2, NotImplemented),
], None, None, NotImplemented)
def _VerifyRequest(req):
req.success = True
req.resp_status_code = http.HTTP_OK
req.resp_body = serializer.DumpJson((True, req.post_data))
http_proc = _FakeRequestProcessor(_VerifyRequest)
for num in [0, 3796, 9032119]:
client = rpc._RpcClientBase(resolver, encoders.get,
_req_process_fn=http_proc)
result = client._Call(cdef, nodes, ["foo", num, "Hello%s" % num])
self.assertEqual(len(result), len(nodes))
for res in result.values():
self.assertFalse(res.fail_msg)
self.assertEqual(serializer.LoadJson(res.payload),
["foo", hex(num), hash("Hello%s" % num)])
def testPostProc(self):
def _VerifyRequest(nums, req):
req.success = True
req.resp_status_code = http.HTTP_OK
req.resp_body = serializer.DumpJson((True, nums))
resolver = rpc._StaticResolver([
"192.0.2.90",
"192.0.2.95",
])
nodes = [
"node90.example.com",
"node95.example.com",
]
def _PostProc(res):
self.assertFalse(res.fail_msg)
res.payload = sum(res.payload)
return res
cdef = ("test_call", NotImplemented, None, rpc_defs.TMO_NORMAL, [],
None, _PostProc, NotImplemented)
# Seeded random generator
rnd = random.Random(20299)
for i in [0, 4, 74, 1391]:
nums = [rnd.randint(0, 1000) for _ in range(i)]
http_proc = _FakeRequestProcessor(compat.partial(_VerifyRequest, nums))
client = rpc._RpcClientBase(resolver, NotImplemented,
_req_process_fn=http_proc)
result = client._Call(cdef, nodes, [])
self.assertEqual(len(result), len(nodes))
for res in result.values():
self.assertFalse(res.fail_msg)
self.assertEqual(res.payload, sum(nums))
def testPreProc(self):
def _VerifyRequest(req):
req.success = True
req.resp_status_code = http.HTTP_OK
req.resp_body = serializer.DumpJson((True, req.post_data))
resolver = rpc._StaticResolver([
"192.0.2.30",
"192.0.2.35",
])
nodes = [
"node30.example.com",
"node35.example.com",
]
def _PreProc(node, data):
self.assertEqual(len(data), 1)
return data[0] + node
cdef = ("test_call", NotImplemented, None, rpc_defs.TMO_NORMAL, [
("arg0", None, NotImplemented),
], _PreProc, None, NotImplemented)
http_proc = _FakeRequestProcessor(_VerifyRequest)
client = rpc._RpcClientBase(resolver, NotImplemented,
_req_process_fn=http_proc)
for prefix in ["foo", "bar", "baz"]:
result = client._Call(cdef, nodes, [prefix])
self.assertEqual(len(result), len(nodes))
for (idx, (node, res)) in enumerate(result.items()):
self.assertFalse(res.fail_msg)
self.assertEqual(serializer.LoadJson(res.payload), prefix + node)
def testResolverOptions(self):
def _VerifyRequest(req):
req.success = True
req.resp_status_code = http.HTTP_OK
req.resp_body = serializer.DumpJson((True, req.post_data))
nodes = [
"node30.example.com",
"node35.example.com",
]
def _Resolver(expected, hosts, options):
self.assertEqual(hosts, nodes)
self.assertEqual(options, expected)
return zip(hosts, nodes)
def _DynamicResolverOptions((arg0, )):
return sum(arg0)
tests = [
(None, None, None),
(rpc_defs.ACCEPT_OFFLINE_NODE, None, rpc_defs.ACCEPT_OFFLINE_NODE),
(False, None, False),
(True, None, True),
(0, None, 0),
(_DynamicResolverOptions, [1, 2, 3], 6),
(_DynamicResolverOptions, range(4, 19), 165),
]
for (resolver_opts, arg0, expected) in tests:
cdef = ("test_call", NotImplemented, resolver_opts, rpc_defs.TMO_NORMAL, [
("arg0", None, NotImplemented),
], None, None, NotImplemented)
http_proc = _FakeRequestProcessor(_VerifyRequest)
client = rpc._RpcClientBase(compat.partial(_Resolver, expected),
NotImplemented, _req_process_fn=http_proc)
result = client._Call(cdef, nodes, [arg0])
self.assertEqual(len(result), len(nodes))
for (idx, (node, res)) in enumerate(result.items()):
self.assertFalse(res.fail_msg)
class TestRpcRunner(unittest.TestCase):
def testUploadFile(self):
runner = rpc.RpcRunner(_req_process_fn=http_proc)
if __name__ == "__main__":
testutils.GanetiTestProgram()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment