diff --git a/lib/rpc.py b/lib/rpc.py index beb6438a545151796d4ff0bb2a3b7e5373901b8a..bb03448002eacac34acfe60482a280aa46994ff2 100644 --- a/lib/rpc.py +++ b/lib/rpc.py @@ -46,6 +46,7 @@ from ganeti import errors from ganeti import netutils from ganeti import ssconf from ganeti import runtime +from ganeti import compat # pylint has a bug here, doesn't see this import import ganeti.http.client # pylint: disable=W0611 @@ -77,6 +78,9 @@ _TMO_1DAY = 86400 _TIMEOUTS = { } +#: Special value to describe an offline host +_OFFLINE = object() + def Init(): """Initializes the module-global HTTP client manager. @@ -285,9 +289,9 @@ class RpcResult(object): raise ec(*args) # pylint: disable=W0142 -def _AddressLookup(node_list, - ssc=ssconf.SimpleStore, - nslookup_fn=netutils.Hostname.GetIP): +def _SsconfResolver(node_list, + ssc=ssconf.SimpleStore, + nslookup_fn=netutils.Hostname.GetIP): """Return addresses for given node names. @type node_list: list @@ -296,126 +300,163 @@ def _AddressLookup(node_list, @param ssc: SimpleStore class that is used to obtain node->ip mappings @type nslookup_fn: callable @param nslookup_fn: function use to do NS lookup - @rtype: list of addresses and/or None's - @returns: List of corresponding addresses, if found + @rtype: list of tuple; (string, string) + @return: List of tuples containing node name and IP address """ ss = ssc() iplist = ss.GetNodePrimaryIPList() family = ss.GetPrimaryIPFamily() - addresses = [] ipmap = dict(entry.split() for entry in iplist) + + result = [] for node in node_list: - address = ipmap.get(node) - if address is None: - address = nslookup_fn(node, family=family) - addresses.append(address) + ip = ipmap.get(node) + if ip is None: + ip = nslookup_fn(node, family=family) + result.append((node, ip)) + + return result + + +class _StaticResolver: + def __init__(self, addresses): + """Initializes this class. + + """ + self._addresses = addresses - return addresses + def __call__(self, hosts): + """Returns static addresses for hosts. + """ + assert len(hosts) == len(self._addresses) + return zip(hosts, self._addresses) -class Client: - """RPC Client class. - This class, given a (remote) method name, a list of parameters and a - list of nodes, will contact (in parallel) all nodes, and return a - dict of results (key: node name, value: result). +def _CheckConfigNode(name, node): + """Checks if a node is online. - One current bug is that generic failure is still signaled by - 'False' result, which is not good. This overloading of values can - cause bugs. + @type name: string + @param name: Node name + @type node: L{objects.Node} or None + @param node: Node object """ - def __init__(self, procedure, body, port, address_lookup_fn=_AddressLookup): - assert procedure in _TIMEOUTS, ("New RPC call not declared in the" - " timeouts table") - self.procedure = procedure - self.body = body - self.port = port - self._request = {} - self._address_lookup_fn = address_lookup_fn - - def ConnectList(self, node_list, address_list=None, read_timeout=None): - """Add a list of nodes to the target nodes. + if node is None: + # Depend on DNS for name resolution + ip = name + elif node.offline: + ip = _OFFLINE + else: + ip = node.primary_ip + return (name, ip) - @type node_list: list - @param node_list: the list of node names to connect - @type address_list: list or None - @keyword address_list: either None or a list with node addresses, - which must have the same length as the node list - @type read_timeout: int - @param read_timeout: overwrites default timeout for operation + +def _NodeConfigResolver(single_node_fn, all_nodes_fn, hosts): + """Calculate node addresses using configuration. + + """ + # Special case for single-host lookups + if len(hosts) == 1: + (name, ) = hosts + return [_CheckConfigNode(name, single_node_fn(name))] + else: + all_nodes = all_nodes_fn() + return [_CheckConfigNode(name, all_nodes.get(name, None)) + for name in hosts] + + +class _RpcProcessor: + def __init__(self, resolver, port): + """Initializes this class. + + @param resolver: callable accepting a list of hostnames, returning a list + of tuples containing name and IP address (IP address can be the name or + the special value L{_OFFLINE} to mark offline machines) + @type port: int + @param port: TCP port """ - if address_list is None: - # Always use IP address instead of node name - address_list = self._address_lookup_fn(node_list) + self._resolver = resolver + self._port = port + + @staticmethod + def _PrepareRequests(hosts, port, procedure, body, read_timeout): + """Prepares requests by sorting offline hosts into separate list. - assert len(node_list) == len(address_list), \ - "Name and address lists must have the same length" + """ + results = {} + requests = {} - for node, address in zip(node_list, address_list): - self.ConnectNode(node, address, read_timeout=read_timeout) + for (name, ip) in hosts: + if ip is _OFFLINE: + # Node is marked as offline + results[name] = RpcResult(node=name, offline=True, call=procedure) + else: + requests[name] = \ + http.client.HttpClientRequest(str(ip), port, + http.HTTP_PUT, str("/%s" % procedure), + headers=_RPC_CLIENT_HEADERS, + post_data=body, + read_timeout=read_timeout) - def ConnectNode(self, name, address=None, read_timeout=None): - """Add a node to the target list. + return (results, requests) - @type name: str - @param name: the node name - @type address: str - @param address: the node address, if known - @type read_timeout: int - @param read_timeout: overwrites default timeout for operation + @staticmethod + def _CombineResults(results, requests, procedure): + """Combines pre-computed results for offline hosts with actual call results. """ - if address is None: - # Always use IP address instead of node name - address = self._address_lookup_fn([name])[0] + for name, req in requests.items(): + if req.success and req.resp_status_code == http.HTTP_OK: + host_result = RpcResult(data=serializer.LoadJson(req.resp_body), + node=name, call=procedure) + else: + # TODO: Better error reporting + if req.error: + msg = req.error + else: + msg = req.resp_body - assert(address is not None) + logging.error("RPC error in %s on node %s: %s", procedure, name, msg) + host_result = RpcResult(data=msg, failed=True, node=name, + call=procedure) - if read_timeout is None: - read_timeout = _TIMEOUTS[self.procedure] + results[name] = host_result - self._request[name] = \ - http.client.HttpClientRequest(str(address), self.port, - http.HTTP_PUT, str("/%s" % self.procedure), - headers=_RPC_CLIENT_HEADERS, - post_data=str(self.body), - read_timeout=read_timeout) + return results - def GetResults(self, http_pool=None): - """Call nodes and return results. + def __call__(self, hosts, procedure, body, read_timeout=None, http_pool=None): + """Makes an RPC request to a number of nodes. - @rtype: list - @return: List of RPC results + @type hosts: sequence + @param hosts: Hostnames + @type procedure: string + @param procedure: Request path + @type body: string + @param body: Request body + @type read_timeout: int or None + @param read_timeout: Read timeout for request """ + assert procedure in _TIMEOUTS, "RPC call not declared in the timeouts table" + if not http_pool: http_pool = _thread_local.GetHttpClientPool() - http_pool.ProcessRequests(self._request.values()) - - results = {} + if read_timeout is None: + read_timeout = _TIMEOUTS[procedure] - for name, req in self._request.iteritems(): - if req.success and req.resp_status_code == http.HTTP_OK: - results[name] = RpcResult(data=serializer.LoadJson(req.resp_body), - node=name, call=self.procedure) - continue + (results, requests) = \ + self._PrepareRequests(self._resolver(hosts), self._port, procedure, + str(body), read_timeout) - # TODO: Better error reporting - if req.error: - msg = req.error - else: - msg = req.resp_body + http_pool.ProcessRequests(requests.values()) - logging.error("RPC error in %s from node %s: %s", - self.procedure, name, msg) - results[name] = RpcResult(data=msg, failed=True, node=name, - call=self.procedure) + assert not frozenset(results).intersection(requests) - return results + return self._CombineResults(results, requests, procedure) def _EncodeImportExportIO(ieio, ieioargs): @@ -445,7 +486,10 @@ class RpcRunner(object): """ self._cfg = context.cfg - self.port = netutils.GetDaemonPort(constants.NODED) + self._proc = _RpcProcessor(compat.partial(_NodeConfigResolver, + self._cfg.GetNodeInfo, + self._cfg.GetAllNodesInfo), + netutils.GetDaemonPort(constants.NODED)) def _InstDict(self, instance, hvp=None, bep=None, osp=None): """Convert the given instance to a dict. @@ -483,98 +527,37 @@ class RpcRunner(object): nic['nicparams']) return idict - def _ConnectList(self, client, node_list, call, read_timeout=None): - """Helper for computing node addresses. - - @type client: L{ganeti.rpc.Client} - @param client: a C{Client} instance - @type node_list: list - @param node_list: the node list we should connect - @type call: string - @param call: the name of the remote procedure call, for filling in - correctly any eventual offline nodes' results - @type read_timeout: int - @param read_timeout: overwrites the default read timeout for the - given operation - - """ - all_nodes = self._cfg.GetAllNodesInfo() - name_list = [] - addr_list = [] - skip_dict = {} - for node in node_list: - if node in all_nodes: - if all_nodes[node].offline: - skip_dict[node] = RpcResult(node=node, offline=True, call=call) - continue - val = all_nodes[node].primary_ip - else: - val = None - addr_list.append(val) - name_list.append(node) - if name_list: - client.ConnectList(name_list, address_list=addr_list, - read_timeout=read_timeout) - return skip_dict - - def _ConnectNode(self, client, node, call, read_timeout=None): - """Helper for computing one node's address. - - @type client: L{ganeti.rpc.Client} - @param client: a C{Client} instance - @type node: str - @param node: the node we should connect - @type call: string - @param call: the name of the remote procedure call, for filling in - correctly any eventual offline nodes' results - @type read_timeout: int - @param read_timeout: overwrites the default read timeout for the - given operation - - """ - node_info = self._cfg.GetNodeInfo(node) - if node_info is not None: - if node_info.offline: - return RpcResult(node=node, offline=True, call=call) - addr = node_info.primary_ip - else: - addr = None - client.ConnectNode(node, address=addr, read_timeout=read_timeout) - def _MultiNodeCall(self, node_list, procedure, args, read_timeout=None): """Helper for making a multi-node call """ body = serializer.DumpJson(args, indent=False) - c = Client(procedure, body, self.port) - skip_dict = self._ConnectList(c, node_list, procedure, - read_timeout=read_timeout) - skip_dict.update(c.GetResults()) - return skip_dict + return self._proc(node_list, procedure, body, read_timeout=read_timeout) - @classmethod - def _StaticMultiNodeCall(cls, node_list, procedure, args, + @staticmethod + def _StaticMultiNodeCall(node_list, procedure, args, address_list=None, read_timeout=None): """Helper for making a multi-node static call """ body = serializer.DumpJson(args, indent=False) - c = Client(procedure, body, netutils.GetDaemonPort(constants.NODED)) - c.ConnectList(node_list, address_list=address_list, - read_timeout=read_timeout) - return c.GetResults() + + if address_list is None: + resolver = _SsconfResolver + else: + # Caller provided an address list + resolver = _StaticResolver(address_list) + + proc = _RpcProcessor(resolver, + netutils.GetDaemonPort(constants.NODED)) + return proc(node_list, procedure, body, read_timeout=read_timeout) def _SingleNodeCall(self, node, procedure, args, read_timeout=None): """Helper for making a single-node call """ body = serializer.DumpJson(args, indent=False) - c = Client(procedure, body, self.port) - result = self._ConnectNode(c, node, procedure, read_timeout=read_timeout) - if result is None: - # we did connect, node is not offline - result = c.GetResults()[node] - return result + return self._proc([node], procedure, body, read_timeout=read_timeout)[node] @classmethod def _StaticSingleNodeCall(cls, node, procedure, args, read_timeout=None): @@ -582,9 +565,9 @@ class RpcRunner(object): """ body = serializer.DumpJson(args, indent=False) - c = Client(procedure, body, netutils.GetDaemonPort(constants.NODED)) - c.ConnectNode(node, read_timeout=read_timeout) - return c.GetResults()[node] + proc = _RpcProcessor(_SsconfResolver, + netutils.GetDaemonPort(constants.NODED)) + return proc([node], procedure, body, read_timeout=read_timeout)[node] # # Begin RPC calls diff --git a/test/ganeti.rpc_unittest.py b/test/ganeti.rpc_unittest.py index ce09d8b8fa31c6c2980f8b3cefb58fb7300f27e4..9ed0628ddc4d07e8003b9498574caf64f8bc86d5 100755 --- a/test/ganeti.rpc_unittest.py +++ b/test/ganeti.rpc_unittest.py @@ -31,6 +31,7 @@ from ganeti import rpc from ganeti import http from ganeti import errors from ganeti import serializer +from ganeti import objects import testutils @@ -64,24 +65,24 @@ def GetFakeSimpleStoreClass(fn): return FakeSimpleStore -class TestClient(unittest.TestCase): +class TestRpcProcessor(unittest.TestCase): def _FakeAddressLookup(self, map): return lambda node_list: [map.get(node) for node in node_list] def _GetVersionResponse(self, req): - self.assertEqual(req.host, "localhost") + self.assertEqual(req.host, "127.0.0.1") self.assertEqual(req.port, 24094) self.assertEqual(req.path, "/version") + self.assertEqual(req.read_timeout, rpc._TMO_URGENT) req.success = True req.resp_status_code = http.HTTP_OK req.resp_body = serializer.DumpJson((True, 123)) def testVersionSuccess(self): - fn = self._FakeAddressLookup({"localhost": "localhost"}) - client = rpc.Client("version", None, 24094, address_lookup_fn=fn) - client.ConnectNode("localhost") + resolver = rpc._StaticResolver(["127.0.0.1"]) pool = FakeHttpPool(self._GetVersionResponse) - result = client.GetResults(http_pool=pool) + proc = rpc._RpcProcessor(resolver, 24094) + result = proc(["localhost"], "version", None, http_pool=pool) self.assertEqual(result.keys(), ["localhost"]) lhresp = result["localhost"] self.assertFalse(lhresp.offline) @@ -92,6 +93,52 @@ class TestClient(unittest.TestCase): lhresp.Raise("should not raise") self.assertEqual(pool.reqcount, 1) + def _ReadTimeoutResponse(self, req): + self.assertEqual(req.host, "192.0.2.13") + self.assertEqual(req.port, 19176) + self.assertEqual(req.path, "/version") + self.assertEqual(req.read_timeout, 12356) + req.success = True + req.resp_status_code = http.HTTP_OK + req.resp_body = serializer.DumpJson((True, -1)) + + def testReadTimeout(self): + resolver = rpc._StaticResolver(["192.0.2.13"]) + pool = FakeHttpPool(self._ReadTimeoutResponse) + proc = rpc._RpcProcessor(resolver, 19176) + result = proc(["node31856"], "version", None, http_pool=pool, + read_timeout=12356) + self.assertEqual(result.keys(), ["node31856"]) + lhresp = result["node31856"] + self.assertFalse(lhresp.offline) + self.assertEqual(lhresp.node, "node31856") + self.assertFalse(lhresp.fail_msg) + self.assertEqual(lhresp.payload, -1) + self.assertEqual(lhresp.call, "version") + lhresp.Raise("should not raise") + self.assertEqual(pool.reqcount, 1) + + def testOfflineNode(self): + resolver = rpc._StaticResolver([rpc._OFFLINE]) + pool = FakeHttpPool(NotImplemented) + proc = rpc._RpcProcessor(resolver, 30668) + result = proc(["n17296"], "version", None, http_pool=pool) + self.assertEqual(result.keys(), ["n17296"]) + lhresp = result["n17296"] + self.assertTrue(lhresp.offline) + self.assertEqual(lhresp.node, "n17296") + self.assertTrue(lhresp.fail_msg) + self.assertFalse(lhresp.payload) + self.assertEqual(lhresp.call, "version") + + # With a message + self.assertRaises(errors.OpExecError, lhresp.Raise, "should raise") + + # No message + self.assertRaises(errors.OpExecError, lhresp.Raise, None) + + self.assertEqual(pool.reqcount, 0) + def _GetMultiVersionResponse(self, req): self.assert_(req.host.startswith("node")) self.assertEqual(req.port, 23245) @@ -102,12 +149,10 @@ class TestClient(unittest.TestCase): def testMultiVersionSuccess(self): nodes = ["node%s" % i for i in range(50)] - fn = self._FakeAddressLookup(dict(zip(nodes, nodes))) - client = rpc.Client("version", None, 23245, address_lookup_fn=fn) - client.ConnectList(nodes) - + resolver = rpc._StaticResolver(nodes) pool = FakeHttpPool(self._GetMultiVersionResponse) - result = client.GetResults(http_pool=pool) + proc = rpc._RpcProcessor(resolver, 23245) + result = proc(nodes, "version", None, http_pool=pool) self.assertEqual(sorted(result.keys()), sorted(nodes)) for name in nodes: @@ -121,28 +166,27 @@ class TestClient(unittest.TestCase): self.assertEqual(pool.reqcount, len(nodes)) - def _GetVersionResponseFail(self, req): + def _GetVersionResponseFail(self, errinfo, req): self.assertEqual(req.path, "/version") req.success = True req.resp_status_code = http.HTTP_OK - req.resp_body = serializer.DumpJson((False, "Unknown error")) + req.resp_body = serializer.DumpJson((False, errinfo)) def testVersionFailure(self): - lookup_map = {"aef9ur4i.example.com": "aef9ur4i.example.com"} - fn = self._FakeAddressLookup(lookup_map) - client = rpc.Client("version", None, 5903, address_lookup_fn=fn) - client.ConnectNode("aef9ur4i.example.com") - pool = FakeHttpPool(self._GetVersionResponseFail) - result = client.GetResults(http_pool=pool) - self.assertEqual(result.keys(), ["aef9ur4i.example.com"]) - lhresp = result["aef9ur4i.example.com"] - self.assertFalse(lhresp.offline) - self.assertEqual(lhresp.node, "aef9ur4i.example.com") - self.assert_(lhresp.fail_msg) - self.assertFalse(lhresp.payload) - self.assertEqual(lhresp.call, "version") - self.assertRaises(errors.OpExecError, lhresp.Raise, "failed") - self.assertEqual(pool.reqcount, 1) + resolver = rpc._StaticResolver(["aef9ur4i.example.com"]) + proc = rpc._RpcProcessor(resolver, 5903) + for errinfo in [None, "Unknown error"]: + pool = FakeHttpPool(compat.partial(self._GetVersionResponseFail, errinfo)) + result = proc(["aef9ur4i.example.com"], "version", None, http_pool=pool) + self.assertEqual(result.keys(), ["aef9ur4i.example.com"]) + lhresp = result["aef9ur4i.example.com"] + self.assertFalse(lhresp.offline) + self.assertEqual(lhresp.node, "aef9ur4i.example.com") + self.assert_(lhresp.fail_msg) + self.assertFalse(lhresp.payload) + self.assertEqual(lhresp.call, "version") + self.assertRaises(errors.OpExecError, lhresp.Raise, "failed") + self.assertEqual(pool.reqcount, 1) def _GetHttpErrorResponse(self, httperrnodes, failnodes, req): self.assertEqual(req.path, "/vg_list") @@ -167,7 +211,7 @@ class TestClient(unittest.TestCase): def testHttpError(self): nodes = ["uaf6pbbv%s" % i for i in range(50)] - fn = self._FakeAddressLookup(dict(zip(nodes, nodes))) + resolver = rpc._StaticResolver(nodes) httperrnodes = set(nodes[1::7]) self.assertEqual(len(httperrnodes), 7) @@ -177,12 +221,10 @@ class TestClient(unittest.TestCase): self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29) - client = rpc.Client("vg_list", None, 15165, address_lookup_fn=fn) - client.ConnectList(nodes) - + proc = rpc._RpcProcessor(resolver, 15165) pool = FakeHttpPool(compat.partial(self._GetHttpErrorResponse, httperrnodes, failnodes)) - result = client.GetResults(http_pool=pool) + result = proc(nodes, "vg_list", None, http_pool=pool) self.assertEqual(sorted(result.keys()), sorted(nodes)) for name in nodes: @@ -219,13 +261,12 @@ class TestClient(unittest.TestCase): req.resp_body = serializer.DumpJson("invalid response") def testInvalidResponse(self): - lookup_map = {"oqo7lanhly.example.com": "oqo7lanhly.example.com"} - fn = self._FakeAddressLookup(lookup_map) - client = rpc.Client("version", None, 19978, address_lookup_fn=fn) + resolver = rpc._StaticResolver(["oqo7lanhly.example.com"]) + proc = rpc._RpcProcessor(resolver, 19978) + for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]: - client.ConnectNode("oqo7lanhly.example.com") pool = FakeHttpPool(fn) - result = client.GetResults(http_pool=pool) + result = proc(["oqo7lanhly.example.com"], "version", None, http_pool=pool) self.assertEqual(result.keys(), ["oqo7lanhly.example.com"]) lhresp = result["oqo7lanhly.example.com"] self.assertFalse(lhresp.offline) @@ -236,41 +277,145 @@ class TestClient(unittest.TestCase): self.assertRaises(errors.OpExecError, lhresp.Raise, "failed") self.assertEqual(pool.reqcount, 1) - def testAddressLookupSimpleStore(self): + def _GetBodyTestResponse(self, test_data, req): + self.assertEqual(req.host, "192.0.2.84") + self.assertEqual(req.port, 18700) + self.assertEqual(req.path, "/upload_file") + self.assertEqual(serializer.LoadJson(req.post_data), test_data) + req.success = True + req.resp_status_code = http.HTTP_OK + req.resp_body = serializer.DumpJson((True, None)) + + def testResponseBody(self): + test_data = { + "Hello": "World", + "xyz": range(10), + } + resolver = rpc._StaticResolver(["192.0.2.84"]) + pool = FakeHttpPool(compat.partial(self._GetBodyTestResponse, test_data)) + proc = rpc._RpcProcessor(resolver, 18700) + body = serializer.DumpJson(test_data) + result = proc(["node19759"], "upload_file", body, http_pool=pool) + self.assertEqual(result.keys(), ["node19759"]) + lhresp = result["node19759"] + self.assertFalse(lhresp.offline) + self.assertEqual(lhresp.node, "node19759") + self.assertFalse(lhresp.fail_msg) + self.assertEqual(lhresp.payload, None) + self.assertEqual(lhresp.call, "upload_file") + lhresp.Raise("should not raise") + self.assertEqual(pool.reqcount, 1) + + +class TestSsconfResolver(unittest.TestCase): + def testSsconfLookup(self): addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)] node_list = ["node%d.example.com" % n for n in range(0, 255, 13)] - node_addr_list = [ " ".join(t) for t in zip(node_list, addr_list)] + node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)] ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list) - result = rpc._AddressLookup(node_list, ssc=ssc) - self.assertEqual(result, addr_list) + result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=NotImplemented) + self.assertEqual(result, zip(node_list, addr_list)) - def testAddressLookupNSLookup(self): + def testNsLookup(self): addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)] node_list = ["node%d.example.com" % n for n in range(0, 255, 13)] ssc = GetFakeSimpleStoreClass(lambda _: []) node_addr_map = dict(zip(node_list, addr_list)) nslookup_fn = lambda name, family=None: node_addr_map.get(name) - result = rpc._AddressLookup(node_list, ssc=ssc, nslookup_fn=nslookup_fn) - self.assertEqual(result, addr_list) + result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=nslookup_fn) + self.assertEqual(result, zip(node_list, addr_list)) - def testAddressLookupBoth(self): + def testBothLookups(self): addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)] node_list = ["node%d.example.com" % n for n in range(0, 255, 13)] n = len(addr_list) / 2 - node_addr_list = [ " ".join(t) for t in zip(node_list[n:], addr_list[n:])] + node_addr_list = [" ".join(t) for t in zip(node_list[n:], addr_list[n:])] ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list) node_addr_map = dict(zip(node_list[:n], addr_list[:n])) nslookup_fn = lambda name, family=None: node_addr_map.get(name) - result = rpc._AddressLookup(node_list, ssc=ssc, nslookup_fn=nslookup_fn) - self.assertEqual(result, addr_list) + result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=nslookup_fn) + self.assertEqual(result, zip(node_list, addr_list)) def testAddressLookupIPv6(self): - addr_list = ["2001:db8::%d" % n for n in range(0, 255, 13)] - node_list = ["node%d.example.com" % n for n in range(0, 255, 13)] - node_addr_list = [ " ".join(t) for t in zip(node_list, addr_list)] + addr_list = ["2001:db8::%d" % n for n in range(0, 255, 11)] + node_list = ["node%d.example.com" % n for n in range(0, 255, 11)] + node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)] ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list) - result = rpc._AddressLookup(node_list, ssc=ssc) - self.assertEqual(result, addr_list) + result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=NotImplemented) + self.assertEqual(result, zip(node_list, addr_list)) + + +class TestStaticResolver(unittest.TestCase): + def test(self): + addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)] + nodes = ["node%s.example.com" % n for n in range(0, 123, 7)] + res = rpc._StaticResolver(addresses) + self.assertEqual(res(nodes), zip(nodes, addresses)) + + def testWrongLength(self): + res = rpc._StaticResolver([]) + self.assertRaises(AssertionError, res, ["abc"]) + + +class TestNodeConfigResolver(unittest.TestCase): + @staticmethod + def _GetSingleOnlineNode(name): + assert name == "node90.example.com" + return objects.Node(name=name, offline=False, primary_ip="192.0.2.90") + + @staticmethod + def _GetSingleOfflineNode(name): + assert name == "node100.example.com" + return objects.Node(name=name, offline=True, primary_ip="192.0.2.100") + + def testSingleOnline(self): + self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOnlineNode, + NotImplemented, + ["node90.example.com"]), + [("node90.example.com", "192.0.2.90")]) + + def testSingleOffline(self): + self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOfflineNode, + NotImplemented, + ["node100.example.com"]), + [("node100.example.com", rpc._OFFLINE)]) + + def testUnknownSingleNode(self): + self.assertEqual(rpc._NodeConfigResolver(lambda _: None, NotImplemented, + ["node110.example.com"]), + [("node110.example.com", "node110.example.com")]) + + def testMultiEmpty(self): + self.assertEqual(rpc._NodeConfigResolver(NotImplemented, + lambda: {}, + []), + []) + + def testMultiSomeOffline(self): + nodes = dict(("node%s.example.com" % i, + objects.Node(name="node%s.example.com" % i, + offline=((i % 3) == 0), + primary_ip="192.0.2.%s" % i)) + for i in range(1, 255)) + + # Resolve no names + self.assertEqual(rpc._NodeConfigResolver(NotImplemented, + lambda: nodes, + []), + []) + + # Offline, online and unknown hosts + self.assertEqual(rpc._NodeConfigResolver(NotImplemented, + lambda: nodes, + ["node3.example.com", + "node92.example.com", + "node54.example.com", + "unknown.example.com",]), [ + ("node3.example.com", rpc._OFFLINE), + ("node92.example.com", "192.0.2.92"), + ("node54.example.com", rpc._OFFLINE), + ("unknown.example.com", "unknown.example.com"), + ]) if __name__ == "__main__":