diff --git a/lib/http.py b/lib/http.py index ed86aaa6d6f88a5a97c2d8fd40dda156af43ff42..50012eb4afb99b4d72cb3dd563af3320783bce8d 100644 --- a/lib/http.py +++ b/lib/http.py @@ -94,7 +94,12 @@ class SocketClosed(socket.error): pass -class ResponseError(Exception): +class _HttpClientError(Exception): + """Internal exception for HTTP client errors. + + This should only be used for internal error reporting. + + """ pass @@ -687,7 +692,8 @@ class HttpServer(_HttpSocketBase): class HttpClientRequest(object): - def __init__(self, host, port, method, path, headers=None, post_data=None): + def __init__(self, host, port, method, path, headers=None, post_data=None, + ssl_key_path=None, ssl_cert_path=None, ssl_verify_peer=False): """Describes an HTTP request. @type host: string @@ -712,6 +718,9 @@ class HttpClientRequest(object): self.host = host self.port = port + self.ssl_key_path = ssl_key_path + self.ssl_cert_path = ssl_cert_path + self.ssl_verify_peer = ssl_verify_peer self.method = method self.path = path self.headers = headers @@ -728,7 +737,7 @@ class HttpClientRequest(object): self.resp_body = None -class HttpClientRequestExecutor(object): +class HttpClientRequestExecutor(_HttpSocketBase): # Default headers DEFAULT_HEADERS = { HTTP_USER_AGENT: HTTP_GANETI_VERSION, @@ -740,7 +749,7 @@ class HttpClientRequestExecutor(object): STATUS_LINE_LENGTH_MAX = 512 HEADER_LENGTH_MAX = 4 * 1024 - # Timeouts in seconds + # Timeouts in seconds for socket layer # TODO: Make read timeout configurable per OpCode CONNECT_TIMEOUT = 5.0 WRITE_TIMEOUT = 10 @@ -753,6 +762,12 @@ class HttpClientRequestExecutor(object): PS_BODY = "body" PS_COMPLETE = "complete" + # Socket operations + (OP_SEND, + OP_RECV, + OP_CLOSE_CHECK, + OP_SHUTDOWN) = range(4) + def __init__(self, req): """Initializes the HttpClientRequestExecutor class. @@ -760,6 +775,8 @@ class HttpClientRequestExecutor(object): @param req: Request object """ + _HttpSocketBase.__init__(self) + self.request = req self.parser_status = self.PS_STATUS_LINE @@ -770,11 +787,11 @@ class HttpClientRequestExecutor(object): self.poller = select.poll() - # TODO: SSL - try: # TODO: Implement connection caching/keep-alive - self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.sock = self._CreateSocket(req.ssl_key_path, + req.ssl_cert_path, + req.ssl_verify_peer) # Disable Python's timeout self.sock.settimeout(None) @@ -801,7 +818,7 @@ class HttpClientRequestExecutor(object): req.success = True req.error = None - except ResponseError, err: + except _HttpClientError, err: req.success = False req.error = str(err) @@ -849,7 +866,7 @@ class HttpClientRequestExecutor(object): line = self.request.resp_status_line if not line: - raise ResponseError("Empty status line") + raise _HttpClientError("Empty status line") try: [version, status, reason] = line.split(None, 2) @@ -866,8 +883,8 @@ class HttpClientRequestExecutor(object): if version not in (HTTP_1_0, HTTP_1_1): # We do not support HTTP/0.9, despite the specification requiring it # (RFC2616, section 19.6) - raise ResponseError("Only HTTP/1.0 and HTTP/1.1 are supported (%r)" % - line) + raise _HttpClientError("Only HTTP/1.0 and HTTP/1.1 are supported (%r)" % + line) # The status code is a three-digit number try: @@ -878,7 +895,7 @@ class HttpClientRequestExecutor(object): status = -1 if status == -1: - raise ResponseError("Invalid status code (%r)" % line) + raise _HttpClientError("Invalid status code (%r)" % line) self.request.resp_version = version self.request.resp_status = status @@ -949,13 +966,13 @@ class HttpClientRequestExecutor(object): def _CheckStatusLineLength(self, length): if length > self.STATUS_LINE_LENGTH_MAX: - raise ResponseError("Status line longer than %d chars" % - self.STATUS_LINE_LENGTH_MAX) + raise _HttpClientError("Status line longer than %d chars" % + self.STATUS_LINE_LENGTH_MAX) def _CheckHeaderLength(self, length): if length > self.HEADER_LENGTH_MAX: - raise ResponseError("Headers longer than %d chars" % - self.HEADER_LENGTH_MAX) + raise _HttpClientError("Headers longer than %d chars" % + self.HEADER_LENGTH_MAX) def _ParseBuffer(self, buf, eof): """Main function for HTTP response state machine. @@ -1053,14 +1070,127 @@ class HttpClientRequestExecutor(object): finally: self.poller.unregister(self.sock) + def _SocketOperation(self, op, arg1, error_msg, timeout_msg): + """Wrapper around socket functions. + + This function abstracts error handling for socket operations, especially + for the complicated interaction with OpenSSL. + + """ + if op == self.OP_SEND: + event_poll = select.POLLOUT + event_check = select.POLLOUT + timeout = self.WRITE_TIMEOUT + + elif op in (self.OP_RECV, self.OP_CLOSE_CHECK): + event_poll = select.POLLIN + event_check = select.POLLIN | select.POLLPRI + if op == self.OP_CLOSE_CHECK: + timeout = self.CLOSE_TIMEOUT + else: + timeout = self.READ_TIMEOUT + + elif op == self.OP_SHUTDOWN: + event_poll = None + event_check = None + + # The timeout is only used when OpenSSL requests polling for a condition. + # It is not advisable to have no timeout for shutdown. + timeout = self.WRITE_TIMEOUT + + else: + raise AssertionError("Invalid socket operation") + + # No override by default + event_override = 0 + + while True: + # Poll only for certain operations and when asked for by an override + if (event_override or + op in (self.OP_SEND, self.OP_RECV, self.OP_CLOSE_CHECK)): + if event_override: + wait_for_event = event_override + else: + wait_for_event = event_poll + + event = self._WaitForCondition(wait_for_event, timeout) + if event is None: + raise _HttpClientTimeout(timeout_msg) + + if (op == self.OP_RECV and + event & (select.POLLNVAL | select.POLLHUP | select.POLLERR)): + return "" + + if not event & wait_for_event: + continue + + # Reset override + event_override = 0 + + try: + try: + if op == self.OP_SEND: + return self.sock.send(arg1) + + elif op in (self.OP_RECV, self.OP_CLOSE_CHECK): + return self.sock.recv(arg1) + + elif op == self.OP_SHUTDOWN: + if self._using_ssl: + # PyOpenSSL's shutdown() doesn't take arguments + return self.sock.shutdown() + else: + return self.sock.shutdown(arg1) + + except OpenSSL.SSL.WantWriteError: + # OpenSSL wants to write, poll for POLLOUT + event_override = select.POLLOUT + continue + + except OpenSSL.SSL.WantReadError: + # OpenSSL wants to read, poll for POLLIN + event_override = select.POLLIN | select.POLLPRI + continue + + except OpenSSL.SSL.WantX509LookupError: + continue + + except OpenSSL.SSL.SysCallError, err: + if op == self.OP_SEND: + # arg1 is the data when writing + if err.args and err.args[0] == -1 and arg1 == "": + # errors when writing empty strings are expected + # and can be ignored + return 0 + + elif op == self.OP_RECV: + if err.args == (-1, _SSL_UNEXPECTED_EOF): + return "" + + raise socket.error(err.args) + + except OpenSSL.SSL.Error, err: + raise socket.error(err.args) + + except socket.error, err: + if err.args and err.args[0] == errno.EAGAIN: + # Ignore EAGAIN + continue + + raise _HttpClientError("%s: %s" % (error_msg, str(err))) + def _Connect(self): """Non-blocking connect to host with timeout. """ connected = False while True: - connect_error = self.sock.connect_ex((self.request.host, - self.request.port)) + try: + connect_error = self.sock.connect_ex((self.request.host, + self.request.port)) + except socket.gaierror, err: + raise _HttpClientError("Connection failed: %s" % str(err)) + if connect_error == errno.EINTR: # Mask signals pass @@ -1074,20 +1204,20 @@ class HttpClientRequestExecutor(object): # Connection started break - raise ResponseError("Connection failed (%s: %s)" % - (connect_error, os.strerror(connect_error))) + raise _HttpClientError("Connection failed (%s: %s)" % + (connect_error, os.strerror(connect_error))) if not connected: # Wait for connection event = self._WaitForCondition(select.POLLOUT, self.CONNECT_TIMEOUT) if event is None: - raise ResponseError("Timeout while connecting to server") + raise _HttpClientError("Timeout while connecting to server") # Get error code connect_error = self.sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) if connect_error != 0: - raise ResponseError("Connection failed (%s: %s)" % - (connect_error, os.strerror(connect_error))) + raise _HttpClientError("Connection failed (%s: %s)" % + (connect_error, os.strerror(connect_error))) # Enable TCP keep-alive self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) @@ -1103,21 +1233,12 @@ class HttpClientRequestExecutor(object): buf = self._BuildRequest() while buf: - event = self._WaitForCondition(select.POLLOUT, self.WRITE_TIMEOUT) - if event is None: - raise ResponseError("Timeout while sending request") + # Send only 4 KB at a time + data = buf[:4096] - try: - # Only send 4 KB at a time - data = buf[:4096] - - sent = self.sock.send(data) - except socket.error, err: - if err.args and err.args[0] == errno.EAGAIN: - # Ignore EAGAIN - continue - - raise ResponseError("Sending request failed: %s" % str(err)) + sent = self._SocketOperation(self.OP_SEND, data, + "Error while sending request", + "Timeout while sending request") # Remove sent bytes buf = buf[sent:] @@ -1133,26 +1254,13 @@ class HttpClientRequestExecutor(object): buf = "" eof = False while self.parser_status != self.PS_COMPLETE: - event = self._WaitForCondition(select.POLLIN, self.READ_TIMEOUT) - if event is None: - raise ResponseError("Timeout while reading response") - - if event & (select.POLLIN | select.POLLPRI): - try: - data = self.sock.recv(4096) - except socket.error, err: - if err.args and err.args[0] == errno.EAGAIN: - # Ignore EAGAIN - continue + data = self._SocketOperation(self.OP_RECV, 4096, + "Error while reading response", + "Timeout while reading response") - raise ResponseError("Reading response failed: %s" % str(err)) - - if data: - buf += data - else: - eof = True - - if event & (select.POLLNVAL | select.POLLHUP | select.POLLERR): + if data: + buf += data + else: eof = True # Do some parsing and error checking while more data arrives @@ -1162,7 +1270,7 @@ class HttpClientRequestExecutor(object): if (eof and self.parser_status in (self.PS_STATUS_LINE, self.PS_HEADERS)): - raise ResponseError("Connection closed prematurely") + raise _HttpClientError("Connection closed prematurely") # Parse rest buf = self._ParseBuffer(buf, True) @@ -1176,21 +1284,19 @@ class HttpClientRequestExecutor(object): """ if self.server_will_close and not force: # Wait for server to close - event = self._WaitForCondition(select.POLLIN, self.CLOSE_TIMEOUT) - if event is None: - # Server didn't close connection within CLOSE_TIMEOUT + try: + # Check whether it's actually closed + if not self._SocketOperation(self.OP_CLOSE_CHECK, 1, + "Error", "Timeout"): + return + except (socket.error, _HttpClientError): + # Ignore errors at this stage pass - else: - try: - # Check whether it's actually closed - if not self.sock.recv(1): - return - except socket.error, err: - # Ignore errors at this stage - pass # Close the connection from our side - self.sock.shutdown(socket.SHUT_RDWR) + self._SocketOperation(self.OP_SHUTDOWN, socket.SHUT_RDWR, + "Error while shutting down connection", + "Timeout while shutting down connection") class HttpClientWorker(workerpool.BaseWorker):