diff --git a/lib/rpc.py b/lib/rpc.py index 2f65ad53f1944506a7afae6fe0e2aa5fe32280b3..bf05cfb25e31571d41ab0366d046b816a955e718 100644 --- a/lib/rpc.py +++ b/lib/rpc.py @@ -338,10 +338,19 @@ class _RpcProcessor: def _PrepareRequests(hosts, port, procedure, body, read_timeout): """Prepares requests by sorting offline hosts into separate list. + @type body: dict + @param body: a dictionary with per-host body data + """ results = {} requests = {} + assert isinstance(body, dict) + assert len(body) == len(hosts) + assert compat.all(isinstance(v, str) for v in body.values()) + assert frozenset(map(compat.fst, hosts)) == frozenset(body.keys()), \ + "%s != %s" % (hosts, body.keys()) + for (name, ip) in hosts: if ip is _OFFLINE: # Node is marked as offline @@ -351,7 +360,7 @@ class _RpcProcessor: http.client.HttpClientRequest(str(ip), port, http.HTTP_PUT, str("/%s" % procedure), headers=_RPC_CLIENT_HEADERS, - post_data=body, + post_data=body[name], read_timeout=read_timeout, nicename="%s/%s" % (name, procedure), curl_config_fn=_ConfigRpcCurl) @@ -390,8 +399,8 @@ class _RpcProcessor: @param hosts: Hostnames @type procedure: string @param procedure: Request path - @type body: string - @param body: Request body + @type body: dictionary + @param body: dictionary with request bodies per host @type read_timeout: int or None @param read_timeout: Read timeout for request @@ -401,7 +410,7 @@ class _RpcProcessor: (results, requests) = \ self._PrepareRequests(self._resolver(hosts), self._port, procedure, - str(body), read_timeout) + body, read_timeout) _req_process_fn(requests.values(), lock_monitor_cb=self._lock_monitor_cb) @@ -434,17 +443,28 @@ class _RpcClientBase: """Entry point for automatically generated RPC wrappers. """ - (procedure, _, timeout, argdefs, _, postproc_fn, _) = cdef + (procedure, _, timeout, argdefs, prep_fn, postproc_fn, _) = cdef if callable(timeout): read_timeout = timeout(args) else: read_timeout = timeout - body = serializer.DumpJson(map(self._encoder, - zip(map(compat.snd, argdefs), args))) - - result = self._proc(node_list, procedure, body, read_timeout=read_timeout) + enc_args = map(self._encoder, zip(map(compat.snd, argdefs), args)) + if prep_fn is None: + # for a no-op prep_fn, we serialise the body once, and then we + # reuse it in the dictionary values + body = serializer.DumpJson(enc_args) + pnbody = dict((n, body) for n in node_list) + else: + # for a custom prep_fn, we pass the encoded arguments and the + # node name to the prep_fn, and we serialise its return value + assert(callable(prep_fn)) + pnbody = dict((n, serializer.DumpJson(prep_fn(n, enc_args))) + for n in node_list) + + result = self._proc(node_list, procedure, pnbody, + read_timeout=read_timeout) if postproc_fn: return dict(map(lambda (key, value): (key, postproc_fn(value)), diff --git a/test/ganeti.rpc_unittest.py b/test/ganeti.rpc_unittest.py index c434ad46e365d217609cf0468306ad1ccd1556b7..e16802eb935f153d8e3d5fcfd335ea192d19f8e8 100755 --- a/test/ganeti.rpc_unittest.py +++ b/test/ganeti.rpc_unittest.py @@ -1,7 +1,7 @@ #!/usr/bin/python # -# Copyright (C) 2010 Google Inc. +# Copyright (C) 2010, 2011 Google Inc. # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -73,8 +73,8 @@ class TestRpcProcessor(unittest.TestCase): resolver = rpc._StaticResolver(["127.0.0.1"]) http_proc = _FakeRequestProcessor(self._GetVersionResponse) proc = rpc._RpcProcessor(resolver, 24094) - result = proc(["localhost"], "version", None, _req_process_fn=http_proc, - read_timeout=60) + result = proc(["localhost"], "version", {"localhost": ""}, + _req_process_fn=http_proc, read_timeout=60) self.assertEqual(result.keys(), ["localhost"]) lhresp = result["localhost"] self.assertFalse(lhresp.offline) @@ -98,12 +98,14 @@ class TestRpcProcessor(unittest.TestCase): resolver = rpc._StaticResolver(["192.0.2.13"]) http_proc = _FakeRequestProcessor(self._ReadTimeoutResponse) proc = rpc._RpcProcessor(resolver, 19176) - result = proc(["node31856"], "version", None, _req_process_fn=http_proc, + host = "node31856" + body = {host: ""} + result = proc([host], "version", body, _req_process_fn=http_proc, read_timeout=12356) - self.assertEqual(result.keys(), ["node31856"]) - lhresp = result["node31856"] + self.assertEqual(result.keys(), [host]) + lhresp = result[host] self.assertFalse(lhresp.offline) - self.assertEqual(lhresp.node, "node31856") + self.assertEqual(lhresp.node, host) self.assertFalse(lhresp.fail_msg) self.assertEqual(lhresp.payload, -1) self.assertEqual(lhresp.call, "version") @@ -114,12 +116,14 @@ class TestRpcProcessor(unittest.TestCase): resolver = rpc._StaticResolver([rpc._OFFLINE]) http_proc = _FakeRequestProcessor(NotImplemented) proc = rpc._RpcProcessor(resolver, 30668) - result = proc(["n17296"], "version", None, _req_process_fn=http_proc, + host = "n17296" + body = {host: ""} + result = proc([host], "version", body, _req_process_fn=http_proc, read_timeout=60) - self.assertEqual(result.keys(), ["n17296"]) - lhresp = result["n17296"] + self.assertEqual(result.keys(), [host]) + lhresp = result[host] self.assertTrue(lhresp.offline) - self.assertEqual(lhresp.node, "n17296") + self.assertEqual(lhresp.node, host) self.assertTrue(lhresp.fail_msg) self.assertFalse(lhresp.payload) self.assertEqual(lhresp.call, "version") @@ -142,10 +146,11 @@ class TestRpcProcessor(unittest.TestCase): def testMultiVersionSuccess(self): nodes = ["node%s" % i for i in range(50)] + body = dict((n, "") for n in nodes) resolver = rpc._StaticResolver(nodes) http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse) proc = rpc._RpcProcessor(resolver, 23245) - result = proc(nodes, "version", None, _req_process_fn=http_proc, + result = proc(nodes, "version", body, _req_process_fn=http_proc, read_timeout=60) self.assertEqual(sorted(result.keys()), sorted(nodes)) @@ -173,12 +178,14 @@ class TestRpcProcessor(unittest.TestCase): http_proc = \ _FakeRequestProcessor(compat.partial(self._GetVersionResponseFail, errinfo)) - result = proc(["aef9ur4i.example.com"], "version", None, + host = "aef9ur4i.example.com" + body = {host: ""} + result = proc(body.keys(), "version", body, _req_process_fn=http_proc, read_timeout=60) - self.assertEqual(result.keys(), ["aef9ur4i.example.com"]) - lhresp = result["aef9ur4i.example.com"] + self.assertEqual(result.keys(), [host]) + lhresp = result[host] self.assertFalse(lhresp.offline) - self.assertEqual(lhresp.node, "aef9ur4i.example.com") + self.assertEqual(lhresp.node, host) self.assert_(lhresp.fail_msg) self.assertFalse(lhresp.payload) self.assertEqual(lhresp.call, "version") @@ -208,6 +215,7 @@ class TestRpcProcessor(unittest.TestCase): def testHttpError(self): nodes = ["uaf6pbbv%s" % i for i in range(50)] + body = dict((n, "") for n in nodes) resolver = rpc._StaticResolver(nodes) httperrnodes = set(nodes[1::7]) @@ -222,7 +230,7 @@ class TestRpcProcessor(unittest.TestCase): http_proc = \ _FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse, httperrnodes, failnodes)) - result = proc(nodes, "vg_list", None, _req_process_fn=http_proc, + result = proc(nodes, "vg_list", body, _req_process_fn=http_proc, read_timeout=rpc._TMO_URGENT) self.assertEqual(sorted(result.keys()), sorted(nodes)) @@ -265,12 +273,14 @@ class TestRpcProcessor(unittest.TestCase): for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]: http_proc = _FakeRequestProcessor(fn) - result = proc(["oqo7lanhly.example.com"], "version", None, + host = "oqo7lanhly.example.com" + body = {host: ""} + result = proc([host], "version", body, _req_process_fn=http_proc, read_timeout=60) - self.assertEqual(result.keys(), ["oqo7lanhly.example.com"]) - lhresp = result["oqo7lanhly.example.com"] + self.assertEqual(result.keys(), [host]) + lhresp = result[host] self.assertFalse(lhresp.offline) - self.assertEqual(lhresp.node, "oqo7lanhly.example.com") + self.assertEqual(lhresp.node, host) self.assert_(lhresp.fail_msg) self.assertFalse(lhresp.payload) self.assertEqual(lhresp.call, "version") @@ -295,13 +305,14 @@ class TestRpcProcessor(unittest.TestCase): http_proc = _FakeRequestProcessor(compat.partial(self._GetBodyTestResponse, test_data)) proc = rpc._RpcProcessor(resolver, 18700) - body = serializer.DumpJson(test_data) - result = proc(["node19759"], "upload_file", body, _req_process_fn=http_proc, + host = "node19759" + body = {host: serializer.DumpJson(test_data)} + result = proc([host], "upload_file", body, _req_process_fn=http_proc, read_timeout=30) - self.assertEqual(result.keys(), ["node19759"]) - lhresp = result["node19759"] + self.assertEqual(result.keys(), [host]) + lhresp = result[host] self.assertFalse(lhresp.offline) - self.assertEqual(lhresp.node, "node19759") + self.assertEqual(lhresp.node, host) self.assertFalse(lhresp.fail_msg) self.assertEqual(lhresp.payload, None) self.assertEqual(lhresp.call, "upload_file")