diff --git a/lib/http/server.py b/lib/http/server.py index 1e564c10806943eb9e4f6a8d8ead5de2ff2855ec..3a05a43e479ee15f8ae889caf8aeb2864e638076 100644 --- a/lib/http/server.py +++ b/lib/http/server.py @@ -74,12 +74,12 @@ class _HttpServerRequest(object): """Data structure for HTTP request on server side. """ - def __init__(self, request_msg): + def __init__(self, method, path, headers, body): # Request attributes - self.request_method = request_msg.start_line.method - self.request_path = request_msg.start_line.path - self.request_headers = request_msg.headers - self.request_body = request_msg.decoded_body + self.request_method = method + self.request_path = path + self.request_headers = headers + self.request_body = body # Response attributes self.resp_headers = {} @@ -308,7 +308,10 @@ class HttpServerRequestExecutor(object): """Calls the handler function for the current request. """ - handler_context = _HttpServerRequest(self.request_msg) + handler_context = _HttpServerRequest(self.request_msg.start_line.method, + self.request_msg.start_line.path, + self.request_msg.headers, + self.request_msg.decoded_body) try: try: diff --git a/test/ganeti.http_unittest.py b/test/ganeti.http_unittest.py index ebb00a923a551a9d27e9ea4afe80e0edd19eb70f..7d0d477cef9a44d08f7861cdc124f8275de5ed96 100755 --- a/test/ganeti.http_unittest.py +++ b/test/ganeti.http_unittest.py @@ -25,6 +25,7 @@ import os import unittest import time +import tempfile from ganeti import http @@ -70,10 +71,7 @@ class TestMisc(unittest.TestCase): def testHttpServerRequest(self): """Test ganeti.http.server._HttpServerRequest""" - fake_request = http.HttpMessage() - fake_request.start_line = \ - http.HttpClientToServerStartLine("GET", "/", "HTTP/1.1") - server_request = http.server._HttpServerRequest(fake_request) + server_request = http.server._HttpServerRequest("GET", "/", None, None) # These are expected by users of the HTTP server self.assert_(hasattr(server_request, "request_method")) @@ -95,16 +93,44 @@ class TestMisc(unittest.TestCase): self.assert_(message_reader_class.START_LINE_LENGTH_MAX > 0) self.assert_(message_reader_class.HEADER_LENGTH_MAX > 0) + def testFormatAuthHeader(self): + self.assertEqual(http.auth._FormatAuthHeader("Basic", {}), + "Basic") + self.assertEqual(http.auth._FormatAuthHeader("Basic", { "foo": "bar", }), + "Basic foo=bar") + self.assertEqual(http.auth._FormatAuthHeader("Basic", { "foo": "", }), + "Basic foo=\"\"") + self.assertEqual(http.auth._FormatAuthHeader("Basic", { "foo": "x,y", }), + "Basic foo=\"x,y\"") + params = { + "foo": "x,y", + "realm": "secure", + } + # It's a dict whose order isn't guaranteed, hence checking a list + self.assert_(http.auth._FormatAuthHeader("Digest", params) in + ("Digest foo=\"x,y\" realm=secure", + "Digest realm=secure foo=\"x,y\"")) + class _FakeRequestAuth(http.auth.HttpServerRequestAuthentication): - def __init__(self, realm): + def __init__(self, realm, authreq, authenticate_fn): http.auth.HttpServerRequestAuthentication.__init__(self) self.realm = realm + self.authreq = authreq + self.authenticate_fn = authenticate_fn + + def AuthenticationRequired(self, req): + return self.authreq def GetAuthRealm(self, req): return self.realm + def Authenticate(self, *args): + if self.authenticate_fn: + return self.authenticate_fn(*args) + raise NotImplementedError() + class TestAuth(unittest.TestCase): """Authentication tests""" @@ -112,17 +138,16 @@ class TestAuth(unittest.TestCase): hsra = http.auth.HttpServerRequestAuthentication def testConstants(self): - self.assertEqual(self.hsra._CLEARTEXT_SCHEME, - self.hsra._CLEARTEXT_SCHEME.upper()) - self.assertEqual(self.hsra._HA1_SCHEME, - self.hsra._HA1_SCHEME.upper()) + for scheme in [self.hsra._CLEARTEXT_SCHEME, self.hsra._HA1_SCHEME]: + self.assertEqual(scheme, scheme.upper()) + self.assert_(scheme.startswith("{")) + self.assert_(scheme.endswith("}")) def _testVerifyBasicAuthPassword(self, realm, user, password, expected): - ra = _FakeRequestAuth(realm) + ra = _FakeRequestAuth(realm, False, None) return ra.VerifyBasicAuthPassword(None, user, password, expected) - def testVerifyBasicAuthPassword(self): tvbap = self._testVerifyBasicAuthPassword @@ -164,5 +189,146 @@ class TestAuth(unittest.TestCase): "{HA1}92ea58ae804481498c257b2f65561a17")) +class _SimpleAuthenticator: + def __init__(self, user, password): + self.user = user + self.password = password + self.called = False + + def __call__(self, req, user, password): + self.called = True + return self.user == user and self.password == password + + +class TestHttpServerRequestAuthentication(unittest.TestCase): + def testNoAuth(self): + req = http.server._HttpServerRequest("GET", "/", None, None) + _FakeRequestAuth("area1", False, None).PreHandleRequest(req) + + def testNoRealm(self): + headers = { http.HTTP_AUTHORIZATION: "", } + req = http.server._HttpServerRequest("GET", "/", headers, None) + ra = _FakeRequestAuth(None, False, None) + self.assertRaises(AssertionError, ra.PreHandleRequest, req) + + def testNoScheme(self): + headers = { http.HTTP_AUTHORIZATION: "", } + req = http.server._HttpServerRequest("GET", "/", headers, None) + ra = _FakeRequestAuth("area1", False, None) + self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req) + + def testUnknownScheme(self): + headers = { http.HTTP_AUTHORIZATION: "NewStyleAuth abc", } + req = http.server._HttpServerRequest("GET", "/", headers, None) + ra = _FakeRequestAuth("area1", False, None) + self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req) + + def testInvalidBase64(self): + headers = { http.HTTP_AUTHORIZATION: "Basic x_=_", } + req = http.server._HttpServerRequest("GET", "/", headers, None) + ra = _FakeRequestAuth("area1", False, None) + self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req) + + def testAuthForPublicResource(self): + headers = { + http.HTTP_AUTHORIZATION: "Basic %s" % ("foo".encode("base64").strip(), ), + } + req = http.server._HttpServerRequest("GET", "/", headers, None) + ra = _FakeRequestAuth("area1", False, None) + self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req) + + def testAuthForPublicResource(self): + headers = { + http.HTTP_AUTHORIZATION: + "Basic %s" % ("foo:bar".encode("base64").strip(), ), + } + req = http.server._HttpServerRequest("GET", "/", headers, None) + ac = _SimpleAuthenticator("foo", "bar") + ra = _FakeRequestAuth("area1", False, ac) + ra.PreHandleRequest(req) + + req = http.server._HttpServerRequest("GET", "/", headers, None) + ac = _SimpleAuthenticator("something", "else") + ra = _FakeRequestAuth("area1", False, ac) + self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req) + + def testInvalidRequestHeader(self): + checks = { + http.HttpUnauthorized: ["", "\t", "-", ".", "@", "<", ">", "Digest", + "basic %s" % "foobar".encode("base64").strip()], + http.HttpBadRequest: ["Basic"], + } + + for exc, headers in checks.items(): + for i in headers: + headers = { http.HTTP_AUTHORIZATION: i, } + req = http.server._HttpServerRequest("GET", "/", headers, None) + ra = _FakeRequestAuth("area1", False, None) + self.assertRaises(exc, ra.PreHandleRequest, req) + + def testBasicAuth(self): + for user in ["", "joe", "user name with spaces"]: + for pw in ["", "-", ":", "foobar", "Foo Bar Baz", "@@@", "###", + "foo:bar:baz"]: + for wrong_pw in [True, False]: + basic_auth = "%s:%s" % (user, pw) + if wrong_pw: + basic_auth += "WRONG" + headers = { + http.HTTP_AUTHORIZATION: + "Basic %s" % (basic_auth.encode("base64").strip(), ), + } + req = http.server._HttpServerRequest("GET", "/", headers, None) + + ac = _SimpleAuthenticator(user, pw) + self.assertFalse(ac.called) + ra = _FakeRequestAuth("area1", True, ac) + if wrong_pw: + try: + ra.PreHandleRequest(req) + except http.HttpUnauthorized, err: + www_auth = err.headers[http.HTTP_WWW_AUTHENTICATE] + self.assert_(www_auth.startswith(http.auth.HTTP_BASIC_AUTH)) + else: + self.fail("Didn't raise HttpUnauthorized") + else: + ra.PreHandleRequest(req) + self.assert_(ac.called) + + +class TestReadPasswordFile(testutils.GanetiTestCase): + def setUp(self): + testutils.GanetiTestCase.setUp(self) + + self.tmpfile = tempfile.NamedTemporaryFile() + + def testSimple(self): + self.tmpfile.write("user1 password") + self.tmpfile.flush() + + users = http.auth.ReadPasswordFile(self.tmpfile.name) + self.assertEqual(len(users), 1) + self.assertEqual(users["user1"].password, "password") + self.assertEqual(len(users["user1"].options), 0) + + def testOptions(self): + self.tmpfile.write("# Passwords\n") + self.tmpfile.write("user1 password\n") + self.tmpfile.write("\n") + self.tmpfile.write("# Comment\n") + self.tmpfile.write("user2 pw write,read\n") + self.tmpfile.write(" \t# Another comment\n") + self.tmpfile.write("invalidline\n") + self.tmpfile.flush() + + users = http.auth.ReadPasswordFile(self.tmpfile.name) + self.assertEqual(len(users), 2) + self.assertEqual(users["user1"].password, "password") + self.assertEqual(len(users["user1"].options), 0) + + self.assertEqual(users["user2"].password, "pw") + self.assertEqual(users["user2"].options, ["write", "read"]) + + if __name__ == '__main__': testutils.GanetiTestProgram()