Commit 33231500 authored by Michael Hanselmann's avatar Michael Hanselmann

Convert RPC client to PycURL

Instead of using our custom HTTP client, using PycURL's multi
interface allows us to get rid of the HTTP client threadpool.
The majority of the code is still in the ganeti.http.client
module.

A simple per-thread HTTP client pool gives cURL a chance to
cache and retain as much information as possible (e.g. SSL certs).
Unused HTTP clients (e.g. due to removed nodes) are deleted after
25 requests going through the pool.
Signed-off-by: default avatarMichael Hanselmann <hansmi@google.com>
Reviewed-by: default avatarIustin Pop <iustin@google.com>
parent 7f93570a
......@@ -389,6 +389,7 @@ python_tests = \
test/ganeti.rapi.client_unittest.py \
test/ganeti.rapi.resources_unittest.py \
test/ganeti.rapi.rlib2_unittest.py \
test/ganeti.rpc_unittest.py \
test/ganeti.serializer_unittest.py \
test/ganeti.ssh_unittest.py \
test/ganeti.uidpool_unittest.py \
......
......@@ -513,7 +513,7 @@ def CheckMasterd(options, args):
sys.exit(constants.EXIT_FAILURE)
def ExecMasterd (options, args): # pylint: disable-msg=W0613
def ExecMasterd(options, args): # pylint: disable-msg=W0613
"""Main master daemon function, executed with the PID file held.
"""
......
This diff is collapsed.
......@@ -34,6 +34,8 @@ import os
import logging
import zlib
import base64
import pycurl
import threading
from ganeti import utils
from ganeti import objects
......@@ -47,8 +49,12 @@ from ganeti import netutils
import ganeti.http.client # pylint: disable-msg=W0611
# Module level variable
_http_manager = None
# Timeout for connecting to nodes (seconds)
_RPC_CONNECT_TIMEOUT = 5
_RPC_CLIENT_HEADERS = [
"Content-type: %s" % http.HTTP_APP_JSON,
]
# Various time constants for the timeout table
_TMO_URGENT = 60 # one minute
......@@ -72,29 +78,62 @@ _TIMEOUTS = {
def Init():
"""Initializes the module-global HTTP client manager.
Must be called before using any RPC function.
Must be called before using any RPC function and while exactly one thread is
running.
"""
global _http_manager # pylint: disable-msg=W0603
assert not _http_manager, "RPC module initialized more than once"
# curl_global_init(3) and curl_global_cleanup(3) must be called with only
# one thread running. This check is just a safety measure -- it doesn't
# cover all cases.
assert threading.activeCount() == 1, \
"Found more than one active thread when initializing pycURL"
http.InitSsl()
logging.info("Using PycURL %s", pycurl.version)
_http_manager = http.client.HttpClientManager()
pycurl.global_init(pycurl.GLOBAL_ALL)
def Shutdown():
"""Stops the module-global HTTP client manager.
Must be called before quitting the program.
Must be called before quitting the program and while exactly one thread is
running.
"""
global _http_manager # pylint: disable-msg=W0603
pycurl.global_cleanup()
def _ConfigRpcCurl(curl):
noded_cert = str(constants.NODED_CERT_FILE)
if _http_manager:
_http_manager.Shutdown()
_http_manager = None
curl.setopt(pycurl.FOLLOWLOCATION, False)
curl.setopt(pycurl.CAINFO, noded_cert)
curl.setopt(pycurl.SSL_VERIFYHOST, 0)
curl.setopt(pycurl.SSL_VERIFYPEER, True)
curl.setopt(pycurl.SSLCERTTYPE, "PEM")
curl.setopt(pycurl.SSLCERT, noded_cert)
curl.setopt(pycurl.SSLKEYTYPE, "PEM")
curl.setopt(pycurl.SSLKEY, noded_cert)
curl.setopt(pycurl.CONNECTTIMEOUT, _RPC_CONNECT_TIMEOUT)
class _RpcThreadLocal(threading.local):
def GetHttpClientPool(self):
"""Returns a per-thread HTTP client pool.
@rtype: L{http.client.HttpClientPool}
"""
try:
pool = self.hcp
except AttributeError:
pool = http.client.HttpClientPool(_ConfigRpcCurl)
self.hcp = pool
return pool
_thread_local = _RpcThreadLocal()
def _RpcTimeout(secs):
......@@ -218,11 +257,7 @@ class Client:
self.procedure = procedure
self.body = body
self.port = port
self.nc = {}
self._ssl_params = \
http.HttpSslParams(ssl_key_path=constants.NODED_CERT_FILE,
ssl_cert_path=constants.NODED_CERT_FILE)
self._request = {}
def ConnectList(self, node_list, address_list=None, read_timeout=None):
"""Add a list of nodes to the target nodes.
......@@ -260,28 +295,28 @@ class Client:
if read_timeout is None:
read_timeout = _TIMEOUTS[self.procedure]
self.nc[name] = \
http.client.HttpClientRequest(address, self.port, http.HTTP_PUT,
"/%s" % self.procedure,
post_data=self.body,
ssl_params=self._ssl_params,
ssl_verify_peer=True,
self._request[name] = \
http.client.HttpClientRequest(str(address), self.port,
http.HTTP_PUT, str("/%s" % self.procedure),
headers=_RPC_CLIENT_HEADERS,
post_data=str(self.body),
read_timeout=read_timeout)
def GetResults(self):
def GetResults(self, http_pool=None):
"""Call nodes and return results.
@rtype: list
@return: List of RPC results
"""
assert _http_manager, "RPC module not initialized"
if not http_pool:
http_pool = _thread_local.GetHttpClientPool()
_http_manager.ExecRequests(self.nc.values())
http_pool.ProcessRequests(self._request.values())
results = {}
for name, req in self.nc.iteritems():
for name, req in self._request.iteritems():
if req.success and req.resp_status_code == http.HTTP_OK:
results[name] = RpcResult(data=serializer.LoadJson(req.resp_body),
node=name, call=self.procedure)
......
......@@ -87,12 +87,6 @@ class TestMisc(unittest.TestCase):
self.assert_(message_reader_class.START_LINE_LENGTH_MAX > 0)
self.assert_(message_reader_class.HEADER_LENGTH_MAX > 0)
def testClientSizeLimits(self):
"""Test HTTP client size limits"""
message_reader_class = http.client._HttpServerToClientMessageReader
self.assert_(message_reader_class.START_LINE_LENGTH_MAX > 0)
self.assert_(message_reader_class.HEADER_LENGTH_MAX > 0)
def testFormatAuthHeader(self):
self.assertEqual(http.auth._FormatAuthHeader("Basic", {}),
"Basic")
......@@ -330,5 +324,76 @@ class TestReadPasswordFile(testutils.GanetiTestCase):
self.assertEqual(users["user2"].options, ["write", "read"])
class TestClientRequest(unittest.TestCase):
def testRepr(self):
cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version",
headers=[], post_data="Hello World")
self.assert_(repr(cr).startswith("<"))
def testNoHeaders(self):
cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version",
headers=None)
self.assert_(isinstance(cr.headers, list))
self.assertEqual(cr.headers, [])
self.assertEqual(cr.url, "https://localhost:1234/version")
def testOldStyleHeaders(self):
headers = {
"Content-type": "text/plain",
"Accept": "text/html",
}
cr = http.client.HttpClientRequest("localhost", 16481, "GET", "/vg_list",
headers=headers)
self.assert_(isinstance(cr.headers, list))
self.assertEqual(sorted(cr.headers), [
"Accept: text/html",
"Content-type: text/plain",
])
self.assertEqual(cr.url, "https://localhost:16481/vg_list")
def testNewStyleHeaders(self):
headers = [
"Accept: text/html",
"Content-type: text/plain; charset=ascii",
"Server: httpd 1.0",
]
cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version",
headers=headers)
self.assert_(isinstance(cr.headers, list))
self.assertEqual(sorted(cr.headers), sorted(headers))
self.assertEqual(cr.url, "https://localhost:1234/version")
def testPostData(self):
cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version",
post_data="Hello World")
self.assertEqual(cr.post_data, "Hello World")
def testNoPostData(self):
cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version")
self.assertEqual(cr.post_data, "")
def testIdentity(self):
# These should all use different connections, hence also have a different
# identity
cr1 = http.client.HttpClientRequest("localhost", 1234, "GET", "/version")
cr2 = http.client.HttpClientRequest("localhost", 9999, "GET", "/version")
cr3 = http.client.HttpClientRequest("node1", 1234, "GET", "/version")
cr4 = http.client.HttpClientRequest("node1", 9999, "GET", "/version")
self.assertEqual(len(set([cr1.identity, cr2.identity,
cr3.identity, cr4.identity])), 4)
# But this one should have the same
cr1vglist = http.client.HttpClientRequest("localhost", 1234,
"GET", "/vg_list")
self.assertEqual(cr1.identity, cr1vglist.identity)
class TestClient(unittest.TestCase):
def test(self):
pool = http.client.HttpClientPool(None)
self.assertFalse(pool._pool)
if __name__ == '__main__':
testutils.GanetiTestProgram()
#!/usr/bin/python
#
# Copyright (C) 2010 Google Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
# 02110-1301, USA.
"""Script for testing ganeti.rpc"""
import os
import sys
import unittest
from ganeti import constants
from ganeti import compat
from ganeti import rpc
from ganeti import http
from ganeti import errors
from ganeti import serializer
import testutils
class TestTimeouts(unittest.TestCase):
def test(self):
names = [name[len("call_"):] for name in dir(rpc.RpcRunner)
if name.startswith("call_")]
self.assertEqual(len(names), len(rpc._TIMEOUTS))
self.assertFalse([name for name in names
if not (rpc._TIMEOUTS[name] is None or
rpc._TIMEOUTS[name] > 0)])
class FakeHttpPool:
def __init__(self, response_fn):
self._response_fn = response_fn
self.reqcount = 0
def ProcessRequests(self, reqs):
for req in reqs:
self.reqcount += 1
self._response_fn(req)
class TestClient(unittest.TestCase):
def _GetVersionResponse(self, req):
self.assertEqual(req.host, "localhost")
self.assertEqual(req.port, 24094)
self.assertEqual(req.path, "/version")
req.success = True
req.resp_status_code = http.HTTP_OK
req.resp_body = serializer.DumpJson((True, 123))
def testVersionSuccess(self):
client = rpc.Client("version", None, 24094)
client.ConnectNode("localhost")
pool = FakeHttpPool(self._GetVersionResponse)
result = client.GetResults(http_pool=pool)
self.assertEqual(result.keys(), ["localhost"])
lhresp = result["localhost"]
self.assertFalse(lhresp.offline)
self.assertEqual(lhresp.node, "localhost")
self.assertFalse(lhresp.fail_msg)
self.assertEqual(lhresp.payload, 123)
self.assertEqual(lhresp.call, "version")
lhresp.Raise("should not raise")
self.assertEqual(pool.reqcount, 1)
def _GetMultiVersionResponse(self, req):
self.assert_(req.host.startswith("node"))
self.assertEqual(req.port, 23245)
self.assertEqual(req.path, "/version")
req.success = True
req.resp_status_code = http.HTTP_OK
req.resp_body = serializer.DumpJson((True, 987))
def testMultiVersionSuccess(self):
nodes = ["node%s" % i for i in range(50)]
client = rpc.Client("version", None, 23245)
client.ConnectList(nodes)
pool = FakeHttpPool(self._GetMultiVersionResponse)
result = client.GetResults(http_pool=pool)
self.assertEqual(sorted(result.keys()), sorted(nodes))
for name in nodes:
lhresp = result[name]
self.assertFalse(lhresp.offline)
self.assertEqual(lhresp.node, name)
self.assertFalse(lhresp.fail_msg)
self.assertEqual(lhresp.payload, 987)
self.assertEqual(lhresp.call, "version")
lhresp.Raise("should not raise")
self.assertEqual(pool.reqcount, len(nodes))
def _GetVersionResponseFail(self, req):
self.assertEqual(req.path, "/version")
req.success = True
req.resp_status_code = http.HTTP_OK
req.resp_body = serializer.DumpJson((False, "Unknown error"))
def testVersionFailure(self):
client = rpc.Client("version", None, 5903)
client.ConnectNode("aef9ur4i.example.com")
pool = FakeHttpPool(self._GetVersionResponseFail)
result = client.GetResults(http_pool=pool)
self.assertEqual(result.keys(), ["aef9ur4i.example.com"])
lhresp = result["aef9ur4i.example.com"]
self.assertFalse(lhresp.offline)
self.assertEqual(lhresp.node, "aef9ur4i.example.com")
self.assert_(lhresp.fail_msg)
self.assertFalse(lhresp.payload)
self.assertEqual(lhresp.call, "version")
self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
self.assertEqual(pool.reqcount, 1)
def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
self.assertEqual(req.path, "/vg_list")
self.assertEqual(req.port, 15165)
if req.host in httperrnodes:
req.success = False
req.error = "Node set up for HTTP errors"
elif req.host in failnodes:
req.success = True
req.resp_status_code = 404
req.resp_body = serializer.DumpJson({
"code": 404,
"message": "Method not found",
"explain": "Explanation goes here",
})
else:
req.success = True
req.resp_status_code = http.HTTP_OK
req.resp_body = serializer.DumpJson((True, hash(req.host)))
def testHttpError(self):
nodes = ["uaf6pbbv%s" % i for i in range(50)]
httperrnodes = set(nodes[1::7])
self.assertEqual(len(httperrnodes), 7)
failnodes = set(nodes[2::3]) - httperrnodes
self.assertEqual(len(failnodes), 14)
self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
client = rpc.Client("vg_list", None, 15165)
client.ConnectList(nodes)
pool = FakeHttpPool(compat.partial(self._GetHttpErrorResponse,
httperrnodes, failnodes))
result = client.GetResults(http_pool=pool)
self.assertEqual(sorted(result.keys()), sorted(nodes))
for name in nodes:
lhresp = result[name]
self.assertFalse(lhresp.offline)
self.assertEqual(lhresp.node, name)
self.assertEqual(lhresp.call, "vg_list")
if name in httperrnodes:
self.assert_(lhresp.fail_msg)
self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
elif name in failnodes:
self.assert_(lhresp.fail_msg)
self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
prereq=True, ecode=errors.ECODE_INVAL)
else:
self.assertFalse(lhresp.fail_msg)
self.assertEqual(lhresp.payload, hash(name))
lhresp.Raise("should not raise")
self.assertEqual(pool.reqcount, len(nodes))
def _GetInvalidResponseA(self, req):
self.assertEqual(req.path, "/version")
req.success = True
req.resp_status_code = http.HTTP_OK
req.resp_body = serializer.DumpJson(("This", "is", "an", "invalid",
"response", "!", 1, 2, 3))
def _GetInvalidResponseB(self, req):
self.assertEqual(req.path, "/version")
req.success = True
req.resp_status_code = http.HTTP_OK
req.resp_body = serializer.DumpJson("invalid response")
def testInvalidResponse(self):
client = rpc.Client("version", None, 19978)
for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
client.ConnectNode("oqo7lanhly.example.com")
pool = FakeHttpPool(fn)
result = client.GetResults(http_pool=pool)
self.assertEqual(result.keys(), ["oqo7lanhly.example.com"])
lhresp = result["oqo7lanhly.example.com"]
self.assertFalse(lhresp.offline)
self.assertEqual(lhresp.node, "oqo7lanhly.example.com")
self.assert_(lhresp.fail_msg)
self.assertFalse(lhresp.payload)
self.assertEqual(lhresp.call, "version")
self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
self.assertEqual(pool.reqcount, 1)
if __name__ == "__main__":
testutils.GanetiTestProgram()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment