diff --git a/lib/http/client.py b/lib/http/client.py index 1186e214d167988e6804d7bb8b58ad14d4d646de..163e4c008df5e926354a867ad91f4658acc41dac 100644 --- a/lib/http/client.py +++ b/lib/http/client.py @@ -55,10 +55,6 @@ class HttpClientRequest(object): timeout while reading the response from the server @type curl_config_fn: callable @param curl_config_fn: Function to configure cURL object before request - (Note: if the function configures the connection in - a way where it wouldn't be efficient to reuse them, - an "identity" property should be defined, see - L{HttpClientRequest.identity}) @type nicename: string @param nicename: Name, presentable to a user, to describe this request (no whitespace) @@ -118,58 +114,77 @@ class HttpClientRequest(object): # TODO: Support for non-SSL requests return "https://%s%s" % (address, self.path) - @property - def identity(self): - """Returns identifier for retrieving a pooled connection for this request. - This allows cURL client objects to be re-used and to cache information - (e.g. SSL session IDs or connections). +def _StartRequest(curl, req): + """Starts a request on a cURL object. - """ - parts = [self.host, self.port] + @type curl: pycurl.Curl + @param curl: cURL object + @type req: L{HttpClientRequest} + @param req: HTTP request - if self.curl_config_fn: - try: - parts.append(self.curl_config_fn.identity) - except AttributeError: - pass + """ + logging.debug("Starting request %r", req) - return "/".join(str(i) for i in parts) + url = req.url + method = req.method + post_data = req.post_data + headers = req.headers + # PycURL requires strings to be non-unicode + assert isinstance(method, str) + assert isinstance(url, str) + assert isinstance(post_data, str) + assert compat.all(isinstance(i, str) for i in headers) -class _HttpClient(object): - def __init__(self, curl_config_fn): - """Initializes this class. + # Buffer for response + resp_buffer = StringIO() - @type curl_config_fn: callable - @param curl_config_fn: Function to configure cURL object after - initialization + # Configure client for request + curl.setopt(pycurl.VERBOSE, False) + curl.setopt(pycurl.NOSIGNAL, True) + curl.setopt(pycurl.USERAGENT, http.HTTP_GANETI_VERSION) + curl.setopt(pycurl.PROXY, "") + curl.setopt(pycurl.CUSTOMREQUEST, str(method)) + curl.setopt(pycurl.URL, url) + curl.setopt(pycurl.POSTFIELDS, post_data) + curl.setopt(pycurl.HTTPHEADER, headers) - """ - self._req = None + if req.read_timeout is None: + curl.setopt(pycurl.TIMEOUT, 0) + else: + curl.setopt(pycurl.TIMEOUT, int(req.read_timeout)) - curl = self._CreateCurlHandle() - curl.setopt(pycurl.VERBOSE, False) - curl.setopt(pycurl.NOSIGNAL, True) - curl.setopt(pycurl.USERAGENT, http.HTTP_GANETI_VERSION) - curl.setopt(pycurl.PROXY, "") + # Disable SSL session ID caching (pycurl >= 7.16.0) + if hasattr(pycurl, "SSL_SESSIONID_CACHE"): + curl.setopt(pycurl.SSL_SESSIONID_CACHE, False) - # Disable SSL session ID caching (pycurl >= 7.16.0) - if hasattr(pycurl, "SSL_SESSIONID_CACHE"): - curl.setopt(pycurl.SSL_SESSIONID_CACHE, False) + curl.setopt(pycurl.WRITEFUNCTION, resp_buffer.write) - # Pass cURL object to external config function - if curl_config_fn: - curl_config_fn(curl) + # Pass cURL object to external config function + if req.curl_config_fn: + req.curl_config_fn(curl) - self._curl = curl + return _PendingRequest(curl, req, resp_buffer.getvalue) - @staticmethod - def _CreateCurlHandle(): - """Returns a new cURL object. + +class _PendingRequest: + def __init__(self, curl, req, resp_buffer_read): + """Initializes this class. + + @type curl: pycurl.Curl + @param curl: cURL object + @type req: L{HttpClientRequest} + @param req: HTTP request + @type resp_buffer_read: callable + @param resp_buffer_read: Function to read response body """ - return pycurl.Curl() + assert req.success is None + + self._curl = curl + self._req = req + self._resp_buffer_read = resp_buffer_read def GetCurlHandle(self): """Returns the cURL object. @@ -180,53 +195,9 @@ class _HttpClient(object): def GetCurrentRequest(self): """Returns the current request. - @rtype: L{HttpClientRequest} or None - """ return self._req - def StartRequest(self, req): - """Starts a request on this client. - - @type req: L{HttpClientRequest} - @param req: HTTP request - - """ - assert not self._req, "Another request is already started" - - logging.debug("Starting request %r", req) - - self._req = req - self._resp_buffer = StringIO() - - url = req.url - method = req.method - post_data = req.post_data - headers = req.headers - - # PycURL requires strings to be non-unicode - assert isinstance(method, str) - assert isinstance(url, str) - assert isinstance(post_data, str) - assert compat.all(isinstance(i, str) for i in headers) - - # Configure cURL object for request - curl = self._curl - curl.setopt(pycurl.CUSTOMREQUEST, str(method)) - curl.setopt(pycurl.URL, url) - curl.setopt(pycurl.POSTFIELDS, post_data) - curl.setopt(pycurl.WRITEFUNCTION, self._resp_buffer.write) - curl.setopt(pycurl.HTTPHEADER, headers) - - if req.read_timeout is None: - curl.setopt(pycurl.TIMEOUT, 0) - else: - curl.setopt(pycurl.TIMEOUT, int(req.read_timeout)) - - # Pass cURL object to external config function - if req.curl_config_fn: - req.curl_config_fn(curl) - def Done(self, errmsg): """Finishes a request. @@ -234,222 +205,29 @@ class _HttpClient(object): @param errmsg: Error message if request failed """ + curl = self._curl req = self._req - assert req, "No request" - logging.debug("Request %s finished, errmsg=%s", req, errmsg) + assert req.success is None, "Request has already been finalized" - curl = self._curl + logging.debug("Request %s finished, errmsg=%s", req, errmsg) req.success = not bool(errmsg) req.error = errmsg # Get HTTP response code req.resp_status_code = curl.getinfo(pycurl.RESPONSE_CODE) - req.resp_body = self._resp_buffer.getvalue() - - # Reset client object - self._req = None - self._resp_buffer = None + req.resp_body = self._resp_buffer_read() # Ensure no potentially large variables are referenced - curl.setopt(pycurl.POSTFIELDS, "") - curl.setopt(pycurl.WRITEFUNCTION, lambda _: None) - - -class _PooledHttpClient: - """Data structure for HTTP client pool. - - """ - def __init__(self, identity, client): - """Initializes this class. - - @type identity: string - @param identity: Client identifier for pool - @type client: L{_HttpClient} - @param client: HTTP client - - """ - self.identity = identity - self.client = client - self.lastused = 0 - - def __repr__(self): - status = ["%s.%s" % (self.__class__.__module__, self.__class__.__name__), - "id=%s" % self.identity, - "lastuse=%s" % self.lastused, - repr(self.client)] - - return "<%s at %#x>" % (" ".join(status), id(self)) - - -class HttpClientPool: - """A simple HTTP client pool. - - Supports one pooled connection per identity (see - L{HttpClientRequest.identity}). - - """ - #: After how many generations to drop unused clients - _MAX_GENERATIONS_DROP = 25 - - def __init__(self, curl_config_fn): - """Initializes this class. - - @type curl_config_fn: callable - @param curl_config_fn: Function to configure cURL object after - initialization - - """ - self._curl_config_fn = curl_config_fn - self._generation = 0 - self._pool = {} - - # Create custom logger for HTTP client pool. Change logging level to - # C{logging.NOTSET} to get more details. - self._logger = logging.getLogger(self.__class__.__name__) - self._logger.setLevel(logging.INFO) - - @staticmethod - def _GetHttpClientCreator(): - """Returns callable to create HTTP client. - - """ - return _HttpClient - - def _Get(self, identity): - """Gets an HTTP client from the pool. - - @type identity: string - @param identity: Client identifier - - """ try: - pclient = self._pool.pop(identity) - except KeyError: - # Need to create new client - client = self._GetHttpClientCreator()(self._curl_config_fn) - pclient = _PooledHttpClient(identity, client) - self._logger.debug("Created new client %s", pclient) + # Only available in PycURL 7.19.0 and above + reset_fn = curl.reset + except AttributeError: + curl.setopt(pycurl.POSTFIELDS, "") + curl.setopt(pycurl.WRITEFUNCTION, lambda _: None) else: - self._logger.debug("Reusing client %s", pclient) - - assert pclient.identity == identity - - return pclient - - def _StartRequest(self, req): - """Starts a request. - - @type req: L{HttpClientRequest} - @param req: HTTP request - - """ - pclient = self._Get(req.identity) - - assert req.identity not in self._pool - - pclient.client.StartRequest(req) - pclient.lastused = self._generation - - return pclient - - def _Return(self, pclients): - """Returns HTTP clients to the pool. - - """ - assert not frozenset(pclients) & frozenset(self._pool.values()) - - for pc in pclients: - self._logger.debug("Returning client %s to pool", pc) - assert pc.identity not in self._pool - self._pool[pc.identity] = pc - - # Check for unused clients - for pc in self._pool.values(): - if (pc.lastused + self._MAX_GENERATIONS_DROP) < self._generation: - self._logger.debug("Removing client %s which hasn't been used" - " for %s generations", - pc, self._MAX_GENERATIONS_DROP) - self._pool.pop(pc.identity, None) - - assert compat.all(pc.lastused >= (self._generation - - self._MAX_GENERATIONS_DROP) - for pc in self._pool.values()) - - @staticmethod - def _CreateCurlMultiHandle(): - """Creates new cURL multi handle. - - """ - return pycurl.CurlMulti() - - def ProcessRequests(self, requests, lock_monitor_cb=None): - """Processes any number of HTTP client requests using pooled objects. - - @type requests: list of L{HttpClientRequest} - @param requests: List of all requests - @param lock_monitor_cb: Callable for registering with lock monitor - - """ - # For client cleanup - self._generation += 1 - - assert compat.all((req.error is None and - req.success is None and - req.resp_status_code is None and - req.resp_body is None) - for req in requests) - - curl_to_pclient = {} - for req in requests: - pclient = self._StartRequest(req) - curl_to_pclient[pclient.client.GetCurlHandle()] = pclient - assert pclient.client.GetCurrentRequest() == req - assert pclient.lastused >= 0 - - assert len(curl_to_pclient) == len(requests) - - if lock_monitor_cb: - monitor = _PendingRequestMonitor(threading.currentThread(), - curl_to_pclient.values) - lock_monitor_cb(monitor) - else: - monitor = _NoOpRequestMonitor - - # Process all requests and act based on the returned values - for (curl, msg) in _ProcessCurlRequests(self._CreateCurlMultiHandle(), - curl_to_pclient.keys()): - pclient = curl_to_pclient[curl] - req = pclient.client.GetCurrentRequest() - - monitor.acquire(shared=0) - try: - pclient.client.Done(msg) - finally: - monitor.release() - - assert ((msg is None and req.success and req.error is None) ^ - (msg is not None and not req.success and req.error == msg)) - - assert compat.all(pclient.client.GetCurrentRequest() is None - for pclient in curl_to_pclient.values()) - - monitor.acquire(shared=0) - try: - # Don't try to read information from returned clients - monitor.Disable() - - # Return clients to pool - self._Return(curl_to_pclient.values()) - finally: - monitor.release() - - assert compat.all(req.error is not None or - (req.success and - req.resp_status_code is not None and - req.resp_body is not None) - for req in requests) + reset_fn() class _NoOpRequestMonitor: # pylint: disable=W0232 @@ -479,6 +257,7 @@ class _PendingRequestMonitor: self.acquire = self._lock.acquire self.release = self._lock.release + @locking.ssynchronized(_LOCK) def Disable(self): """Disable monitor. @@ -501,8 +280,8 @@ class _PendingRequestMonitor: if self._pending_fn: owner_name = self._owner.getName() - for pclient in self._pending_fn(): - req = pclient.client.GetCurrentRequest() + for client in self._pending_fn(): + req = client.GetCurrentRequest() if req: if req.nicename is None: name = "%s%s" % (req.host, req.path) @@ -559,3 +338,53 @@ def _ProcessCurlRequests(multi, requests): # timeouts, which are only evaluated in multi.perform, aren't # unnecessarily delayed. multi.select(1.0) + + +def ProcessRequests(requests, lock_monitor_cb=None, _curl=pycurl.Curl, + _curl_multi=pycurl.CurlMulti, + _curl_process=_ProcessCurlRequests): + """Processes any number of HTTP client requests. + + @type requests: list of L{HttpClientRequest} + @param requests: List of all requests + @param lock_monitor_cb: Callable for registering with lock monitor + + """ + assert compat.all((req.error is None and + req.success is None and + req.resp_status_code is None and + req.resp_body is None) + for req in requests) + + # Prepare all requests + curl_to_client = \ + dict((client.GetCurlHandle(), client) + for client in map(lambda req: _StartRequest(_curl(), req), requests)) + + assert len(curl_to_client) == len(requests) + + if lock_monitor_cb: + monitor = _PendingRequestMonitor(threading.currentThread(), + curl_to_client.values) + lock_monitor_cb(monitor) + else: + monitor = _NoOpRequestMonitor + + # Process all requests and act based on the returned values + for (curl, msg) in _curl_process(_curl_multi(), curl_to_client.keys()): + monitor.acquire(shared=0) + try: + curl_to_client.pop(curl).Done(msg) + finally: + monitor.release() + + assert not curl_to_client, "Not all requests were processed" + + # Don't try to read information anymore as all requests have been processed + monitor.Disable() + + assert compat.all(req.error is not None or + (req.success and + req.resp_status_code is not None and + req.resp_body is not None) + for req in requests) diff --git a/lib/rpc.py b/lib/rpc.py index 44ef760bd478309eba344c16fb87204a600ce908..3a1cd9b1346b1f5bd9a0a6b824e29e7526aff0c2 100644 --- a/lib/rpc.py +++ b/lib/rpc.py @@ -374,7 +374,8 @@ class _RpcProcessor: headers=_RPC_CLIENT_HEADERS, post_data=body, read_timeout=read_timeout, - nicename="%s/%s" % (name, procedure)) + nicename="%s/%s" % (name, procedure), + curl_config_fn=_ConfigRpcCurl) return (results, requests) @@ -402,7 +403,8 @@ class _RpcProcessor: return results - def __call__(self, hosts, procedure, body, read_timeout=None, http_pool=None): + def __call__(self, hosts, procedure, body, read_timeout=None, + _req_process_fn=http.client.ProcessRequests): """Makes an RPC request to a number of nodes. @type hosts: sequence @@ -417,9 +419,6 @@ class _RpcProcessor: """ assert procedure in _TIMEOUTS, "RPC call not declared in the timeouts table" - if not http_pool: - http_pool = http.client.HttpClientPool(_ConfigRpcCurl) - if read_timeout is None: read_timeout = _TIMEOUTS[procedure] @@ -427,8 +426,7 @@ class _RpcProcessor: self._PrepareRequests(self._resolver(hosts), self._port, procedure, str(body), read_timeout) - http_pool.ProcessRequests(requests.values(), - lock_monitor_cb=self._lock_monitor_cb) + _req_process_fn(requests.values(), lock_monitor_cb=self._lock_monitor_cb) assert not frozenset(results).intersection(requests) diff --git a/test/ganeti.http_unittest.py b/test/ganeti.http_unittest.py index 76577ea6950632f994ebd27f9d4dcb217409c497..7bffe6fe97d3a5d69dd7c644db3805deaad4dfdb 100755 --- a/test/ganeti.http_unittest.py +++ b/test/ganeti.http_unittest.py @@ -26,9 +26,13 @@ import os import unittest import time import tempfile +import pycurl +import itertools +import threading from cStringIO import StringIO from ganeti import http +from ganeti import compat import ganeti.http.server import ganeti.http.client @@ -330,6 +334,14 @@ class TestClientRequest(unittest.TestCase): self.assertEqual(cr.headers, []) self.assertEqual(cr.url, "https://localhost:1234/version") + def testPlainAddressIPv4(self): + cr = http.client.HttpClientRequest("192.0.2.9", 19956, "GET", "/version") + self.assertEqual(cr.url, "https://192.0.2.9:19956/version") + + def testPlainAddressIPv6(self): + cr = http.client.HttpClientRequest("2001:db8::cafe", 15110, "GET", "/info") + self.assertEqual(cr.url, "https://[2001:db8::cafe]:15110/info") + def testOldStyleHeaders(self): headers = { "Content-type": "text/plain", @@ -365,27 +377,339 @@ class TestClientRequest(unittest.TestCase): cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version") self.assertEqual(cr.post_data, "") - def testIdentity(self): - # These should all use different connections, hence also have a different - # identity - cr1 = http.client.HttpClientRequest("localhost", 1234, "GET", "/version") - cr2 = http.client.HttpClientRequest("localhost", 9999, "GET", "/version") - cr3 = http.client.HttpClientRequest("node1", 1234, "GET", "/version") - cr4 = http.client.HttpClientRequest("node1", 9999, "GET", "/version") - self.assertEqual(len(set([cr1.identity, cr2.identity, - cr3.identity, cr4.identity])), 4) +class _FakeCurl: + def __init__(self): + self.opts = {} + self.info = NotImplemented + + def setopt(self, opt, value): + assert opt not in self.opts, "Option set more than once" + self.opts[opt] = value + + def getinfo(self, info): + return self.info.pop(info) + + +class TestClientStartRequest(unittest.TestCase): + @staticmethod + def _TestCurlConfig(curl): + curl.setopt(pycurl.SSLKEYTYPE, "PEM") + + def test(self): + for method in [http.HTTP_GET, http.HTTP_PUT, "CUSTOM"]: + for port in [8761, 29796, 19528]: + for curl_config_fn in [None, self._TestCurlConfig]: + for read_timeout in [None, 0, 1, 123, 36000]: + self._TestInner(method, port, curl_config_fn, read_timeout) + + def _TestInner(self, method, port, curl_config_fn, read_timeout): + for response_code in [http.HTTP_OK, http.HttpNotFound.code, + http.HTTP_NOT_MODIFIED]: + for response_body in [None, "Hello World", + "Very Long\tContent here\n" * 171]: + for errmsg in [None, "error"]: + req = http.client.HttpClientRequest("localhost", port, method, + "/version", + curl_config_fn=curl_config_fn, + read_timeout=read_timeout) + curl = _FakeCurl() + pending = http.client._StartRequest(curl, req) + self.assertEqual(pending.GetCurlHandle(), curl) + self.assertEqual(pending.GetCurrentRequest(), req) + + # Check options + opts = curl.opts + self.assertEqual(opts.pop(pycurl.CUSTOMREQUEST), method) + self.assertEqual(opts.pop(pycurl.URL), + "https://localhost:%s/version" % port) + if read_timeout is None: + self.assertEqual(opts.pop(pycurl.TIMEOUT), 0) + else: + self.assertEqual(opts.pop(pycurl.TIMEOUT), read_timeout) + self.assertFalse(opts.pop(pycurl.VERBOSE)) + self.assertTrue(opts.pop(pycurl.NOSIGNAL)) + self.assertEqual(opts.pop(pycurl.USERAGENT), + http.HTTP_GANETI_VERSION) + self.assertEqual(opts.pop(pycurl.PROXY), "") + self.assertFalse(opts.pop(pycurl.POSTFIELDS)) + self.assertFalse(opts.pop(pycurl.HTTPHEADER)) + write_fn = opts.pop(pycurl.WRITEFUNCTION) + self.assertTrue(callable(write_fn)) + if hasattr(pycurl, "SSL_SESSIONID_CACHE"): + self.assertFalse(opts.pop(pycurl.SSL_SESSIONID_CACHE)) + if curl_config_fn: + self.assertEqual(opts.pop(pycurl.SSLKEYTYPE), "PEM") + else: + self.assertFalse(pycurl.SSLKEYTYPE in opts) + self.assertFalse(opts) + + if response_body is not None: + offset = 0 + while offset < len(response_body): + piece = response_body[offset:offset + 10] + write_fn(piece) + offset += len(piece) + + curl.info = { + pycurl.RESPONSE_CODE: response_code, + } + + # Finalize request + pending.Done(errmsg) + + self.assertFalse(curl.info) + + # Can only finalize once + self.assertRaises(AssertionError, pending.Done, True) + + if errmsg: + self.assertFalse(req.success) + else: + self.assertTrue(req.success) + self.assertEqual(req.error, errmsg) + self.assertEqual(req.resp_status_code, response_code) + if response_body is None: + self.assertEqual(req.resp_body, "") + else: + self.assertEqual(req.resp_body, response_body) + + # Check if resetting worked + assert not hasattr(curl, "reset") + opts = curl.opts + self.assertFalse(opts.pop(pycurl.POSTFIELDS)) + self.assertTrue(callable(opts.pop(pycurl.WRITEFUNCTION))) + self.assertFalse(opts) + + self.assertFalse(curl.opts, + msg="Previous checks did not consume all options") + assert id(opts) == id(curl.opts) - # But this one should have the same - cr1vglist = http.client.HttpClientRequest("localhost", 1234, - "GET", "/vg_list") - self.assertEqual(cr1.identity, cr1vglist.identity) + def _TestWrongTypes(self, *args, **kwargs): + req = http.client.HttpClientRequest(*args, **kwargs) + self.assertRaises(AssertionError, http.client._StartRequest, + _FakeCurl(), req) + def testWrongHostType(self): + self._TestWrongTypes(unicode("localhost"), 8080, "GET", "/version") + + def testWrongUrlType(self): + self._TestWrongTypes("localhost", 8080, "GET", unicode("/version")) + + def testWrongMethodType(self): + self._TestWrongTypes("localhost", 8080, unicode("GET"), "/version") + + def testWrongHeaderType(self): + self._TestWrongTypes("localhost", 8080, "GET", "/version", + headers={ + unicode("foo"): "bar", + }) + + def testWrongPostDataType(self): + self._TestWrongTypes("localhost", 8080, "GET", "/version", + post_data=unicode("verylongdata" * 100)) + + +class _EmptyCurlMulti: + def perform(self): + return (pycurl.E_MULTI_OK, 0) + + def info_read(self): + return (0, [], []) + + +class TestClientProcessRequests(unittest.TestCase): + def testEmpty(self): + requests = [] + http.client.ProcessRequests(requests, _curl=NotImplemented, + _curl_multi=_EmptyCurlMulti) + self.assertEqual(requests, []) + + +class TestProcessCurlRequests(unittest.TestCase): + class _FakeCurlMulti: + def __init__(self): + self.handles = [] + self.will_fail = [] + self._expect = ["perform"] + self._counter = itertools.count() + + def add_handle(self, curl): + assert curl not in self.handles + self.handles.append(curl) + if self._counter.next() % 3 == 0: + self.will_fail.append(curl) + + def remove_handle(self, curl): + self.handles.remove(curl) + + def perform(self): + assert self._expect.pop(0) == "perform" + + if self._counter.next() % 2 == 0: + self._expect.append("perform") + return (pycurl.E_CALL_MULTI_PERFORM, None) + + self._expect.append("info_read") + + return (pycurl.E_MULTI_OK, len(self.handles)) + + def info_read(self): + assert self._expect.pop(0) == "info_read" + successful = [] + failed = [] + if self.handles: + if self._counter.next() % 17 == 0: + curl = self.handles[0] + if curl in self.will_fail: + failed.append((curl, -1, "test error")) + else: + successful.append(curl) + remaining_messages = len(self.handles) % 3 + if remaining_messages > 0: + self._expect.append("info_read") + else: + self._expect.append("select") + else: + remaining_messages = 0 + self._expect.append("select") + return (remaining_messages, successful, failed) + + def select(self, timeout): + # Never compare floats for equality + assert timeout >= 0.95 and timeout <= 1.05 + assert self._expect.pop(0) == "select" + self._expect.append("perform") -class TestClient(unittest.TestCase): def test(self): - pool = http.client.HttpClientPool(None) - self.assertFalse(pool._pool) + requests = [_FakeCurl() for _ in range(10)] + multi = self._FakeCurlMulti() + for (curl, errmsg) in http.client._ProcessCurlRequests(multi, requests): + self.assertTrue(curl not in multi.handles) + if curl in multi.will_fail: + self.assertTrue("test error" in errmsg) + else: + self.assertTrue(errmsg is None) + self.assertFalse(multi.handles) + self.assertEqual(multi._expect, ["select"]) + + +class TestProcessRequests(unittest.TestCase): + class _DummyCurlMulti: + pass + + def testNoMonitor(self): + self._Test(False) + + def testWithMonitor(self): + self._Test(True) + + class _MonitorChecker: + def __init__(self): + self._monitor = None + + def GetMonitor(self): + return self._monitor + + def __call__(self, monitor): + assert callable(monitor.GetLockInfo) + self._monitor = monitor + + def _Test(self, use_monitor): + def cfg_fn(port, curl): + curl.opts["__port__"] = port + + def _LockCheckReset(monitor, curl): + self.assertTrue(monitor._lock.is_owned(shared=0), + msg="Lock must be owned in exclusive mode") + curl.opts["__lockcheck__"] = True + + requests = \ + [http.client.HttpClientRequest("localhost", i, "POST", "/version%s" % i, + curl_config_fn=compat.partial(cfg_fn, i)) + for i in range(15176, 15501)] + requests_count = len(requests) + + if use_monitor: + lock_monitor_cb = self._MonitorChecker() + else: + lock_monitor_cb = None + + def _ProcessRequests(multi, handles): + self.assertTrue(isinstance(multi, self._DummyCurlMulti)) + self.assertEqual(len(requests), len(handles)) + self.assertTrue(compat.all(isinstance(curl, _FakeCurl) + for curl in handles)) + + for idx, curl in enumerate(handles): + port = curl.opts["__port__"] + + if use_monitor: + # Check if lock information is correct + lock_info = lock_monitor_cb.GetMonitor().GetLockInfo(None) + expected = \ + [("rpc/localhost/version%s" % handle.opts["__port__"], None, + [threading.currentThread().getName()], None) + for handle in handles[idx:]] + self.assertEqual(sorted(lock_info), sorted(expected)) + + if port % 3 == 0: + response_code = http.HTTP_OK + msg = None + else: + response_code = http.HttpNotFound.code + msg = "test error" + + curl.info = { + pycurl.RESPONSE_CODE: response_code, + } + + # Unset options which will be reset + assert not hasattr(curl, "reset") + if use_monitor: + setattr(curl, "reset", + compat.partial(_LockCheckReset, lock_monitor_cb.GetMonitor(), + curl)) + else: + self.assertFalse(curl.opts.pop(pycurl.POSTFIELDS)) + self.assertTrue(callable(curl.opts.pop(pycurl.WRITEFUNCTION))) + + yield (curl, msg) + + if use_monitor: + self.assertTrue(compat.all(curl.opts["__lockcheck__"] + for curl in handles)) + + http.client.ProcessRequests(requests, lock_monitor_cb=lock_monitor_cb, + _curl=_FakeCurl, + _curl_multi=self._DummyCurlMulti, + _curl_process=_ProcessRequests) + for req in requests: + if req.port % 3 == 0: + self.assertTrue(req.success) + self.assertEqual(req.error, None) + else: + self.assertFalse(req.success) + self.assertTrue("test error" in req.error) + + # See if monitor was disabled + if use_monitor: + monitor = lock_monitor_cb.GetMonitor() + self.assertEqual(monitor._pending_fn, None) + self.assertEqual(monitor.GetLockInfo(None), []) + else: + self.assertEqual(lock_monitor_cb, None) + + self.assertEqual(len(requests), requests_count) + + def testBadRequest(self): + bad_request = http.client.HttpClientRequest("localhost", 27784, + "POST", "/version") + bad_request.success = False + + self.assertRaises(AssertionError, http.client.ProcessRequests, + [bad_request], _curl=NotImplemented, + _curl_multi=NotImplemented, _curl_process=NotImplemented) if __name__ == '__main__': diff --git a/test/ganeti.rpc_unittest.py b/test/ganeti.rpc_unittest.py index 6dee12ed8467b45790e4ef5f06db01b24b4e0ac8..68d950d42c5d90ee6a7291fc6787f07404c2fca3 100755 --- a/test/ganeti.rpc_unittest.py +++ b/test/ganeti.rpc_unittest.py @@ -46,12 +46,13 @@ class TestTimeouts(unittest.TestCase): rpc._TIMEOUTS[name] > 0)]) -class FakeHttpPool: +class _FakeRequestProcessor: def __init__(self, response_fn): self._response_fn = response_fn self.reqcount = 0 - def ProcessRequests(self, reqs, lock_monitor_cb=None): + def __call__(self, reqs, lock_monitor_cb=None): + assert lock_monitor_cb is None or callable(lock_monitor_cb) for req in reqs: self.reqcount += 1 self._response_fn(req) @@ -80,9 +81,9 @@ class TestRpcProcessor(unittest.TestCase): def testVersionSuccess(self): resolver = rpc._StaticResolver(["127.0.0.1"]) - pool = FakeHttpPool(self._GetVersionResponse) + http_proc = _FakeRequestProcessor(self._GetVersionResponse) proc = rpc._RpcProcessor(resolver, 24094) - result = proc(["localhost"], "version", None, http_pool=pool) + result = proc(["localhost"], "version", None, _req_process_fn=http_proc) self.assertEqual(result.keys(), ["localhost"]) lhresp = result["localhost"] self.assertFalse(lhresp.offline) @@ -91,7 +92,7 @@ class TestRpcProcessor(unittest.TestCase): self.assertEqual(lhresp.payload, 123) self.assertEqual(lhresp.call, "version") lhresp.Raise("should not raise") - self.assertEqual(pool.reqcount, 1) + self.assertEqual(http_proc.reqcount, 1) def _ReadTimeoutResponse(self, req): self.assertEqual(req.host, "192.0.2.13") @@ -104,9 +105,9 @@ class TestRpcProcessor(unittest.TestCase): def testReadTimeout(self): resolver = rpc._StaticResolver(["192.0.2.13"]) - pool = FakeHttpPool(self._ReadTimeoutResponse) + http_proc = _FakeRequestProcessor(self._ReadTimeoutResponse) proc = rpc._RpcProcessor(resolver, 19176) - result = proc(["node31856"], "version", None, http_pool=pool, + result = proc(["node31856"], "version", None, _req_process_fn=http_proc, read_timeout=12356) self.assertEqual(result.keys(), ["node31856"]) lhresp = result["node31856"] @@ -116,13 +117,13 @@ class TestRpcProcessor(unittest.TestCase): self.assertEqual(lhresp.payload, -1) self.assertEqual(lhresp.call, "version") lhresp.Raise("should not raise") - self.assertEqual(pool.reqcount, 1) + self.assertEqual(http_proc.reqcount, 1) def testOfflineNode(self): resolver = rpc._StaticResolver([rpc._OFFLINE]) - pool = FakeHttpPool(NotImplemented) + http_proc = _FakeRequestProcessor(NotImplemented) proc = rpc._RpcProcessor(resolver, 30668) - result = proc(["n17296"], "version", None, http_pool=pool) + result = proc(["n17296"], "version", None, _req_process_fn=http_proc) self.assertEqual(result.keys(), ["n17296"]) lhresp = result["n17296"] self.assertTrue(lhresp.offline) @@ -137,7 +138,7 @@ class TestRpcProcessor(unittest.TestCase): # No message self.assertRaises(errors.OpExecError, lhresp.Raise, None) - self.assertEqual(pool.reqcount, 0) + self.assertEqual(http_proc.reqcount, 0) def _GetMultiVersionResponse(self, req): self.assert_(req.host.startswith("node")) @@ -150,9 +151,9 @@ class TestRpcProcessor(unittest.TestCase): def testMultiVersionSuccess(self): nodes = ["node%s" % i for i in range(50)] resolver = rpc._StaticResolver(nodes) - pool = FakeHttpPool(self._GetMultiVersionResponse) + http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse) proc = rpc._RpcProcessor(resolver, 23245) - result = proc(nodes, "version", None, http_pool=pool) + result = proc(nodes, "version", None, _req_process_fn=http_proc) self.assertEqual(sorted(result.keys()), sorted(nodes)) for name in nodes: @@ -164,7 +165,7 @@ class TestRpcProcessor(unittest.TestCase): self.assertEqual(lhresp.call, "version") lhresp.Raise("should not raise") - self.assertEqual(pool.reqcount, len(nodes)) + self.assertEqual(http_proc.reqcount, len(nodes)) def _GetVersionResponseFail(self, errinfo, req): self.assertEqual(req.path, "/version") @@ -176,8 +177,11 @@ class TestRpcProcessor(unittest.TestCase): resolver = rpc._StaticResolver(["aef9ur4i.example.com"]) proc = rpc._RpcProcessor(resolver, 5903) for errinfo in [None, "Unknown error"]: - pool = FakeHttpPool(compat.partial(self._GetVersionResponseFail, errinfo)) - result = proc(["aef9ur4i.example.com"], "version", None, http_pool=pool) + http_proc = \ + _FakeRequestProcessor(compat.partial(self._GetVersionResponseFail, + errinfo)) + result = proc(["aef9ur4i.example.com"], "version", None, + _req_process_fn=http_proc) self.assertEqual(result.keys(), ["aef9ur4i.example.com"]) lhresp = result["aef9ur4i.example.com"] self.assertFalse(lhresp.offline) @@ -186,7 +190,7 @@ class TestRpcProcessor(unittest.TestCase): self.assertFalse(lhresp.payload) self.assertEqual(lhresp.call, "version") self.assertRaises(errors.OpExecError, lhresp.Raise, "failed") - self.assertEqual(pool.reqcount, 1) + self.assertEqual(http_proc.reqcount, 1) def _GetHttpErrorResponse(self, httperrnodes, failnodes, req): self.assertEqual(req.path, "/vg_list") @@ -222,9 +226,10 @@ class TestRpcProcessor(unittest.TestCase): self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29) proc = rpc._RpcProcessor(resolver, 15165) - pool = FakeHttpPool(compat.partial(self._GetHttpErrorResponse, - httperrnodes, failnodes)) - result = proc(nodes, "vg_list", None, http_pool=pool) + http_proc = \ + _FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse, + httperrnodes, failnodes)) + result = proc(nodes, "vg_list", None, _req_process_fn=http_proc) self.assertEqual(sorted(result.keys()), sorted(nodes)) for name in nodes: @@ -245,7 +250,7 @@ class TestRpcProcessor(unittest.TestCase): self.assertEqual(lhresp.payload, hash(name)) lhresp.Raise("should not raise") - self.assertEqual(pool.reqcount, len(nodes)) + self.assertEqual(http_proc.reqcount, len(nodes)) def _GetInvalidResponseA(self, req): self.assertEqual(req.path, "/version") @@ -265,8 +270,9 @@ class TestRpcProcessor(unittest.TestCase): proc = rpc._RpcProcessor(resolver, 19978) for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]: - pool = FakeHttpPool(fn) - result = proc(["oqo7lanhly.example.com"], "version", None, http_pool=pool) + http_proc = _FakeRequestProcessor(fn) + result = proc(["oqo7lanhly.example.com"], "version", None, + _req_process_fn=http_proc) self.assertEqual(result.keys(), ["oqo7lanhly.example.com"]) lhresp = result["oqo7lanhly.example.com"] self.assertFalse(lhresp.offline) @@ -275,7 +281,7 @@ class TestRpcProcessor(unittest.TestCase): self.assertFalse(lhresp.payload) self.assertEqual(lhresp.call, "version") self.assertRaises(errors.OpExecError, lhresp.Raise, "failed") - self.assertEqual(pool.reqcount, 1) + self.assertEqual(http_proc.reqcount, 1) def _GetBodyTestResponse(self, test_data, req): self.assertEqual(req.host, "192.0.2.84") @@ -292,10 +298,11 @@ class TestRpcProcessor(unittest.TestCase): "xyz": range(10), } resolver = rpc._StaticResolver(["192.0.2.84"]) - pool = FakeHttpPool(compat.partial(self._GetBodyTestResponse, test_data)) + http_proc = _FakeRequestProcessor(compat.partial(self._GetBodyTestResponse, + test_data)) proc = rpc._RpcProcessor(resolver, 18700) body = serializer.DumpJson(test_data) - result = proc(["node19759"], "upload_file", body, http_pool=pool) + result = proc(["node19759"], "upload_file", body, _req_process_fn=http_proc) self.assertEqual(result.keys(), ["node19759"]) lhresp = result["node19759"] self.assertFalse(lhresp.offline) @@ -304,7 +311,7 @@ class TestRpcProcessor(unittest.TestCase): self.assertEqual(lhresp.payload, None) self.assertEqual(lhresp.call, "upload_file") lhresp.Raise("should not raise") - self.assertEqual(pool.reqcount, 1) + self.assertEqual(http_proc.reqcount, 1) class TestSsconfResolver(unittest.TestCase):