Skip to content
Snippets Groups Projects
Commit eb202c13 authored by Manuel Franceschini's avatar Manuel Franceschini
Browse files

Always use address instead of hostname in rpc.Client


In light of the upcoming IPv6 support, this patch enables the rpc.Client
to always use a node's address to connect to it. This is necessary as we
do not want to rely on name resolution to connect to the correct IP
address on a dual-stack machine.

Signed-off-by: default avatarManuel Franceschini <livewire@google.com>
Reviewed-by: default avatarGuido Trotter <ultrotter@google.com>
parent d367b66c
No related branches found
No related tags found
No related merge requests found
# #
# #
# Copyright (C) 2006, 2007 Google Inc. # Copyright (C) 2006, 2007, 2010 Google Inc.
# #
# This program is free software; you can redistribute it and/or modify # This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by # it under the terms of the GNU General Public License as published by
...@@ -44,6 +44,7 @@ from ganeti import serializer ...@@ -44,6 +44,7 @@ from ganeti import serializer
from ganeti import constants from ganeti import constants
from ganeti import errors from ganeti import errors
from ganeti import netutils from ganeti import netutils
from ganeti import ssconf
# pylint has a bug here, doesn't see this import # pylint has a bug here, doesn't see this import
import ganeti.http.client # pylint: disable-msg=W0611 import ganeti.http.client # pylint: disable-msg=W0611
...@@ -256,6 +257,41 @@ class RpcResult(object): ...@@ -256,6 +257,41 @@ class RpcResult(object):
raise ec(*args) # pylint: disable-msg=W0142 raise ec(*args) # pylint: disable-msg=W0142
def _AddressLookup(node_list,
ssc=ssconf.SimpleStore,
nslookup_fn=netutils.HostInfo.LookupHostname):
"""Return addresses for given node names.
@type node_list: list
@param node_list: List of node names
@type ssc: class
@param ssc: SimpleStore class that is used to obtain node->ip mappings
@type lookup_fn: callable
@param lookup_fn: function use to do NS lookup
@rtype: list of addresses and/or None's
@returns: List of corresponding addresses, if found
"""
def _NSLookup(name):
_, _, addrs = nslookup_fn(name)
return addrs[0]
addresses = []
try:
iplist = ssc().GetNodePrimaryIPList()
ipmap = dict(entry.split() for entry in iplist)
for node in node_list:
address = ipmap.get(node)
if address is None:
address = _NSLookup(node)
addresses.append(address)
except errors.ConfigurationError:
# Address not found in so we do a NS lookup
addresses = [_NSLookup(node) for node in node_list]
return addresses
class Client: class Client:
"""RPC Client class. """RPC Client class.
...@@ -268,13 +304,14 @@ class Client: ...@@ -268,13 +304,14 @@ class Client:
cause bugs. cause bugs.
""" """
def __init__(self, procedure, body, port): def __init__(self, procedure, body, port, address_lookup_fn=_AddressLookup):
assert procedure in _TIMEOUTS, ("New RPC call not declared in the" assert procedure in _TIMEOUTS, ("New RPC call not declared in the"
" timeouts table") " timeouts table")
self.procedure = procedure self.procedure = procedure
self.body = body self.body = body
self.port = port self.port = port
self._request = {} self._request = {}
self._address_lookup_fn = address_lookup_fn
def ConnectList(self, node_list, address_list=None, read_timeout=None): def ConnectList(self, node_list, address_list=None, read_timeout=None):
"""Add a list of nodes to the target nodes. """Add a list of nodes to the target nodes.
...@@ -285,15 +322,16 @@ class Client: ...@@ -285,15 +322,16 @@ class Client:
@keyword address_list: either None or a list with node addresses, @keyword address_list: either None or a list with node addresses,
which must have the same length as the node list which must have the same length as the node list
@type read_timeout: int @type read_timeout: int
@param read_timeout: overwrites the default read timeout for the @param read_timeout: overwrites default timeout for operation
given operation
""" """
if address_list is None: if address_list is None:
address_list = [None for _ in node_list] # Always use IP address instead of node name
else: address_list = self._address_lookup_fn(node_list)
assert len(node_list) == len(address_list), \
"Name and address lists should have the same length" assert len(node_list) == len(address_list), \
"Name and address lists must have the same length"
for node, address in zip(node_list, address_list): for node, address in zip(node_list, address_list):
self.ConnectNode(node, address, read_timeout=read_timeout) self.ConnectNode(node, address, read_timeout=read_timeout)
...@@ -303,11 +341,16 @@ class Client: ...@@ -303,11 +341,16 @@ class Client:
@type name: str @type name: str
@param name: the node name @param name: the node name
@type address: str @type address: str
@keyword address: the node address, if known @param address: the node address, if known
@type read_timeout: int
@param read_timeout: overwrites default timeout for operation
""" """
if address is None: if address is None:
address = name # Always use IP address instead of node name
address = self._address_lookup_fn([name])[0]
assert(address is not None)
if read_timeout is None: if read_timeout is None:
read_timeout = _TIMEOUTS[self.procedure] read_timeout = _TIMEOUTS[self.procedure]
......
...@@ -56,7 +56,17 @@ class FakeHttpPool: ...@@ -56,7 +56,17 @@ class FakeHttpPool:
self._response_fn(req) self._response_fn(req)
def GetFakeSimpleStoreClass(fn):
class FakeSimpleStore:
GetNodePrimaryIPList = fn
return FakeSimpleStore
class TestClient(unittest.TestCase): class TestClient(unittest.TestCase):
def _FakeAddressLookup(self, map):
return lambda node_list: [map.get(node) for node in node_list]
def _GetVersionResponse(self, req): def _GetVersionResponse(self, req):
self.assertEqual(req.host, "localhost") self.assertEqual(req.host, "localhost")
self.assertEqual(req.port, 24094) self.assertEqual(req.port, 24094)
...@@ -66,7 +76,8 @@ class TestClient(unittest.TestCase): ...@@ -66,7 +76,8 @@ class TestClient(unittest.TestCase):
req.resp_body = serializer.DumpJson((True, 123)) req.resp_body = serializer.DumpJson((True, 123))
def testVersionSuccess(self): def testVersionSuccess(self):
client = rpc.Client("version", None, 24094) fn = self._FakeAddressLookup({"localhost": "localhost"})
client = rpc.Client("version", None, 24094, address_lookup_fn=fn)
client.ConnectNode("localhost") client.ConnectNode("localhost")
pool = FakeHttpPool(self._GetVersionResponse) pool = FakeHttpPool(self._GetVersionResponse)
result = client.GetResults(http_pool=pool) result = client.GetResults(http_pool=pool)
...@@ -90,7 +101,8 @@ class TestClient(unittest.TestCase): ...@@ -90,7 +101,8 @@ class TestClient(unittest.TestCase):
def testMultiVersionSuccess(self): def testMultiVersionSuccess(self):
nodes = ["node%s" % i for i in range(50)] nodes = ["node%s" % i for i in range(50)]
client = rpc.Client("version", None, 23245) fn = self._FakeAddressLookup(dict(zip(nodes, nodes)))
client = rpc.Client("version", None, 23245, address_lookup_fn=fn)
client.ConnectList(nodes) client.ConnectList(nodes)
pool = FakeHttpPool(self._GetMultiVersionResponse) pool = FakeHttpPool(self._GetMultiVersionResponse)
...@@ -115,7 +127,9 @@ class TestClient(unittest.TestCase): ...@@ -115,7 +127,9 @@ class TestClient(unittest.TestCase):
req.resp_body = serializer.DumpJson((False, "Unknown error")) req.resp_body = serializer.DumpJson((False, "Unknown error"))
def testVersionFailure(self): def testVersionFailure(self):
client = rpc.Client("version", None, 5903) lookup_map = {"aef9ur4i.example.com": "aef9ur4i.example.com"}
fn = self._FakeAddressLookup(lookup_map)
client = rpc.Client("version", None, 5903, address_lookup_fn=fn)
client.ConnectNode("aef9ur4i.example.com") client.ConnectNode("aef9ur4i.example.com")
pool = FakeHttpPool(self._GetVersionResponseFail) pool = FakeHttpPool(self._GetVersionResponseFail)
result = client.GetResults(http_pool=pool) result = client.GetResults(http_pool=pool)
...@@ -152,6 +166,7 @@ class TestClient(unittest.TestCase): ...@@ -152,6 +166,7 @@ class TestClient(unittest.TestCase):
def testHttpError(self): def testHttpError(self):
nodes = ["uaf6pbbv%s" % i for i in range(50)] nodes = ["uaf6pbbv%s" % i for i in range(50)]
fn = self._FakeAddressLookup(dict(zip(nodes, nodes)))
httperrnodes = set(nodes[1::7]) httperrnodes = set(nodes[1::7])
self.assertEqual(len(httperrnodes), 7) self.assertEqual(len(httperrnodes), 7)
...@@ -161,7 +176,7 @@ class TestClient(unittest.TestCase): ...@@ -161,7 +176,7 @@ class TestClient(unittest.TestCase):
self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29) self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
client = rpc.Client("vg_list", None, 15165) client = rpc.Client("vg_list", None, 15165, address_lookup_fn=fn)
client.ConnectList(nodes) client.ConnectList(nodes)
pool = FakeHttpPool(compat.partial(self._GetHttpErrorResponse, pool = FakeHttpPool(compat.partial(self._GetHttpErrorResponse,
...@@ -203,7 +218,9 @@ class TestClient(unittest.TestCase): ...@@ -203,7 +218,9 @@ class TestClient(unittest.TestCase):
req.resp_body = serializer.DumpJson("invalid response") req.resp_body = serializer.DumpJson("invalid response")
def testInvalidResponse(self): def testInvalidResponse(self):
client = rpc.Client("version", None, 19978) lookup_map = {"oqo7lanhly.example.com": "oqo7lanhly.example.com"}
fn = self._FakeAddressLookup(lookup_map)
client = rpc.Client("version", None, 19978, address_lookup_fn=fn)
for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]: for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
client.ConnectNode("oqo7lanhly.example.com") client.ConnectNode("oqo7lanhly.example.com")
pool = FakeHttpPool(fn) pool = FakeHttpPool(fn)
...@@ -218,6 +235,34 @@ class TestClient(unittest.TestCase): ...@@ -218,6 +235,34 @@ class TestClient(unittest.TestCase):
self.assertRaises(errors.OpExecError, lhresp.Raise, "failed") self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
self.assertEqual(pool.reqcount, 1) self.assertEqual(pool.reqcount, 1)
def testAddressLookupSimpleStore(self):
addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
node_addr_list = [ " ".join(t) for t in zip(node_list, addr_list)]
ssc = GetFakeSimpleStoreClass(lambda s: node_addr_list)
result = rpc._AddressLookup(node_list, ssc=ssc)
self.assertEqual(result, addr_list)
def testAddressLookupNSLookup(self):
addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
ssc = GetFakeSimpleStoreClass(lambda s: [])
node_addr_map = dict(zip(node_list, addr_list))
nslookup_fn = lambda name: (None, None, [node_addr_map.get(name)])
result = rpc._AddressLookup(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
self.assertEqual(result, addr_list)
def testAddressLookupBoth(self):
addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
n = len(addr_list) / 2
node_addr_list = [ " ".join(t) for t in zip(node_list[n:], addr_list[n:])]
ssc = GetFakeSimpleStoreClass(lambda s: node_addr_list)
node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
nslookup_fn = lambda name: (None, None, [node_addr_map.get(name)])
result = rpc._AddressLookup(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
self.assertEqual(result, addr_list)
if __name__ == "__main__": if __name__ == "__main__":
testutils.GanetiTestProgram() testutils.GanetiTestProgram()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment