Commit 00267bfe authored by Michael Hanselmann's avatar Michael Hanselmann
Browse files

rpc: Overhaul client structure



- Clearly separate node name to IP address resolution into separate
  functions
- Simplified code structure (one code path instead of several)
- Fully unittested
- Preparation for more RPC improvements
Signed-off-by: default avatarMichael Hanselmann <hansmi@google.com>
Reviewed-by: default avatarIustin Pop <iustin@google.com>
parent 30474135
......@@ -46,6 +46,7 @@ from ganeti import errors
from ganeti import netutils
from ganeti import ssconf
from ganeti import runtime
from ganeti import compat
# pylint has a bug here, doesn't see this import
import ganeti.http.client # pylint: disable=W0611
......@@ -77,6 +78,9 @@ _TMO_1DAY = 86400
_TIMEOUTS = {
}
#: Special value to describe an offline host
_OFFLINE = object()
def Init():
"""Initializes the module-global HTTP client manager.
......@@ -285,9 +289,9 @@ class RpcResult(object):
raise ec(*args) # pylint: disable=W0142
def _AddressLookup(node_list,
ssc=ssconf.SimpleStore,
nslookup_fn=netutils.Hostname.GetIP):
def _SsconfResolver(node_list,
ssc=ssconf.SimpleStore,
nslookup_fn=netutils.Hostname.GetIP):
"""Return addresses for given node names.
@type node_list: list
......@@ -296,126 +300,163 @@ def _AddressLookup(node_list,
@param ssc: SimpleStore class that is used to obtain node->ip mappings
@type nslookup_fn: callable
@param nslookup_fn: function use to do NS lookup
@rtype: list of addresses and/or None's
@returns: List of corresponding addresses, if found
@rtype: list of tuple; (string, string)
@return: List of tuples containing node name and IP address
"""
ss = ssc()
iplist = ss.GetNodePrimaryIPList()
family = ss.GetPrimaryIPFamily()
addresses = []
ipmap = dict(entry.split() for entry in iplist)
result = []
for node in node_list:
address = ipmap.get(node)
if address is None:
address = nslookup_fn(node, family=family)
addresses.append(address)
ip = ipmap.get(node)
if ip is None:
ip = nslookup_fn(node, family=family)
result.append((node, ip))
return result
class _StaticResolver:
def __init__(self, addresses):
"""Initializes this class.
"""
self._addresses = addresses
return addresses
def __call__(self, hosts):
"""Returns static addresses for hosts.
"""
assert len(hosts) == len(self._addresses)
return zip(hosts, self._addresses)
class Client:
"""RPC Client class.
This class, given a (remote) method name, a list of parameters and a
list of nodes, will contact (in parallel) all nodes, and return a
dict of results (key: node name, value: result).
def _CheckConfigNode(name, node):
"""Checks if a node is online.
One current bug is that generic failure is still signaled by
'False' result, which is not good. This overloading of values can
cause bugs.
@type name: string
@param name: Node name
@type node: L{objects.Node} or None
@param node: Node object
"""
def __init__(self, procedure, body, port, address_lookup_fn=_AddressLookup):
assert procedure in _TIMEOUTS, ("New RPC call not declared in the"
" timeouts table")
self.procedure = procedure
self.body = body
self.port = port
self._request = {}
self._address_lookup_fn = address_lookup_fn
def ConnectList(self, node_list, address_list=None, read_timeout=None):
"""Add a list of nodes to the target nodes.
if node is None:
# Depend on DNS for name resolution
ip = name
elif node.offline:
ip = _OFFLINE
else:
ip = node.primary_ip
return (name, ip)
@type node_list: list
@param node_list: the list of node names to connect
@type address_list: list or None
@keyword address_list: either None or a list with node addresses,
which must have the same length as the node list
@type read_timeout: int
@param read_timeout: overwrites default timeout for operation
def _NodeConfigResolver(single_node_fn, all_nodes_fn, hosts):
"""Calculate node addresses using configuration.
"""
# Special case for single-host lookups
if len(hosts) == 1:
(name, ) = hosts
return [_CheckConfigNode(name, single_node_fn(name))]
else:
all_nodes = all_nodes_fn()
return [_CheckConfigNode(name, all_nodes.get(name, None))
for name in hosts]
class _RpcProcessor:
def __init__(self, resolver, port):
"""Initializes this class.
@param resolver: callable accepting a list of hostnames, returning a list
of tuples containing name and IP address (IP address can be the name or
the special value L{_OFFLINE} to mark offline machines)
@type port: int
@param port: TCP port
"""
if address_list is None:
# Always use IP address instead of node name
address_list = self._address_lookup_fn(node_list)
self._resolver = resolver
self._port = port
@staticmethod
def _PrepareRequests(hosts, port, procedure, body, read_timeout):
"""Prepares requests by sorting offline hosts into separate list.
assert len(node_list) == len(address_list), \
"Name and address lists must have the same length"
"""
results = {}
requests = {}
for node, address in zip(node_list, address_list):
self.ConnectNode(node, address, read_timeout=read_timeout)
for (name, ip) in hosts:
if ip is _OFFLINE:
# Node is marked as offline
results[name] = RpcResult(node=name, offline=True, call=procedure)
else:
requests[name] = \
http.client.HttpClientRequest(str(ip), port,
http.HTTP_PUT, str("/%s" % procedure),
headers=_RPC_CLIENT_HEADERS,
post_data=body,
read_timeout=read_timeout)
def ConnectNode(self, name, address=None, read_timeout=None):
"""Add a node to the target list.
return (results, requests)
@type name: str
@param name: the node name
@type address: str
@param address: the node address, if known
@type read_timeout: int
@param read_timeout: overwrites default timeout for operation
@staticmethod
def _CombineResults(results, requests, procedure):
"""Combines pre-computed results for offline hosts with actual call results.
"""
if address is None:
# Always use IP address instead of node name
address = self._address_lookup_fn([name])[0]
for name, req in requests.items():
if req.success and req.resp_status_code == http.HTTP_OK:
host_result = RpcResult(data=serializer.LoadJson(req.resp_body),
node=name, call=procedure)
else:
# TODO: Better error reporting
if req.error:
msg = req.error
else:
msg = req.resp_body
assert(address is not None)
logging.error("RPC error in %s on node %s: %s", procedure, name, msg)
host_result = RpcResult(data=msg, failed=True, node=name,
call=procedure)
if read_timeout is None:
read_timeout = _TIMEOUTS[self.procedure]
results[name] = host_result
self._request[name] = \
http.client.HttpClientRequest(str(address), self.port,
http.HTTP_PUT, str("/%s" % self.procedure),
headers=_RPC_CLIENT_HEADERS,
post_data=str(self.body),
read_timeout=read_timeout)
return results
def GetResults(self, http_pool=None):
"""Call nodes and return results.
def __call__(self, hosts, procedure, body, read_timeout=None, http_pool=None):
"""Makes an RPC request to a number of nodes.
@rtype: list
@return: List of RPC results
@type hosts: sequence
@param hosts: Hostnames
@type procedure: string
@param procedure: Request path
@type body: string
@param body: Request body
@type read_timeout: int or None
@param read_timeout: Read timeout for request
"""
assert procedure in _TIMEOUTS, "RPC call not declared in the timeouts table"
if not http_pool:
http_pool = _thread_local.GetHttpClientPool()
http_pool.ProcessRequests(self._request.values())
results = {}
if read_timeout is None:
read_timeout = _TIMEOUTS[procedure]
for name, req in self._request.iteritems():
if req.success and req.resp_status_code == http.HTTP_OK:
results[name] = RpcResult(data=serializer.LoadJson(req.resp_body),
node=name, call=self.procedure)
continue
(results, requests) = \
self._PrepareRequests(self._resolver(hosts), self._port, procedure,
str(body), read_timeout)
# TODO: Better error reporting
if req.error:
msg = req.error
else:
msg = req.resp_body
http_pool.ProcessRequests(requests.values())
logging.error("RPC error in %s from node %s: %s",
self.procedure, name, msg)
results[name] = RpcResult(data=msg, failed=True, node=name,
call=self.procedure)
assert not frozenset(results).intersection(requests)
return results
return self._CombineResults(results, requests, procedure)
def _EncodeImportExportIO(ieio, ieioargs):
......@@ -445,7 +486,10 @@ class RpcRunner(object):
"""
self._cfg = context.cfg
self.port = netutils.GetDaemonPort(constants.NODED)
self._proc = _RpcProcessor(compat.partial(_NodeConfigResolver,
self._cfg.GetNodeInfo,
self._cfg.GetAllNodesInfo),
netutils.GetDaemonPort(constants.NODED))
def _InstDict(self, instance, hvp=None, bep=None, osp=None):
"""Convert the given instance to a dict.
......@@ -483,98 +527,37 @@ class RpcRunner(object):
nic['nicparams'])
return idict
def _ConnectList(self, client, node_list, call, read_timeout=None):
"""Helper for computing node addresses.
@type client: L{ganeti.rpc.Client}
@param client: a C{Client} instance
@type node_list: list
@param node_list: the node list we should connect
@type call: string
@param call: the name of the remote procedure call, for filling in
correctly any eventual offline nodes' results
@type read_timeout: int
@param read_timeout: overwrites the default read timeout for the
given operation
"""
all_nodes = self._cfg.GetAllNodesInfo()
name_list = []
addr_list = []
skip_dict = {}
for node in node_list:
if node in all_nodes:
if all_nodes[node].offline:
skip_dict[node] = RpcResult(node=node, offline=True, call=call)
continue
val = all_nodes[node].primary_ip
else:
val = None
addr_list.append(val)
name_list.append(node)
if name_list:
client.ConnectList(name_list, address_list=addr_list,
read_timeout=read_timeout)
return skip_dict
def _ConnectNode(self, client, node, call, read_timeout=None):
"""Helper for computing one node's address.
@type client: L{ganeti.rpc.Client}
@param client: a C{Client} instance
@type node: str
@param node: the node we should connect
@type call: string
@param call: the name of the remote procedure call, for filling in
correctly any eventual offline nodes' results
@type read_timeout: int
@param read_timeout: overwrites the default read timeout for the
given operation
"""
node_info = self._cfg.GetNodeInfo(node)
if node_info is not None:
if node_info.offline:
return RpcResult(node=node, offline=True, call=call)
addr = node_info.primary_ip
else:
addr = None
client.ConnectNode(node, address=addr, read_timeout=read_timeout)
def _MultiNodeCall(self, node_list, procedure, args, read_timeout=None):
"""Helper for making a multi-node call
"""
body = serializer.DumpJson(args, indent=False)
c = Client(procedure, body, self.port)
skip_dict = self._ConnectList(c, node_list, procedure,
read_timeout=read_timeout)
skip_dict.update(c.GetResults())
return skip_dict
return self._proc(node_list, procedure, body, read_timeout=read_timeout)
@classmethod
def _StaticMultiNodeCall(cls, node_list, procedure, args,
@staticmethod
def _StaticMultiNodeCall(node_list, procedure, args,
address_list=None, read_timeout=None):
"""Helper for making a multi-node static call
"""
body = serializer.DumpJson(args, indent=False)
c = Client(procedure, body, netutils.GetDaemonPort(constants.NODED))
c.ConnectList(node_list, address_list=address_list,
read_timeout=read_timeout)
return c.GetResults()
if address_list is None:
resolver = _SsconfResolver
else:
# Caller provided an address list
resolver = _StaticResolver(address_list)
proc = _RpcProcessor(resolver,
netutils.GetDaemonPort(constants.NODED))
return proc(node_list, procedure, body, read_timeout=read_timeout)
def _SingleNodeCall(self, node, procedure, args, read_timeout=None):
"""Helper for making a single-node call
"""
body = serializer.DumpJson(args, indent=False)
c = Client(procedure, body, self.port)
result = self._ConnectNode(c, node, procedure, read_timeout=read_timeout)
if result is None:
# we did connect, node is not offline
result = c.GetResults()[node]
return result
return self._proc([node], procedure, body, read_timeout=read_timeout)[node]
@classmethod
def _StaticSingleNodeCall(cls, node, procedure, args, read_timeout=None):
......@@ -582,9 +565,9 @@ class RpcRunner(object):
"""
body = serializer.DumpJson(args, indent=False)
c = Client(procedure, body, netutils.GetDaemonPort(constants.NODED))
c.ConnectNode(node, read_timeout=read_timeout)
return c.GetResults()[node]
proc = _RpcProcessor(_SsconfResolver,
netutils.GetDaemonPort(constants.NODED))
return proc([node], procedure, body, read_timeout=read_timeout)[node]
#
# Begin RPC calls
......
......@@ -31,6 +31,7 @@ from ganeti import rpc
from ganeti import http
from ganeti import errors
from ganeti import serializer
from ganeti import objects
import testutils
......@@ -64,24 +65,24 @@ def GetFakeSimpleStoreClass(fn):
return FakeSimpleStore
class TestClient(unittest.TestCase):
class TestRpcProcessor(unittest.TestCase):
def _FakeAddressLookup(self, map):
return lambda node_list: [map.get(node) for node in node_list]
def _GetVersionResponse(self, req):
self.assertEqual(req.host, "localhost")
self.assertEqual(req.host, "127.0.0.1")
self.assertEqual(req.port, 24094)
self.assertEqual(req.path, "/version")
self.assertEqual(req.read_timeout, rpc._TMO_URGENT)
req.success = True
req.resp_status_code = http.HTTP_OK
req.resp_body = serializer.DumpJson((True, 123))
def testVersionSuccess(self):
fn = self._FakeAddressLookup({"localhost": "localhost"})
client = rpc.Client("version", None, 24094, address_lookup_fn=fn)
client.ConnectNode("localhost")
resolver = rpc._StaticResolver(["127.0.0.1"])
pool = FakeHttpPool(self._GetVersionResponse)
result = client.GetResults(http_pool=pool)
proc = rpc._RpcProcessor(resolver, 24094)
result = proc(["localhost"], "version", None, http_pool=pool)
self.assertEqual(result.keys(), ["localhost"])
lhresp = result["localhost"]
self.assertFalse(lhresp.offline)
......@@ -92,6 +93,52 @@ class TestClient(unittest.TestCase):
lhresp.Raise("should not raise")
self.assertEqual(pool.reqcount, 1)
def _ReadTimeoutResponse(self, req):
self.assertEqual(req.host, "192.0.2.13")
self.assertEqual(req.port, 19176)
self.assertEqual(req.path, "/version")
self.assertEqual(req.read_timeout, 12356)
req.success = True
req.resp_status_code = http.HTTP_OK
req.resp_body = serializer.DumpJson((True, -1))
def testReadTimeout(self):
resolver = rpc._StaticResolver(["192.0.2.13"])
pool = FakeHttpPool(self._ReadTimeoutResponse)
proc = rpc._RpcProcessor(resolver, 19176)
result = proc(["node31856"], "version", None, http_pool=pool,
read_timeout=12356)
self.assertEqual(result.keys(), ["node31856"])
lhresp = result["node31856"]
self.assertFalse(lhresp.offline)
self.assertEqual(lhresp.node, "node31856")
self.assertFalse(lhresp.fail_msg)
self.assertEqual(lhresp.payload, -1)
self.assertEqual(lhresp.call, "version")
lhresp.Raise("should not raise")
self.assertEqual(pool.reqcount, 1)
def testOfflineNode(self):
resolver = rpc._StaticResolver([rpc._OFFLINE])
pool = FakeHttpPool(NotImplemented)
proc = rpc._RpcProcessor(resolver, 30668)
result = proc(["n17296"], "version", None, http_pool=pool)
self.assertEqual(result.keys(), ["n17296"])
lhresp = result["n17296"]
self.assertTrue(lhresp.offline)
self.assertEqual(lhresp.node, "n17296")
self.assertTrue(lhresp.fail_msg)
self.assertFalse(lhresp.payload)
self.assertEqual(lhresp.call, "version")
# With a message
self.assertRaises(errors.OpExecError, lhresp.Raise, "should raise")
# No message
self.assertRaises(errors.OpExecError, lhresp.Raise, None)
self.assertEqual(pool.reqcount, 0)
def _GetMultiVersionResponse(self, req):
self.assert_(req.host.startswith("node"))
self.assertEqual(req.port, 23245)
......@@ -102,12 +149,10 @@ class TestClient(unittest.TestCase):
def testMultiVersionSuccess(self):
nodes = ["node%s" % i for i in range(50)]
fn = self._FakeAddressLookup(dict(zip(nodes, nodes)))
client = rpc.Client("version", None, 23245, address_lookup_fn=fn)
client.ConnectList(nodes)
resolver = rpc._StaticResolver(nodes)
pool = FakeHttpPool(self._GetMultiVersionResponse)
result = client.GetResults(http_pool=pool)
proc = rpc._RpcProcessor(resolver, 23245)
result = proc(nodes, "version", None, http_pool=pool)
self.assertEqual(sorted(result.keys()), sorted(nodes))
for name in nodes:
......@@ -121,28 +166,27 @@ class TestClient(unittest.TestCase):
self.assertEqual(pool.reqcount, len(nodes))
def _GetVersionResponseFail(self, req):
def _GetVersionResponseFail(self, errinfo, req):
self.assertEqual(req.path, "/version")
req.success = True
req.resp_status_code = http.HTTP_OK
req.resp_body = serializer.DumpJson((False, "Unknown error"))
req.resp_body = serializer.DumpJson((False, errinfo))
def testVersionFailure(self):
lookup_map = {"aef9ur4i.example.com": "aef9ur4i.example.com"}
fn = self._FakeAddressLookup(lookup_map)
client = rpc.Client("version", None, 5903, address_lookup_fn=fn)
client.ConnectNode("aef9ur4i.example.com")
pool = FakeHttpPool(self._GetVersionResponseFail)
result = client.GetResults(http_pool=pool)
self.assertEqual(result.keys(), ["aef9ur4i.example.com"])
lhresp = result["aef9ur4i.example.com"]
self.assertFalse(lhresp.offline)
self.assertEqual(lhresp.node, "aef9ur4i.example.com")
self.assert_(lhresp.fail_msg)
self.assertFalse(lhresp.payload)
self.assertEqual(lhresp.call, "version")
self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
self.assertEqual(pool.reqcount, 1)
resolver = rpc._StaticResolver(["aef9ur4i.example.com"])
proc = rpc._RpcProcessor(resolver, 5903)
for errinfo in [None, "Unknown error"]:
pool = FakeHttpPool(compat.partial(self._GetVersionResponseFail, errinfo))
result = proc(["aef9ur4i.example.com"], "version", None, http_pool=pool)
self.assertEqual(result.keys(), ["aef9ur4i.example.com"])
lhresp = result["aef9ur4i.example.com"]
self.assertFalse(lhresp.offline)
self.assertEqual(lhresp.node, "aef9ur4i.example.com")
self.assert_(lhresp.fail_msg)
self.assertFalse(lhresp.payload)
self.assertEqual(lhresp.call, "version")
self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
self.assertEqual(pool.reqcount, 1)
def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
self.assertEqual(req.path, "/vg_list")
......@@ -167,7 +211,7 @@ class TestClient(unittest.TestCase):
def testHttpError(self):
nodes = ["uaf6pbbv%s" % i for i in range(50)]
fn = self._FakeAddressLookup(dict(zip(nodes, nodes)))
resolver = rpc._StaticResolver(nodes)
httperrnodes = set(nodes[1::7])
self.assertEqual(len(httperrnodes), 7)
......@@ -177,12 +221,10 @@ class TestClient(unittest.TestCase):
self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
client = rpc.Client("vg_list", None, 15165, address_lookup_fn=fn)
client.ConnectList(nodes)
proc = rpc._RpcProcessor(resolver, 15165)
pool = FakeHttpPool(compat.partial(self._GetHttpErrorResponse,
httperrnodes, failnodes))
result = client.GetResults(http_pool=pool)
result = proc(nodes, "vg_list", None, http_pool=pool)
self.assertEqual(sorted(result.keys()), sorted(nodes))
for name in nodes:
......@@ -219,13 +261,12 @@ class TestClient(unittest.TestCase):
req.resp_body = serializer.DumpJson("invalid response")
def testInvalidResponse(self):
lookup_map = {"oqo7lanhly.example.com": "oqo7lanhly.example.com"}
fn = self._FakeAddressLookup(lookup_map)
client = rpc.Client("version", None, 19978, address_lookup_fn=fn)
resolver = rpc._StaticResolver(["oqo7lanhly.example.com"])
proc = rpc._RpcProcessor(resolver, 19978)
for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
client.ConnectNode("oqo7lanhly.example.com")
pool = FakeHttpPool(fn)
result = client.GetResults(http_pool=pool)
result = proc(["oqo7lanhly.example.com"], "version", None, http_pool=pool)
self.assertEqual(result.keys(), ["oqo7lanhly.example.com"])
lhresp = result["oqo7lanhly.example.com"]
self.assertFalse(lhresp.offline)
......@@ -236,41 +277,145 @@ class TestClient(unittest.TestCase):
self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
self.assertEqual(pool.reqcount, 1)
def testAddressLookupSimpleStore(self):
def _GetBodyTestResponse(self, test_data, req):
self.assertEqual(req.host, "192.0.2.84")
self.assertEqual(req.port, 18700)