Commit fce5efd1 authored by Michael Hanselmann's avatar Michael Hanselmann
Browse files

rpc: Pass resolver options to actual resolver


Signed-off-by: default avatarMichael Hanselmann <hansmi@google.com>
Reviewed-by: default avatarIustin Pop <iustin@google.com>
parent dd6d2d09
......@@ -240,7 +240,7 @@ class RpcResult(object):
raise ec(*args) # pylint: disable=W0142
def _SsconfResolver(node_list,
def _SsconfResolver(node_list, _,
ssc=ssconf.SimpleStore,
nslookup_fn=netutils.Hostname.GetIP):
"""Return addresses for given node names.
......@@ -277,7 +277,7 @@ class _StaticResolver:
"""
self._addresses = addresses
def __call__(self, hosts):
def __call__(self, hosts, _):
"""Returns static addresses for hosts.
"""
......@@ -304,7 +304,7 @@ def _CheckConfigNode(name, node):
return (name, ip)
def _NodeConfigResolver(single_node_fn, all_nodes_fn, hosts):
def _NodeConfigResolver(single_node_fn, all_nodes_fn, hosts, _):
"""Calculate node addresses using configuration.
"""
......@@ -391,7 +391,7 @@ class _RpcProcessor:
return results
def __call__(self, hosts, procedure, body, read_timeout,
def __call__(self, hosts, procedure, body, read_timeout, resolver_opts,
_req_process_fn=http.client.ProcessRequests):
"""Makes an RPC request to a number of nodes.
......@@ -409,8 +409,8 @@ class _RpcProcessor:
"Missing RPC read timeout for procedure '%s'" % procedure
(results, requests) = \
self._PrepareRequests(self._resolver(hosts), self._port, procedure,
body, read_timeout)
self._PrepareRequests(self._resolver(hosts, resolver_opts), self._port,
procedure, body, read_timeout)
_req_process_fn(requests.values(), lock_monitor_cb=self._lock_monitor_cb)
......@@ -469,7 +469,8 @@ class _RpcClientBase:
pnbody = dict((n, serializer.DumpJson(prep_fn(n, enc_args)))
for n in node_list)
result = self._proc(node_list, procedure, pnbody, read_timeout)
result = self._proc(node_list, procedure, pnbody, read_timeout,
req_resolver_opts)
if postproc_fn:
return dict(map(lambda (key, value): (key, postproc_fn(value)),
......
......@@ -74,7 +74,7 @@ class TestRpcProcessor(unittest.TestCase):
http_proc = _FakeRequestProcessor(self._GetVersionResponse)
proc = rpc._RpcProcessor(resolver, 24094)
result = proc(["localhost"], "version", {"localhost": ""}, 60,
_req_process_fn=http_proc)
NotImplemented, _req_process_fn=http_proc)
self.assertEqual(result.keys(), ["localhost"])
lhresp = result["localhost"]
self.assertFalse(lhresp.offline)
......@@ -100,7 +100,8 @@ class TestRpcProcessor(unittest.TestCase):
proc = rpc._RpcProcessor(resolver, 19176)
host = "node31856"
body = {host: ""}
result = proc([host], "version", body, 12356, _req_process_fn=http_proc)
result = proc([host], "version", body, 12356, NotImplemented,
_req_process_fn=http_proc)
self.assertEqual(result.keys(), [host])
lhresp = result[host]
self.assertFalse(lhresp.offline)
......@@ -117,7 +118,8 @@ class TestRpcProcessor(unittest.TestCase):
proc = rpc._RpcProcessor(resolver, 30668)
host = "n17296"
body = {host: ""}
result = proc([host], "version", body, 60, _req_process_fn=http_proc)
result = proc([host], "version", body, 60, NotImplemented,
_req_process_fn=http_proc)
self.assertEqual(result.keys(), [host])
lhresp = result[host]
self.assertTrue(lhresp.offline)
......@@ -148,7 +150,8 @@ class TestRpcProcessor(unittest.TestCase):
resolver = rpc._StaticResolver(nodes)
http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse)
proc = rpc._RpcProcessor(resolver, 23245)
result = proc(nodes, "version", body, 60, _req_process_fn=http_proc,)
result = proc(nodes, "version", body, 60, NotImplemented,
_req_process_fn=http_proc)
self.assertEqual(sorted(result.keys()), sorted(nodes))
for name in nodes:
......@@ -177,7 +180,7 @@ class TestRpcProcessor(unittest.TestCase):
errinfo))
host = "aef9ur4i.example.com"
body = {host: ""}
result = proc(body.keys(), "version", body, 60,
result = proc(body.keys(), "version", body, 60, NotImplemented,
_req_process_fn=http_proc)
self.assertEqual(result.keys(), [host])
lhresp = result[host]
......@@ -227,7 +230,7 @@ class TestRpcProcessor(unittest.TestCase):
http_proc = \
_FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse,
httperrnodes, failnodes))
result = proc(nodes, "vg_list", body, rpc._TMO_URGENT,
result = proc(nodes, "vg_list", body, rpc._TMO_URGENT, NotImplemented,
_req_process_fn=http_proc)
self.assertEqual(sorted(result.keys()), sorted(nodes))
......@@ -272,7 +275,7 @@ class TestRpcProcessor(unittest.TestCase):
http_proc = _FakeRequestProcessor(fn)
host = "oqo7lanhly.example.com"
body = {host: ""}
result = proc([host], "version", body, 60,
result = proc([host], "version", body, 60, NotImplemented,
_req_process_fn=http_proc)
self.assertEqual(result.keys(), [host])
lhresp = result[host]
......@@ -304,7 +307,8 @@ class TestRpcProcessor(unittest.TestCase):
proc = rpc._RpcProcessor(resolver, 18700)
host = "node19759"
body = {host: serializer.DumpJson(test_data)}
result = proc([host], "upload_file", body, 30, _req_process_fn=http_proc)
result = proc([host], "upload_file", body, 30, NotImplemented,
_req_process_fn=http_proc)
self.assertEqual(result.keys(), [host])
lhresp = result[host]
self.assertFalse(lhresp.offline)
......@@ -322,7 +326,8 @@ class TestSsconfResolver(unittest.TestCase):
node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=NotImplemented)
result = rpc._SsconfResolver(node_list, NotImplemented,
ssc=ssc, nslookup_fn=NotImplemented)
self.assertEqual(result, zip(node_list, addr_list))
def testNsLookup(self):
......@@ -331,7 +336,8 @@ class TestSsconfResolver(unittest.TestCase):
ssc = GetFakeSimpleStoreClass(lambda _: [])
node_addr_map = dict(zip(node_list, addr_list))
nslookup_fn = lambda name, family=None: node_addr_map.get(name)
result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
result = rpc._SsconfResolver(node_list, NotImplemented,
ssc=ssc, nslookup_fn=nslookup_fn)
self.assertEqual(result, zip(node_list, addr_list))
def testBothLookups(self):
......@@ -342,7 +348,8 @@ class TestSsconfResolver(unittest.TestCase):
ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
nslookup_fn = lambda name, family=None: node_addr_map.get(name)
result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
result = rpc._SsconfResolver(node_list, NotImplemented,
ssc=ssc, nslookup_fn=nslookup_fn)
self.assertEqual(result, zip(node_list, addr_list))
def testAddressLookupIPv6(self):
......@@ -350,7 +357,8 @@ class TestSsconfResolver(unittest.TestCase):
node_list = ["node%d.example.com" % n for n in range(0, 255, 11)]
node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=NotImplemented)
result = rpc._SsconfResolver(node_list, NotImplemented,
ssc=ssc, nslookup_fn=NotImplemented)
self.assertEqual(result, zip(node_list, addr_list))
......@@ -359,11 +367,11 @@ class TestStaticResolver(unittest.TestCase):
addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)]
nodes = ["node%s.example.com" % n for n in range(0, 123, 7)]
res = rpc._StaticResolver(addresses)
self.assertEqual(res(nodes), zip(nodes, addresses))
self.assertEqual(res(nodes, NotImplemented), zip(nodes, addresses))
def testWrongLength(self):
res = rpc._StaticResolver([])
self.assertRaises(AssertionError, res, ["abc"])
self.assertRaises(AssertionError, res, ["abc"], NotImplemented)
class TestNodeConfigResolver(unittest.TestCase):
......@@ -380,24 +388,24 @@ class TestNodeConfigResolver(unittest.TestCase):
def testSingleOnline(self):
self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOnlineNode,
NotImplemented,
["node90.example.com"]),
["node90.example.com"], None),
[("node90.example.com", "192.0.2.90")])
def testSingleOffline(self):
self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOfflineNode,
NotImplemented,
["node100.example.com"]),
["node100.example.com"], None),
[("node100.example.com", rpc._OFFLINE)])
def testUnknownSingleNode(self):
self.assertEqual(rpc._NodeConfigResolver(lambda _: None, NotImplemented,
["node110.example.com"]),
["node110.example.com"], None),
[("node110.example.com", "node110.example.com")])
def testMultiEmpty(self):
self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
lambda: {},
[]),
[], None),
[])
def testMultiSomeOffline(self):
......@@ -410,7 +418,7 @@ class TestNodeConfigResolver(unittest.TestCase):
# Resolve no names
self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
lambda: nodes,
[]),
[], None),
[])
# Offline, online and unknown hosts
......@@ -419,7 +427,8 @@ class TestNodeConfigResolver(unittest.TestCase):
["node3.example.com",
"node92.example.com",
"node54.example.com",
"unknown.example.com",]), [
"unknown.example.com",],
None), [
("node3.example.com", rpc._OFFLINE),
("node92.example.com", "192.0.2.92"),
("node54.example.com", rpc._OFFLINE),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment