Skip to content
Snippets Groups Projects
Commit fce5efd1 authored by Michael Hanselmann's avatar Michael Hanselmann
Browse files

rpc: Pass resolver options to actual resolver


Signed-off-by: default avatarMichael Hanselmann <hansmi@google.com>
Reviewed-by: default avatarIustin Pop <iustin@google.com>
parent dd6d2d09
No related branches found
No related tags found
No related merge requests found
...@@ -240,7 +240,7 @@ class RpcResult(object): ...@@ -240,7 +240,7 @@ class RpcResult(object):
raise ec(*args) # pylint: disable=W0142 raise ec(*args) # pylint: disable=W0142
def _SsconfResolver(node_list, def _SsconfResolver(node_list, _,
ssc=ssconf.SimpleStore, ssc=ssconf.SimpleStore,
nslookup_fn=netutils.Hostname.GetIP): nslookup_fn=netutils.Hostname.GetIP):
"""Return addresses for given node names. """Return addresses for given node names.
...@@ -277,7 +277,7 @@ class _StaticResolver: ...@@ -277,7 +277,7 @@ class _StaticResolver:
""" """
self._addresses = addresses self._addresses = addresses
def __call__(self, hosts): def __call__(self, hosts, _):
"""Returns static addresses for hosts. """Returns static addresses for hosts.
""" """
...@@ -304,7 +304,7 @@ def _CheckConfigNode(name, node): ...@@ -304,7 +304,7 @@ def _CheckConfigNode(name, node):
return (name, ip) return (name, ip)
def _NodeConfigResolver(single_node_fn, all_nodes_fn, hosts): def _NodeConfigResolver(single_node_fn, all_nodes_fn, hosts, _):
"""Calculate node addresses using configuration. """Calculate node addresses using configuration.
""" """
...@@ -391,7 +391,7 @@ class _RpcProcessor: ...@@ -391,7 +391,7 @@ class _RpcProcessor:
return results return results
def __call__(self, hosts, procedure, body, read_timeout, def __call__(self, hosts, procedure, body, read_timeout, resolver_opts,
_req_process_fn=http.client.ProcessRequests): _req_process_fn=http.client.ProcessRequests):
"""Makes an RPC request to a number of nodes. """Makes an RPC request to a number of nodes.
...@@ -409,8 +409,8 @@ class _RpcProcessor: ...@@ -409,8 +409,8 @@ class _RpcProcessor:
"Missing RPC read timeout for procedure '%s'" % procedure "Missing RPC read timeout for procedure '%s'" % procedure
(results, requests) = \ (results, requests) = \
self._PrepareRequests(self._resolver(hosts), self._port, procedure, self._PrepareRequests(self._resolver(hosts, resolver_opts), self._port,
body, read_timeout) procedure, body, read_timeout)
_req_process_fn(requests.values(), lock_monitor_cb=self._lock_monitor_cb) _req_process_fn(requests.values(), lock_monitor_cb=self._lock_monitor_cb)
...@@ -469,7 +469,8 @@ class _RpcClientBase: ...@@ -469,7 +469,8 @@ class _RpcClientBase:
pnbody = dict((n, serializer.DumpJson(prep_fn(n, enc_args))) pnbody = dict((n, serializer.DumpJson(prep_fn(n, enc_args)))
for n in node_list) for n in node_list)
result = self._proc(node_list, procedure, pnbody, read_timeout) result = self._proc(node_list, procedure, pnbody, read_timeout,
req_resolver_opts)
if postproc_fn: if postproc_fn:
return dict(map(lambda (key, value): (key, postproc_fn(value)), return dict(map(lambda (key, value): (key, postproc_fn(value)),
......
...@@ -74,7 +74,7 @@ class TestRpcProcessor(unittest.TestCase): ...@@ -74,7 +74,7 @@ class TestRpcProcessor(unittest.TestCase):
http_proc = _FakeRequestProcessor(self._GetVersionResponse) http_proc = _FakeRequestProcessor(self._GetVersionResponse)
proc = rpc._RpcProcessor(resolver, 24094) proc = rpc._RpcProcessor(resolver, 24094)
result = proc(["localhost"], "version", {"localhost": ""}, 60, result = proc(["localhost"], "version", {"localhost": ""}, 60,
_req_process_fn=http_proc) NotImplemented, _req_process_fn=http_proc)
self.assertEqual(result.keys(), ["localhost"]) self.assertEqual(result.keys(), ["localhost"])
lhresp = result["localhost"] lhresp = result["localhost"]
self.assertFalse(lhresp.offline) self.assertFalse(lhresp.offline)
...@@ -100,7 +100,8 @@ class TestRpcProcessor(unittest.TestCase): ...@@ -100,7 +100,8 @@ class TestRpcProcessor(unittest.TestCase):
proc = rpc._RpcProcessor(resolver, 19176) proc = rpc._RpcProcessor(resolver, 19176)
host = "node31856" host = "node31856"
body = {host: ""} body = {host: ""}
result = proc([host], "version", body, 12356, _req_process_fn=http_proc) result = proc([host], "version", body, 12356, NotImplemented,
_req_process_fn=http_proc)
self.assertEqual(result.keys(), [host]) self.assertEqual(result.keys(), [host])
lhresp = result[host] lhresp = result[host]
self.assertFalse(lhresp.offline) self.assertFalse(lhresp.offline)
...@@ -117,7 +118,8 @@ class TestRpcProcessor(unittest.TestCase): ...@@ -117,7 +118,8 @@ class TestRpcProcessor(unittest.TestCase):
proc = rpc._RpcProcessor(resolver, 30668) proc = rpc._RpcProcessor(resolver, 30668)
host = "n17296" host = "n17296"
body = {host: ""} body = {host: ""}
result = proc([host], "version", body, 60, _req_process_fn=http_proc) result = proc([host], "version", body, 60, NotImplemented,
_req_process_fn=http_proc)
self.assertEqual(result.keys(), [host]) self.assertEqual(result.keys(), [host])
lhresp = result[host] lhresp = result[host]
self.assertTrue(lhresp.offline) self.assertTrue(lhresp.offline)
...@@ -148,7 +150,8 @@ class TestRpcProcessor(unittest.TestCase): ...@@ -148,7 +150,8 @@ class TestRpcProcessor(unittest.TestCase):
resolver = rpc._StaticResolver(nodes) resolver = rpc._StaticResolver(nodes)
http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse) http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse)
proc = rpc._RpcProcessor(resolver, 23245) proc = rpc._RpcProcessor(resolver, 23245)
result = proc(nodes, "version", body, 60, _req_process_fn=http_proc,) result = proc(nodes, "version", body, 60, NotImplemented,
_req_process_fn=http_proc)
self.assertEqual(sorted(result.keys()), sorted(nodes)) self.assertEqual(sorted(result.keys()), sorted(nodes))
for name in nodes: for name in nodes:
...@@ -177,7 +180,7 @@ class TestRpcProcessor(unittest.TestCase): ...@@ -177,7 +180,7 @@ class TestRpcProcessor(unittest.TestCase):
errinfo)) errinfo))
host = "aef9ur4i.example.com" host = "aef9ur4i.example.com"
body = {host: ""} body = {host: ""}
result = proc(body.keys(), "version", body, 60, result = proc(body.keys(), "version", body, 60, NotImplemented,
_req_process_fn=http_proc) _req_process_fn=http_proc)
self.assertEqual(result.keys(), [host]) self.assertEqual(result.keys(), [host])
lhresp = result[host] lhresp = result[host]
...@@ -227,7 +230,7 @@ class TestRpcProcessor(unittest.TestCase): ...@@ -227,7 +230,7 @@ class TestRpcProcessor(unittest.TestCase):
http_proc = \ http_proc = \
_FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse, _FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse,
httperrnodes, failnodes)) httperrnodes, failnodes))
result = proc(nodes, "vg_list", body, rpc._TMO_URGENT, result = proc(nodes, "vg_list", body, rpc._TMO_URGENT, NotImplemented,
_req_process_fn=http_proc) _req_process_fn=http_proc)
self.assertEqual(sorted(result.keys()), sorted(nodes)) self.assertEqual(sorted(result.keys()), sorted(nodes))
...@@ -272,7 +275,7 @@ class TestRpcProcessor(unittest.TestCase): ...@@ -272,7 +275,7 @@ class TestRpcProcessor(unittest.TestCase):
http_proc = _FakeRequestProcessor(fn) http_proc = _FakeRequestProcessor(fn)
host = "oqo7lanhly.example.com" host = "oqo7lanhly.example.com"
body = {host: ""} body = {host: ""}
result = proc([host], "version", body, 60, result = proc([host], "version", body, 60, NotImplemented,
_req_process_fn=http_proc) _req_process_fn=http_proc)
self.assertEqual(result.keys(), [host]) self.assertEqual(result.keys(), [host])
lhresp = result[host] lhresp = result[host]
...@@ -304,7 +307,8 @@ class TestRpcProcessor(unittest.TestCase): ...@@ -304,7 +307,8 @@ class TestRpcProcessor(unittest.TestCase):
proc = rpc._RpcProcessor(resolver, 18700) proc = rpc._RpcProcessor(resolver, 18700)
host = "node19759" host = "node19759"
body = {host: serializer.DumpJson(test_data)} body = {host: serializer.DumpJson(test_data)}
result = proc([host], "upload_file", body, 30, _req_process_fn=http_proc) result = proc([host], "upload_file", body, 30, NotImplemented,
_req_process_fn=http_proc)
self.assertEqual(result.keys(), [host]) self.assertEqual(result.keys(), [host])
lhresp = result[host] lhresp = result[host]
self.assertFalse(lhresp.offline) self.assertFalse(lhresp.offline)
...@@ -322,7 +326,8 @@ class TestSsconfResolver(unittest.TestCase): ...@@ -322,7 +326,8 @@ class TestSsconfResolver(unittest.TestCase):
node_list = ["node%d.example.com" % n for n in range(0, 255, 13)] node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)] node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list) ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=NotImplemented) result = rpc._SsconfResolver(node_list, NotImplemented,
ssc=ssc, nslookup_fn=NotImplemented)
self.assertEqual(result, zip(node_list, addr_list)) self.assertEqual(result, zip(node_list, addr_list))
def testNsLookup(self): def testNsLookup(self):
...@@ -331,7 +336,8 @@ class TestSsconfResolver(unittest.TestCase): ...@@ -331,7 +336,8 @@ class TestSsconfResolver(unittest.TestCase):
ssc = GetFakeSimpleStoreClass(lambda _: []) ssc = GetFakeSimpleStoreClass(lambda _: [])
node_addr_map = dict(zip(node_list, addr_list)) node_addr_map = dict(zip(node_list, addr_list))
nslookup_fn = lambda name, family=None: node_addr_map.get(name) nslookup_fn = lambda name, family=None: node_addr_map.get(name)
result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=nslookup_fn) result = rpc._SsconfResolver(node_list, NotImplemented,
ssc=ssc, nslookup_fn=nslookup_fn)
self.assertEqual(result, zip(node_list, addr_list)) self.assertEqual(result, zip(node_list, addr_list))
def testBothLookups(self): def testBothLookups(self):
...@@ -342,7 +348,8 @@ class TestSsconfResolver(unittest.TestCase): ...@@ -342,7 +348,8 @@ class TestSsconfResolver(unittest.TestCase):
ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list) ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
node_addr_map = dict(zip(node_list[:n], addr_list[:n])) node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
nslookup_fn = lambda name, family=None: node_addr_map.get(name) nslookup_fn = lambda name, family=None: node_addr_map.get(name)
result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=nslookup_fn) result = rpc._SsconfResolver(node_list, NotImplemented,
ssc=ssc, nslookup_fn=nslookup_fn)
self.assertEqual(result, zip(node_list, addr_list)) self.assertEqual(result, zip(node_list, addr_list))
def testAddressLookupIPv6(self): def testAddressLookupIPv6(self):
...@@ -350,7 +357,8 @@ class TestSsconfResolver(unittest.TestCase): ...@@ -350,7 +357,8 @@ class TestSsconfResolver(unittest.TestCase):
node_list = ["node%d.example.com" % n for n in range(0, 255, 11)] node_list = ["node%d.example.com" % n for n in range(0, 255, 11)]
node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)] node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list) ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=NotImplemented) result = rpc._SsconfResolver(node_list, NotImplemented,
ssc=ssc, nslookup_fn=NotImplemented)
self.assertEqual(result, zip(node_list, addr_list)) self.assertEqual(result, zip(node_list, addr_list))
...@@ -359,11 +367,11 @@ class TestStaticResolver(unittest.TestCase): ...@@ -359,11 +367,11 @@ class TestStaticResolver(unittest.TestCase):
addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)] addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)]
nodes = ["node%s.example.com" % n for n in range(0, 123, 7)] nodes = ["node%s.example.com" % n for n in range(0, 123, 7)]
res = rpc._StaticResolver(addresses) res = rpc._StaticResolver(addresses)
self.assertEqual(res(nodes), zip(nodes, addresses)) self.assertEqual(res(nodes, NotImplemented), zip(nodes, addresses))
def testWrongLength(self): def testWrongLength(self):
res = rpc._StaticResolver([]) res = rpc._StaticResolver([])
self.assertRaises(AssertionError, res, ["abc"]) self.assertRaises(AssertionError, res, ["abc"], NotImplemented)
class TestNodeConfigResolver(unittest.TestCase): class TestNodeConfigResolver(unittest.TestCase):
...@@ -380,24 +388,24 @@ class TestNodeConfigResolver(unittest.TestCase): ...@@ -380,24 +388,24 @@ class TestNodeConfigResolver(unittest.TestCase):
def testSingleOnline(self): def testSingleOnline(self):
self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOnlineNode, self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOnlineNode,
NotImplemented, NotImplemented,
["node90.example.com"]), ["node90.example.com"], None),
[("node90.example.com", "192.0.2.90")]) [("node90.example.com", "192.0.2.90")])
def testSingleOffline(self): def testSingleOffline(self):
self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOfflineNode, self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOfflineNode,
NotImplemented, NotImplemented,
["node100.example.com"]), ["node100.example.com"], None),
[("node100.example.com", rpc._OFFLINE)]) [("node100.example.com", rpc._OFFLINE)])
def testUnknownSingleNode(self): def testUnknownSingleNode(self):
self.assertEqual(rpc._NodeConfigResolver(lambda _: None, NotImplemented, self.assertEqual(rpc._NodeConfigResolver(lambda _: None, NotImplemented,
["node110.example.com"]), ["node110.example.com"], None),
[("node110.example.com", "node110.example.com")]) [("node110.example.com", "node110.example.com")])
def testMultiEmpty(self): def testMultiEmpty(self):
self.assertEqual(rpc._NodeConfigResolver(NotImplemented, self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
lambda: {}, lambda: {},
[]), [], None),
[]) [])
def testMultiSomeOffline(self): def testMultiSomeOffline(self):
...@@ -410,7 +418,7 @@ class TestNodeConfigResolver(unittest.TestCase): ...@@ -410,7 +418,7 @@ class TestNodeConfigResolver(unittest.TestCase):
# Resolve no names # Resolve no names
self.assertEqual(rpc._NodeConfigResolver(NotImplemented, self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
lambda: nodes, lambda: nodes,
[]), [], None),
[]) [])
# Offline, online and unknown hosts # Offline, online and unknown hosts
...@@ -419,7 +427,8 @@ class TestNodeConfigResolver(unittest.TestCase): ...@@ -419,7 +427,8 @@ class TestNodeConfigResolver(unittest.TestCase):
["node3.example.com", ["node3.example.com",
"node92.example.com", "node92.example.com",
"node54.example.com", "node54.example.com",
"unknown.example.com",]), [ "unknown.example.com",],
None), [
("node3.example.com", rpc._OFFLINE), ("node3.example.com", rpc._OFFLINE),
("node92.example.com", "192.0.2.92"), ("node92.example.com", "192.0.2.92"),
("node54.example.com", rpc._OFFLINE), ("node54.example.com", rpc._OFFLINE),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment