From a5eba783d458af8103d89e5e5b541bbf835704d8 Mon Sep 17 00:00:00 2001
From: Michael Hanselmann <hansmi@google.com>
Date: Fri, 30 Jul 2010 20:44:07 +0200
Subject: [PATCH] RAPI client: Don't re-use PycURL object

With this patch, a new PycURL object will be created for each request.
This should make the RAPI client safe for simultaneous calls from
multiple threads. Unittests are adjusted accordingly.

An unnecessary variable assignment is also removed from the unittest
script.

This patch survived a small QA and unittests.

Signed-off-by: Michael Hanselmann <hansmi@google.com>
Reviewed-by: Guido Trotter <ultrotter@google.com>
---
 lib/rapi/client.py                  | 50 ++++++++++------
 test/ganeti.rapi.client_unittest.py | 89 +++++++++++++++++------------
 2 files changed, 85 insertions(+), 54 deletions(-)

diff --git a/lib/rapi/client.py b/lib/rapi/client.py
index 444a1d233..457337563 100644
--- a/lib/rapi/client.py
+++ b/lib/rapi/client.py
@@ -243,7 +243,7 @@ class GanetiRapiClient(object):
 
   def __init__(self, host, port=GANETI_RAPI_PORT,
                username=None, password=None, logger=logging,
-               curl_config_fn=None, curl=None):
+               curl_config_fn=None, curl_factory=None):
     """Initializes this class.
 
     @type host: string
@@ -259,14 +259,28 @@ class GanetiRapiClient(object):
     @param logger: Logging object
 
     """
-    self._host = host
-    self._port = port
+    self._username = username
+    self._password = password
     self._logger = logger
+    self._curl_config_fn = curl_config_fn
+    self._curl_factory = curl_factory
 
     self._base_url = "https://%s:%s" % (host, port)
 
-    # Create pycURL object if not supplied
-    if not curl:
+    if username is not None:
+      if password is None:
+        raise Error("Password not specified")
+    elif password:
+      raise Error("Specified password without username")
+
+  def _CreateCurl(self):
+    """Creates a cURL object.
+
+    """
+    # Create pycURL object if no factory is provided
+    if self._curl_factory:
+      curl = self._curl_factory()
+    else:
       curl = pycurl.Curl()
 
     # Default cURL settings
@@ -282,20 +296,20 @@ class GanetiRapiClient(object):
       "Content-type: %s" % HTTP_APP_JSON,
       ])
 
-    # Setup authentication
-    if username is not None:
-      if password is None:
-        raise Error("Password not specified")
+    assert ((self._username is None and self._password is None) ^
+            (self._username is not None and self._password is not None))
+
+    if self._username:
+      # Setup authentication
       curl.setopt(pycurl.HTTPAUTH, pycurl.HTTPAUTH_BASIC)
-      curl.setopt(pycurl.USERPWD, str("%s:%s" % (username, password)))
-    elif password:
-      raise Error("Specified password without username")
+      curl.setopt(pycurl.USERPWD,
+                  str("%s:%s" % (self._username, self._password)))
 
     # Call external configuration function
-    if curl_config_fn:
-      curl_config_fn(curl, logger)
+    if self._curl_config_fn:
+      self._curl_config_fn(curl, self._logger)
 
-    self._curl = curl
+    return curl
 
   @staticmethod
   def _EncodeQuery(query):
@@ -349,7 +363,7 @@ class GanetiRapiClient(object):
     """
     assert path.startswith("/")
 
-    curl = self._curl
+    curl = self._CreateCurl()
 
     if content is not None:
       encoded_content = self._json_encoder.encode(content)
@@ -364,8 +378,8 @@ class GanetiRapiClient(object):
 
     url = "".join(urlparts)
 
-    self._logger.debug("Sending request %s %s to %s:%s (content=%r)",
-                       method, url, self._host, self._port, encoded_content)
+    self._logger.debug("Sending request %s %s (content=%r)",
+                       method, url, encoded_content)
 
     # Buffer for response
     encoded_resp_body = StringIO()
diff --git a/test/ganeti.rapi.client_unittest.py b/test/ganeti.rapi.client_unittest.py
index c97b384e6..eaa29e55b 100755
--- a/test/ganeti.rapi.client_unittest.py
+++ b/test/ganeti.rapi.client_unittest.py
@@ -170,11 +170,11 @@ def _FakeGnuTlsPycurlVersion():
 
 class TestExtendedConfig(unittest.TestCase):
   def testAuth(self):
-    curl = FakeCurl(RapiMock())
     cl = client.GanetiRapiClient("master.example.com",
                                  username="user", password="pw",
-                                 curl=curl)
+                                 curl_factory=lambda: FakeCurl(RapiMock()))
 
+    curl = cl._CreateCurl()
     self.assertEqual(curl.getopt(pycurl.HTTPAUTH), pycurl.HTTPAUTH_BASIC)
     self.assertEqual(curl.getopt(pycurl.USERPWD), "user:pw")
 
@@ -209,10 +209,12 @@ class TestExtendedConfig(unittest.TestCase):
                                              verify_hostname=verify_hostname,
                                              _pycurl_version_fn=pcverfn)
 
-            curl = FakeCurl(RapiMock())
+            curl_factory = lambda: FakeCurl(RapiMock())
             cl = client.GanetiRapiClient("master.example.com",
-                                         curl_config_fn=cfgfn, curl=curl)
+                                         curl_config_fn=cfgfn,
+                                         curl_factory=curl_factory)
 
+            curl = cl._CreateCurl()
             self.assertEqual(curl.getopt(pycurl.PROXY), proxy)
             self.assertEqual(curl.getopt(pycurl.NOSIGNAL), not use_signal)
 
@@ -224,10 +226,11 @@ class TestExtendedConfig(unittest.TestCase):
   def testNoCertVerify(self):
     cfgfn = client.GenericCurlConfig()
 
-    curl = FakeCurl(RapiMock())
+    curl_factory = lambda: FakeCurl(RapiMock())
     cl = client.GanetiRapiClient("master.example.com", curl_config_fn=cfgfn,
-                                 curl=curl)
+                                 curl_factory=curl_factory)
 
+    curl = cl._CreateCurl()
     self.assertFalse(curl.getopt(pycurl.SSL_VERIFYPEER))
     self.assertFalse(curl.getopt(pycurl.CAINFO))
     self.assertFalse(curl.getopt(pycurl.CAPATH))
@@ -235,10 +238,11 @@ class TestExtendedConfig(unittest.TestCase):
   def testCertVerifyCurlBundle(self):
     cfgfn = client.GenericCurlConfig(use_curl_cabundle=True)
 
-    curl = FakeCurl(RapiMock())
+    curl_factory = lambda: FakeCurl(RapiMock())
     cl = client.GanetiRapiClient("master.example.com", curl_config_fn=cfgfn,
-                                 curl=curl)
+                                 curl_factory=curl_factory)
 
+    curl = cl._CreateCurl()
     self.assert_(curl.getopt(pycurl.SSL_VERIFYPEER))
     self.assertFalse(curl.getopt(pycurl.CAINFO))
     self.assertFalse(curl.getopt(pycurl.CAPATH))
@@ -247,10 +251,11 @@ class TestExtendedConfig(unittest.TestCase):
     mycert = "/tmp/some/UNUSED/cert/file.pem"
     cfgfn = client.GenericCurlConfig(cafile=mycert)
 
-    curl = FakeCurl(RapiMock())
+    curl_factory = lambda: FakeCurl(RapiMock())
     cl = client.GanetiRapiClient("master.example.com", curl_config_fn=cfgfn,
-                                 curl=curl)
+                                 curl_factory=curl_factory)
 
+    curl = cl._CreateCurl()
     self.assert_(curl.getopt(pycurl.SSL_VERIFYPEER))
     self.assertEqual(curl.getopt(pycurl.CAINFO), mycert)
     self.assertFalse(curl.getopt(pycurl.CAPATH))
@@ -261,10 +266,11 @@ class TestExtendedConfig(unittest.TestCase):
     cfgfn = client.GenericCurlConfig(capath=certdir,
                                      _pycurl_version_fn=pcverfn)
 
-    curl = FakeCurl(RapiMock())
+    curl_factory = lambda: FakeCurl(RapiMock())
     cl = client.GanetiRapiClient("master.example.com", curl_config_fn=cfgfn,
-                                 curl=curl)
+                                 curl_factory=curl_factory)
 
+    curl = cl._CreateCurl()
     self.assert_(curl.getopt(pycurl.SSL_VERIFYPEER))
     self.assertEqual(curl.getopt(pycurl.CAPATH), certdir)
     self.assertFalse(curl.getopt(pycurl.CAINFO))
@@ -275,9 +281,11 @@ class TestExtendedConfig(unittest.TestCase):
     cfgfn = client.GenericCurlConfig(capath=certdir,
                                      _pycurl_version_fn=pcverfn)
 
-    curl = FakeCurl(RapiMock())
-    self.assertRaises(client.Error, client.GanetiRapiClient,
-                      "master.example.com", curl_config_fn=cfgfn, curl=curl)
+    curl_factory = lambda: FakeCurl(RapiMock())
+    cl = client.GanetiRapiClient("master.example.com", curl_config_fn=cfgfn,
+                                 curl_factory=curl_factory)
+
+    self.assertRaises(client.Error, cl._CreateCurl)
 
   def testCertVerifyNoSsl(self):
     certdir = "/tmp/some/UNUSED/cert/directory"
@@ -285,9 +293,11 @@ class TestExtendedConfig(unittest.TestCase):
     cfgfn = client.GenericCurlConfig(capath=certdir,
                                      _pycurl_version_fn=pcverfn)
 
-    curl = FakeCurl(RapiMock())
-    self.assertRaises(client.Error, client.GanetiRapiClient,
-                      "master.example.com", curl_config_fn=cfgfn, curl=curl)
+    curl_factory = lambda: FakeCurl(RapiMock())
+    cl = client.GanetiRapiClient("master.example.com", curl_config_fn=cfgfn,
+                                 curl_factory=curl_factory)
+
+    self.assertRaises(client.Error, cl._CreateCurl)
 
   def testCertVerifyFancySsl(self):
     certdir = "/tmp/some/UNUSED/cert/directory"
@@ -295,9 +305,11 @@ class TestExtendedConfig(unittest.TestCase):
     cfgfn = client.GenericCurlConfig(capath=certdir,
                                      _pycurl_version_fn=pcverfn)
 
-    curl = FakeCurl(RapiMock())
-    self.assertRaises(NotImplementedError, client.GanetiRapiClient,
-                      "master.example.com", curl_config_fn=cfgfn, curl=curl)
+    curl_factory = lambda: FakeCurl(RapiMock())
+    cl = client.GanetiRapiClient("master.example.com", curl_config_fn=cfgfn,
+                                 curl_factory=curl_factory)
+
+    self.assertRaises(NotImplementedError, cl._CreateCurl)
 
   def testCertVerifyCapath(self):
     for connect_timeout in [None, 1, 5, 10, 30, 60, 300]:
@@ -305,10 +317,11 @@ class TestExtendedConfig(unittest.TestCase):
         cfgfn = client.GenericCurlConfig(connect_timeout=connect_timeout,
                                          timeout=timeout)
 
-        curl = FakeCurl(RapiMock())
+        curl_factory = lambda: FakeCurl(RapiMock())
         cl = client.GanetiRapiClient("master.example.com", curl_config_fn=cfgfn,
-                                     curl=curl)
+                                     curl_factory=curl_factory)
 
+        curl = cl._CreateCurl()
         self.assertEqual(curl.getopt(pycurl.CONNECTTIMEOUT), connect_timeout)
         self.assertEqual(curl.getopt(pycurl.TIMEOUT), timeout)
 
@@ -320,18 +333,7 @@ class GanetiRapiClientTests(testutils.GanetiTestCase):
     self.rapi = RapiMock()
     self.curl = FakeCurl(self.rapi)
     self.client = client.GanetiRapiClient("master.example.com",
-                                          curl=self.curl)
-
-    # Signals should be disabled by default
-    self.assert_(self.curl.getopt(pycurl.NOSIGNAL))
-
-    # No auth and no proxy
-    self.assertFalse(self.curl.getopt(pycurl.USERPWD))
-    self.assert_(self.curl.getopt(pycurl.PROXY) is None)
-
-    # Content-type is required for requests
-    headers = self.curl.getopt(pycurl.HTTPHEADER)
-    self.assert_("Content-type: application/json" in headers)
+                                          curl_factory=lambda: self.curl)
 
   def assertHandler(self, handler_cls):
     self.failUnless(isinstance(self.rapi.GetLastHandler(), handler_cls))
@@ -372,6 +374,22 @@ class GanetiRapiClientTests(testutils.GanetiTestCase):
     for i in [[1, 2, 3], {"moo": "boo"}, (1, 2, 3)]:
       self.assertRaises(ValueError, self.client._EncodeQuery, [("x", i)])
 
+  def testCurlSettings(self):
+    self.rapi.AddResponse("2")
+    self.assertEqual(2, self.client.GetVersion())
+    self.assertHandler(rlib2.R_version)
+
+    # Signals should be disabled by default
+    self.assert_(self.curl.getopt(pycurl.NOSIGNAL))
+
+    # No auth and no proxy
+    self.assertFalse(self.curl.getopt(pycurl.USERPWD))
+    self.assert_(self.curl.getopt(pycurl.PROXY) is None)
+
+    # Content-type is required for requests
+    headers = self.curl.getopt(pycurl.HTTPHEADER)
+    self.assert_("Content-type: application/json" in headers)
+
   def testHttpError(self):
     self.rapi.AddResponse(None, code=404)
     try:
@@ -382,7 +400,6 @@ class GanetiRapiClientTests(testutils.GanetiTestCase):
       self.fail("Didn't raise exception")
 
   def testGetVersion(self):
-    self.client._version = None
     self.rapi.AddResponse("2")
     self.assertEqual(2, self.client.GetVersion())
     self.assertHandler(rlib2.R_version)
-- 
GitLab