diff --git a/lib/rpc.py b/lib/rpc.py index 08f69ef291f835f6f64623b6541e954b79a41442..2b26e6205e3e3d9ffe3e1170677fd68ae923211b 100644 --- a/lib/rpc.py +++ b/lib/rpc.py @@ -1,7 +1,7 @@ # # -# Copyright (C) 2006, 2007 Google Inc. +# Copyright (C) 2006, 2007, 2010 Google Inc. # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -44,6 +44,7 @@ from ganeti import serializer from ganeti import constants from ganeti import errors from ganeti import netutils +from ganeti import ssconf # pylint has a bug here, doesn't see this import import ganeti.http.client # pylint: disable-msg=W0611 @@ -256,6 +257,41 @@ class RpcResult(object): raise ec(*args) # pylint: disable-msg=W0142 +def _AddressLookup(node_list, + ssc=ssconf.SimpleStore, + nslookup_fn=netutils.HostInfo.LookupHostname): + """Return addresses for given node names. + + @type node_list: list + @param node_list: List of node names + @type ssc: class + @param ssc: SimpleStore class that is used to obtain node->ip mappings + @type lookup_fn: callable + @param lookup_fn: function use to do NS lookup + @rtype: list of addresses and/or None's + @returns: List of corresponding addresses, if found + + """ + def _NSLookup(name): + _, _, addrs = nslookup_fn(name) + return addrs[0] + + addresses = [] + try: + iplist = ssc().GetNodePrimaryIPList() + ipmap = dict(entry.split() for entry in iplist) + for node in node_list: + address = ipmap.get(node) + if address is None: + address = _NSLookup(node) + addresses.append(address) + except errors.ConfigurationError: + # Address not found in so we do a NS lookup + addresses = [_NSLookup(node) for node in node_list] + + return addresses + + class Client: """RPC Client class. @@ -268,13 +304,14 @@ class Client: cause bugs. """ - def __init__(self, procedure, body, port): + def __init__(self, procedure, body, port, address_lookup_fn=_AddressLookup): assert procedure in _TIMEOUTS, ("New RPC call not declared in the" " timeouts table") self.procedure = procedure self.body = body self.port = port self._request = {} + self._address_lookup_fn = address_lookup_fn def ConnectList(self, node_list, address_list=None, read_timeout=None): """Add a list of nodes to the target nodes. @@ -285,15 +322,16 @@ class Client: @keyword address_list: either None or a list with node addresses, which must have the same length as the node list @type read_timeout: int - @param read_timeout: overwrites the default read timeout for the - given operation + @param read_timeout: overwrites default timeout for operation """ if address_list is None: - address_list = [None for _ in node_list] - else: - assert len(node_list) == len(address_list), \ - "Name and address lists should have the same length" + # Always use IP address instead of node name + address_list = self._address_lookup_fn(node_list) + + assert len(node_list) == len(address_list), \ + "Name and address lists must have the same length" + for node, address in zip(node_list, address_list): self.ConnectNode(node, address, read_timeout=read_timeout) @@ -303,11 +341,16 @@ class Client: @type name: str @param name: the node name @type address: str - @keyword address: the node address, if known + @param address: the node address, if known + @type read_timeout: int + @param read_timeout: overwrites default timeout for operation """ if address is None: - address = name + # Always use IP address instead of node name + address = self._address_lookup_fn([name])[0] + + assert(address is not None) if read_timeout is None: read_timeout = _TIMEOUTS[self.procedure] diff --git a/test/ganeti.rpc_unittest.py b/test/ganeti.rpc_unittest.py index acbd1ab0c5e25e9904cccaaade3620c90a9e9d58..d59d790f2453dd5edba3a4e3cc5cae93ff53570b 100755 --- a/test/ganeti.rpc_unittest.py +++ b/test/ganeti.rpc_unittest.py @@ -56,7 +56,17 @@ class FakeHttpPool: self._response_fn(req) +def GetFakeSimpleStoreClass(fn): + class FakeSimpleStore: + GetNodePrimaryIPList = fn + + return FakeSimpleStore + + class TestClient(unittest.TestCase): + def _FakeAddressLookup(self, map): + return lambda node_list: [map.get(node) for node in node_list] + def _GetVersionResponse(self, req): self.assertEqual(req.host, "localhost") self.assertEqual(req.port, 24094) @@ -66,7 +76,8 @@ class TestClient(unittest.TestCase): req.resp_body = serializer.DumpJson((True, 123)) def testVersionSuccess(self): - client = rpc.Client("version", None, 24094) + fn = self._FakeAddressLookup({"localhost": "localhost"}) + client = rpc.Client("version", None, 24094, address_lookup_fn=fn) client.ConnectNode("localhost") pool = FakeHttpPool(self._GetVersionResponse) result = client.GetResults(http_pool=pool) @@ -90,7 +101,8 @@ class TestClient(unittest.TestCase): def testMultiVersionSuccess(self): nodes = ["node%s" % i for i in range(50)] - client = rpc.Client("version", None, 23245) + fn = self._FakeAddressLookup(dict(zip(nodes, nodes))) + client = rpc.Client("version", None, 23245, address_lookup_fn=fn) client.ConnectList(nodes) pool = FakeHttpPool(self._GetMultiVersionResponse) @@ -115,7 +127,9 @@ class TestClient(unittest.TestCase): req.resp_body = serializer.DumpJson((False, "Unknown error")) def testVersionFailure(self): - client = rpc.Client("version", None, 5903) + lookup_map = {"aef9ur4i.example.com": "aef9ur4i.example.com"} + fn = self._FakeAddressLookup(lookup_map) + client = rpc.Client("version", None, 5903, address_lookup_fn=fn) client.ConnectNode("aef9ur4i.example.com") pool = FakeHttpPool(self._GetVersionResponseFail) result = client.GetResults(http_pool=pool) @@ -152,6 +166,7 @@ class TestClient(unittest.TestCase): def testHttpError(self): nodes = ["uaf6pbbv%s" % i for i in range(50)] + fn = self._FakeAddressLookup(dict(zip(nodes, nodes))) httperrnodes = set(nodes[1::7]) self.assertEqual(len(httperrnodes), 7) @@ -161,7 +176,7 @@ class TestClient(unittest.TestCase): self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29) - client = rpc.Client("vg_list", None, 15165) + client = rpc.Client("vg_list", None, 15165, address_lookup_fn=fn) client.ConnectList(nodes) pool = FakeHttpPool(compat.partial(self._GetHttpErrorResponse, @@ -203,7 +218,9 @@ class TestClient(unittest.TestCase): req.resp_body = serializer.DumpJson("invalid response") def testInvalidResponse(self): - client = rpc.Client("version", None, 19978) + lookup_map = {"oqo7lanhly.example.com": "oqo7lanhly.example.com"} + fn = self._FakeAddressLookup(lookup_map) + client = rpc.Client("version", None, 19978, address_lookup_fn=fn) for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]: client.ConnectNode("oqo7lanhly.example.com") pool = FakeHttpPool(fn) @@ -218,6 +235,34 @@ class TestClient(unittest.TestCase): self.assertRaises(errors.OpExecError, lhresp.Raise, "failed") self.assertEqual(pool.reqcount, 1) + def testAddressLookupSimpleStore(self): + addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)] + node_list = ["node%d.example.com" % n for n in range(0, 255, 13)] + node_addr_list = [ " ".join(t) for t in zip(node_list, addr_list)] + ssc = GetFakeSimpleStoreClass(lambda s: node_addr_list) + result = rpc._AddressLookup(node_list, ssc=ssc) + self.assertEqual(result, addr_list) + + def testAddressLookupNSLookup(self): + addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)] + node_list = ["node%d.example.com" % n for n in range(0, 255, 13)] + ssc = GetFakeSimpleStoreClass(lambda s: []) + node_addr_map = dict(zip(node_list, addr_list)) + nslookup_fn = lambda name: (None, None, [node_addr_map.get(name)]) + result = rpc._AddressLookup(node_list, ssc=ssc, nslookup_fn=nslookup_fn) + self.assertEqual(result, addr_list) + + def testAddressLookupBoth(self): + addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)] + node_list = ["node%d.example.com" % n for n in range(0, 255, 13)] + n = len(addr_list) / 2 + node_addr_list = [ " ".join(t) for t in zip(node_list[n:], addr_list[n:])] + ssc = GetFakeSimpleStoreClass(lambda s: node_addr_list) + node_addr_map = dict(zip(node_list[:n], addr_list[:n])) + nslookup_fn = lambda name: (None, None, [node_addr_map.get(name)]) + result = rpc._AddressLookup(node_list, ssc=ssc, nslookup_fn=nslookup_fn) + self.assertEqual(result, addr_list) + if __name__ == "__main__": testutils.GanetiTestProgram()