From b14f759ee86fc08c50588d30360d386a21df35e0 Mon Sep 17 00:00:00 2001
From: Michael Hanselmann <hansmi@google.com>
Date: Mon, 10 Nov 2008 12:38:26 +0000
Subject: [PATCH] ganeti.http: Move SSL socket creation into base class

The same code will be used by the HTTP client.

Reviewed-by: iustinp
---
 lib/http.py | 108 ++++++++++++++++++++++++++++++++++------------------
 1 file changed, 71 insertions(+), 37 deletions(-)

diff --git a/lib/http.py b/lib/http.py
index 46838358e..ed86aaa6d 100644
--- a/lib/http.py
+++ b/lib/http.py
@@ -209,6 +209,70 @@ class HTTPJsonConverter:
     return serializer.LoadJson(data)
 
 
+class _HttpSocketBase(object):
+  """Base class for HTTP server and client.
+
+  """
+  def __init__(self):
+    self._using_ssl = None
+    self._ssl_cert = None
+    self._ssl_key = None
+
+  def _CreateSocket(self, ssl_key_path, ssl_cert_path, ssl_verify_peer):
+    """Creates a TCP socket and initializes SSL if needed.
+
+    @type ssl_key_path: string
+    @param ssl_key_path: Path to file containing SSL key in PEM format
+    @type ssl_cert_path: string
+    @param ssl_cert_path: Path to file containing SSL certificate in PEM format
+    @type ssl_verify_peer: bool
+    @param ssl_verify_peer: Whether to require client certificate and compare
+                            it with our certificate
+
+    """
+    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+
+    # Should we enable SSL?
+    self._using_ssl = (ssl_cert_path and ssl_key_path)
+
+    if not self._using_ssl:
+      return sock
+
+    ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
+    ctx.set_options(OpenSSL.SSL.OP_NO_SSLv2)
+
+    ssl_key_pem = utils.ReadFile(ssl_key_path)
+    ssl_cert_pem = utils.ReadFile(ssl_cert_path)
+
+    cr = OpenSSL.crypto
+    self._ssl_cert = cr.load_certificate(cr.FILETYPE_PEM, ssl_cert_pem)
+    self._ssl_key = cr.load_privatekey(cr.FILETYPE_PEM, ssl_key_pem)
+    del cr
+
+    ctx.use_privatekey(self._ssl_key)
+    ctx.use_certificate(self._ssl_cert)
+    ctx.check_privatekey()
+
+    if ssl_verify_peer:
+      ctx.set_verify(OpenSSL.SSL.VERIFY_PEER |
+                     OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
+                     self._SSLVerifyCallback)
+
+    return OpenSSL.SSL.Connection(ctx, sock)
+
+  def _SSLVerifyCallback(self, conn, cert, errnum, errdepth, ok):
+    """Verify the certificate provided by the peer
+
+    We only compare fingerprints. The client must use the same certificate as
+    we do on our side.
+
+    """
+    assert self._ssl_cert and self._ssl_key, "SSL not initialized"
+
+    return (self._ssl_cert.digest("sha1") == cert.digest("sha1") and
+            self._ssl_cert.digest("md5") == cert.digest("md5"))
+
+
 class _HttpConnectionHandler(object):
   """Implements server side of HTTP
 
@@ -487,7 +551,7 @@ class _HttpConnectionHandler(object):
     logging.debug("HTTP POST data: %s", self.request_post_data)
 
 
-class HttpServer(object):
+class HttpServer(_HttpSocketBase):
   """Generic HTTP server class
 
   Users of this class must subclass it and override the HandleRequest function.
@@ -514,57 +578,27 @@ class HttpServer(object):
                             it with our certificate
 
     """
+    _HttpSocketBase.__init__(self)
+
     self.mainloop = mainloop
     self.local_address = local_address
     self.port = port
 
-    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-
-    if ssl_cert_path and ssl_key_path:
-      ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
-      ctx.set_options(OpenSSL.SSL.OP_NO_SSLv2)
-
-      ssl_key_pem = utils.ReadFile(ssl_key_path)
-      ssl_cert_pem = utils.ReadFile(ssl_cert_path)
-
-      cr = OpenSSL.crypto
-      self._ssl_cert = cr.load_certificate(cr.FILETYPE_PEM, ssl_cert_pem)
-      self._ssl_key = cr.load_privatekey(cr.FILETYPE_PEM, ssl_key_pem)
-      del cr
-
-      ctx.use_privatekey(self._ssl_key)
-      ctx.use_certificate(self._ssl_cert)
-      ctx.check_privatekey()
+    self.socket = self._CreateSocket(ssl_key_path, ssl_cert_path, ssl_verify_peer)
 
-      if ssl_verify_peer:
-        ctx.set_verify(OpenSSL.SSL.VERIFY_PEER |
-                       OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
-                       self._VerifyCallback)
+    # Allow port to be reused
+    self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
 
-      self.socket = OpenSSL.SSL.Connection(ctx, sock)
+    if self._using_ssl:
       self._fileio_class = _SSLFileObject
     else:
-      self.socket = sock
       self._fileio_class = socket._fileobject
 
-    # Allow port to be reused
-    self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
-
     self._children = []
 
     mainloop.RegisterIO(self, self.socket.fileno(), select.POLLIN)
     mainloop.RegisterSignal(self)
 
-  def _VerifyCallback(self, conn, cert, errno, errdepth, ok):
-    """Verify the certificate provided by the peer
-
-    We only compare fingerprints. The client must use the same certificate as
-    we do on the server side.
-
-    """
-    return (self._ssl_cert.digest("sha1") == cert.digest("sha1") and
-            self._ssl_cert.digest("md5") == cert.digest("md5"))
-
   def Start(self):
     self.socket.bind((self.local_address, self.port))
     self.socket.listen(5)
-- 
GitLab