From d9de612c0c146448ffcc510ed7b74ba047d35de6 Mon Sep 17 00:00:00 2001
From: Iustin Pop <iustin@google.com>
Date: Wed, 21 Dec 2011 14:11:32 +0100
Subject: [PATCH] Change internal RPC client body values
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Currently, all RPC payloads sent by the client to the remote node
daemons must be identical, due to how the data is passed
internally. This is deficient in both use (from the programmer's point
of view) and from the network traffic (cluster verify/disk data
gathering has a total payload which is O(nΒ²) in size of the nodes
being queried, instead of O(n)).

This patch changes the RPC internals so that we always pass
dictionaries indexed by target node name. For the default use case,
when the payload is identical, we only serialise the payload once, so
the extra overhead is just a dict with the node names and values all
pointing to the same object. For different payloads, we will encode
the body multiple times, but hopefully the bodies will be smaller.

Signed-off-by: Iustin Pop <iustin@google.com>
Reviewed-by: Michael Hanselmann <hansmi@google.com>
---
 lib/rpc.py                  | 38 ++++++++++++++++------
 test/ganeti.rpc_unittest.py | 63 ++++++++++++++++++++++---------------
 2 files changed, 66 insertions(+), 35 deletions(-)

diff --git a/lib/rpc.py b/lib/rpc.py
index 2f65ad53f..bf05cfb25 100644
--- a/lib/rpc.py
+++ b/lib/rpc.py
@@ -338,10 +338,19 @@ class _RpcProcessor:
   def _PrepareRequests(hosts, port, procedure, body, read_timeout):
     """Prepares requests by sorting offline hosts into separate list.
 
+    @type body: dict
+    @param body: a dictionary with per-host body data
+
     """
     results = {}
     requests = {}
 
+    assert isinstance(body, dict)
+    assert len(body) == len(hosts)
+    assert compat.all(isinstance(v, str) for v in body.values())
+    assert frozenset(map(compat.fst, hosts)) == frozenset(body.keys()), \
+        "%s != %s" % (hosts, body.keys())
+
     for (name, ip) in hosts:
       if ip is _OFFLINE:
         # Node is marked as offline
@@ -351,7 +360,7 @@ class _RpcProcessor:
           http.client.HttpClientRequest(str(ip), port,
                                         http.HTTP_PUT, str("/%s" % procedure),
                                         headers=_RPC_CLIENT_HEADERS,
-                                        post_data=body,
+                                        post_data=body[name],
                                         read_timeout=read_timeout,
                                         nicename="%s/%s" % (name, procedure),
                                         curl_config_fn=_ConfigRpcCurl)
@@ -390,8 +399,8 @@ class _RpcProcessor:
     @param hosts: Hostnames
     @type procedure: string
     @param procedure: Request path
-    @type body: string
-    @param body: Request body
+    @type body: dictionary
+    @param body: dictionary with request bodies per host
     @type read_timeout: int or None
     @param read_timeout: Read timeout for request
 
@@ -401,7 +410,7 @@ class _RpcProcessor:
 
     (results, requests) = \
       self._PrepareRequests(self._resolver(hosts), self._port, procedure,
-                            str(body), read_timeout)
+                            body, read_timeout)
 
     _req_process_fn(requests.values(), lock_monitor_cb=self._lock_monitor_cb)
 
@@ -434,17 +443,28 @@ class _RpcClientBase:
     """Entry point for automatically generated RPC wrappers.
 
     """
-    (procedure, _, timeout, argdefs, _, postproc_fn, _) = cdef
+    (procedure, _, timeout, argdefs, prep_fn, postproc_fn, _) = cdef
 
     if callable(timeout):
       read_timeout = timeout(args)
     else:
       read_timeout = timeout
 
-    body = serializer.DumpJson(map(self._encoder,
-                                   zip(map(compat.snd, argdefs), args)))
-
-    result = self._proc(node_list, procedure, body, read_timeout=read_timeout)
+    enc_args = map(self._encoder, zip(map(compat.snd, argdefs), args))
+    if prep_fn is None:
+      # for a no-op prep_fn, we serialise the body once, and then we
+      # reuse it in the dictionary values
+      body = serializer.DumpJson(enc_args)
+      pnbody = dict((n, body) for n in node_list)
+    else:
+      # for a custom prep_fn, we pass the encoded arguments and the
+      # node name to the prep_fn, and we serialise its return value
+      assert(callable(prep_fn))
+      pnbody = dict((n, serializer.DumpJson(prep_fn(n, enc_args)))
+                    for n in node_list)
+
+    result = self._proc(node_list, procedure, pnbody,
+                        read_timeout=read_timeout)
 
     if postproc_fn:
       return dict(map(lambda (key, value): (key, postproc_fn(value)),
diff --git a/test/ganeti.rpc_unittest.py b/test/ganeti.rpc_unittest.py
index c434ad46e..e16802eb9 100755
--- a/test/ganeti.rpc_unittest.py
+++ b/test/ganeti.rpc_unittest.py
@@ -1,7 +1,7 @@
 #!/usr/bin/python
 #
 
-# Copyright (C) 2010 Google Inc.
+# Copyright (C) 2010, 2011 Google Inc.
 #
 # This program is free software; you can redistribute it and/or modify
 # it under the terms of the GNU General Public License as published by
@@ -73,8 +73,8 @@ class TestRpcProcessor(unittest.TestCase):
     resolver = rpc._StaticResolver(["127.0.0.1"])
     http_proc = _FakeRequestProcessor(self._GetVersionResponse)
     proc = rpc._RpcProcessor(resolver, 24094)
-    result = proc(["localhost"], "version", None, _req_process_fn=http_proc,
-                  read_timeout=60)
+    result = proc(["localhost"], "version", {"localhost": ""},
+                  _req_process_fn=http_proc, read_timeout=60)
     self.assertEqual(result.keys(), ["localhost"])
     lhresp = result["localhost"]
     self.assertFalse(lhresp.offline)
@@ -98,12 +98,14 @@ class TestRpcProcessor(unittest.TestCase):
     resolver = rpc._StaticResolver(["192.0.2.13"])
     http_proc = _FakeRequestProcessor(self._ReadTimeoutResponse)
     proc = rpc._RpcProcessor(resolver, 19176)
-    result = proc(["node31856"], "version", None, _req_process_fn=http_proc,
+    host = "node31856"
+    body = {host: ""}
+    result = proc([host], "version", body, _req_process_fn=http_proc,
                   read_timeout=12356)
-    self.assertEqual(result.keys(), ["node31856"])
-    lhresp = result["node31856"]
+    self.assertEqual(result.keys(), [host])
+    lhresp = result[host]
     self.assertFalse(lhresp.offline)
-    self.assertEqual(lhresp.node, "node31856")
+    self.assertEqual(lhresp.node, host)
     self.assertFalse(lhresp.fail_msg)
     self.assertEqual(lhresp.payload, -1)
     self.assertEqual(lhresp.call, "version")
@@ -114,12 +116,14 @@ class TestRpcProcessor(unittest.TestCase):
     resolver = rpc._StaticResolver([rpc._OFFLINE])
     http_proc = _FakeRequestProcessor(NotImplemented)
     proc = rpc._RpcProcessor(resolver, 30668)
-    result = proc(["n17296"], "version", None, _req_process_fn=http_proc,
+    host = "n17296"
+    body = {host: ""}
+    result = proc([host], "version", body, _req_process_fn=http_proc,
                   read_timeout=60)
-    self.assertEqual(result.keys(), ["n17296"])
-    lhresp = result["n17296"]
+    self.assertEqual(result.keys(), [host])
+    lhresp = result[host]
     self.assertTrue(lhresp.offline)
-    self.assertEqual(lhresp.node, "n17296")
+    self.assertEqual(lhresp.node, host)
     self.assertTrue(lhresp.fail_msg)
     self.assertFalse(lhresp.payload)
     self.assertEqual(lhresp.call, "version")
@@ -142,10 +146,11 @@ class TestRpcProcessor(unittest.TestCase):
 
   def testMultiVersionSuccess(self):
     nodes = ["node%s" % i for i in range(50)]
+    body = dict((n, "") for n in nodes)
     resolver = rpc._StaticResolver(nodes)
     http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse)
     proc = rpc._RpcProcessor(resolver, 23245)
-    result = proc(nodes, "version", None, _req_process_fn=http_proc,
+    result = proc(nodes, "version", body, _req_process_fn=http_proc,
                   read_timeout=60)
     self.assertEqual(sorted(result.keys()), sorted(nodes))
 
@@ -173,12 +178,14 @@ class TestRpcProcessor(unittest.TestCase):
       http_proc = \
         _FakeRequestProcessor(compat.partial(self._GetVersionResponseFail,
                                              errinfo))
-      result = proc(["aef9ur4i.example.com"], "version", None,
+      host = "aef9ur4i.example.com"
+      body = {host: ""}
+      result = proc(body.keys(), "version", body,
                     _req_process_fn=http_proc, read_timeout=60)
-      self.assertEqual(result.keys(), ["aef9ur4i.example.com"])
-      lhresp = result["aef9ur4i.example.com"]
+      self.assertEqual(result.keys(), [host])
+      lhresp = result[host]
       self.assertFalse(lhresp.offline)
-      self.assertEqual(lhresp.node, "aef9ur4i.example.com")
+      self.assertEqual(lhresp.node, host)
       self.assert_(lhresp.fail_msg)
       self.assertFalse(lhresp.payload)
       self.assertEqual(lhresp.call, "version")
@@ -208,6 +215,7 @@ class TestRpcProcessor(unittest.TestCase):
 
   def testHttpError(self):
     nodes = ["uaf6pbbv%s" % i for i in range(50)]
+    body = dict((n, "") for n in nodes)
     resolver = rpc._StaticResolver(nodes)
 
     httperrnodes = set(nodes[1::7])
@@ -222,7 +230,7 @@ class TestRpcProcessor(unittest.TestCase):
     http_proc = \
       _FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse,
                                            httperrnodes, failnodes))
-    result = proc(nodes, "vg_list", None, _req_process_fn=http_proc,
+    result = proc(nodes, "vg_list", body, _req_process_fn=http_proc,
                   read_timeout=rpc._TMO_URGENT)
     self.assertEqual(sorted(result.keys()), sorted(nodes))
 
@@ -265,12 +273,14 @@ class TestRpcProcessor(unittest.TestCase):
 
     for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
       http_proc = _FakeRequestProcessor(fn)
-      result = proc(["oqo7lanhly.example.com"], "version", None,
+      host = "oqo7lanhly.example.com"
+      body = {host: ""}
+      result = proc([host], "version", body,
                     _req_process_fn=http_proc, read_timeout=60)
-      self.assertEqual(result.keys(), ["oqo7lanhly.example.com"])
-      lhresp = result["oqo7lanhly.example.com"]
+      self.assertEqual(result.keys(), [host])
+      lhresp = result[host]
       self.assertFalse(lhresp.offline)
-      self.assertEqual(lhresp.node, "oqo7lanhly.example.com")
+      self.assertEqual(lhresp.node, host)
       self.assert_(lhresp.fail_msg)
       self.assertFalse(lhresp.payload)
       self.assertEqual(lhresp.call, "version")
@@ -295,13 +305,14 @@ class TestRpcProcessor(unittest.TestCase):
     http_proc = _FakeRequestProcessor(compat.partial(self._GetBodyTestResponse,
                                                      test_data))
     proc = rpc._RpcProcessor(resolver, 18700)
-    body = serializer.DumpJson(test_data)
-    result = proc(["node19759"], "upload_file", body, _req_process_fn=http_proc,
+    host = "node19759"
+    body = {host: serializer.DumpJson(test_data)}
+    result = proc([host], "upload_file", body, _req_process_fn=http_proc,
                   read_timeout=30)
-    self.assertEqual(result.keys(), ["node19759"])
-    lhresp = result["node19759"]
+    self.assertEqual(result.keys(), [host])
+    lhresp = result[host]
     self.assertFalse(lhresp.offline)
-    self.assertEqual(lhresp.node, "node19759")
+    self.assertEqual(lhresp.node, host)
     self.assertFalse(lhresp.fail_msg)
     self.assertEqual(lhresp.payload, None)
     self.assertEqual(lhresp.call, "upload_file")
-- 
GitLab