Skip to content
Snippets Groups Projects
Commit a8950eb7 authored by Michael Hanselmann's avatar Michael Hanselmann
Browse files

Provide unittests for http.auth


To simplify writing unittests, one data structure class in http.server is
also changed. According to the coverage utility, this provides 95%
coverage.

Signed-off-by: default avatarMichael Hanselmann <hansmi@google.com>
Reviewed-by: default avatarIustin Pop <iustin@google.com>
parent 23ccba04
No related branches found
No related tags found
No related merge requests found
......@@ -74,12 +74,12 @@ class _HttpServerRequest(object):
"""Data structure for HTTP request on server side.
"""
def __init__(self, request_msg):
def __init__(self, method, path, headers, body):
# Request attributes
self.request_method = request_msg.start_line.method
self.request_path = request_msg.start_line.path
self.request_headers = request_msg.headers
self.request_body = request_msg.decoded_body
self.request_method = method
self.request_path = path
self.request_headers = headers
self.request_body = body
# Response attributes
self.resp_headers = {}
......@@ -308,7 +308,10 @@ class HttpServerRequestExecutor(object):
"""Calls the handler function for the current request.
"""
handler_context = _HttpServerRequest(self.request_msg)
handler_context = _HttpServerRequest(self.request_msg.start_line.method,
self.request_msg.start_line.path,
self.request_msg.headers,
self.request_msg.decoded_body)
try:
try:
......
......@@ -25,6 +25,7 @@
import os
import unittest
import time
import tempfile
from ganeti import http
......@@ -70,10 +71,7 @@ class TestMisc(unittest.TestCase):
def testHttpServerRequest(self):
"""Test ganeti.http.server._HttpServerRequest"""
fake_request = http.HttpMessage()
fake_request.start_line = \
http.HttpClientToServerStartLine("GET", "/", "HTTP/1.1")
server_request = http.server._HttpServerRequest(fake_request)
server_request = http.server._HttpServerRequest("GET", "/", None, None)
# These are expected by users of the HTTP server
self.assert_(hasattr(server_request, "request_method"))
......@@ -95,16 +93,44 @@ class TestMisc(unittest.TestCase):
self.assert_(message_reader_class.START_LINE_LENGTH_MAX > 0)
self.assert_(message_reader_class.HEADER_LENGTH_MAX > 0)
def testFormatAuthHeader(self):
self.assertEqual(http.auth._FormatAuthHeader("Basic", {}),
"Basic")
self.assertEqual(http.auth._FormatAuthHeader("Basic", { "foo": "bar", }),
"Basic foo=bar")
self.assertEqual(http.auth._FormatAuthHeader("Basic", { "foo": "", }),
"Basic foo=\"\"")
self.assertEqual(http.auth._FormatAuthHeader("Basic", { "foo": "x,y", }),
"Basic foo=\"x,y\"")
params = {
"foo": "x,y",
"realm": "secure",
}
# It's a dict whose order isn't guaranteed, hence checking a list
self.assert_(http.auth._FormatAuthHeader("Digest", params) in
("Digest foo=\"x,y\" realm=secure",
"Digest realm=secure foo=\"x,y\""))
class _FakeRequestAuth(http.auth.HttpServerRequestAuthentication):
def __init__(self, realm):
def __init__(self, realm, authreq, authenticate_fn):
http.auth.HttpServerRequestAuthentication.__init__(self)
self.realm = realm
self.authreq = authreq
self.authenticate_fn = authenticate_fn
def AuthenticationRequired(self, req):
return self.authreq
def GetAuthRealm(self, req):
return self.realm
def Authenticate(self, *args):
if self.authenticate_fn:
return self.authenticate_fn(*args)
raise NotImplementedError()
class TestAuth(unittest.TestCase):
"""Authentication tests"""
......@@ -112,17 +138,16 @@ class TestAuth(unittest.TestCase):
hsra = http.auth.HttpServerRequestAuthentication
def testConstants(self):
self.assertEqual(self.hsra._CLEARTEXT_SCHEME,
self.hsra._CLEARTEXT_SCHEME.upper())
self.assertEqual(self.hsra._HA1_SCHEME,
self.hsra._HA1_SCHEME.upper())
for scheme in [self.hsra._CLEARTEXT_SCHEME, self.hsra._HA1_SCHEME]:
self.assertEqual(scheme, scheme.upper())
self.assert_(scheme.startswith("{"))
self.assert_(scheme.endswith("}"))
def _testVerifyBasicAuthPassword(self, realm, user, password, expected):
ra = _FakeRequestAuth(realm)
ra = _FakeRequestAuth(realm, False, None)
return ra.VerifyBasicAuthPassword(None, user, password, expected)
def testVerifyBasicAuthPassword(self):
tvbap = self._testVerifyBasicAuthPassword
......@@ -164,5 +189,146 @@ class TestAuth(unittest.TestCase):
"{HA1}92ea58ae804481498c257b2f65561a17"))
class _SimpleAuthenticator:
def __init__(self, user, password):
self.user = user
self.password = password
self.called = False
def __call__(self, req, user, password):
self.called = True
return self.user == user and self.password == password
class TestHttpServerRequestAuthentication(unittest.TestCase):
def testNoAuth(self):
req = http.server._HttpServerRequest("GET", "/", None, None)
_FakeRequestAuth("area1", False, None).PreHandleRequest(req)
def testNoRealm(self):
headers = { http.HTTP_AUTHORIZATION: "", }
req = http.server._HttpServerRequest("GET", "/", headers, None)
ra = _FakeRequestAuth(None, False, None)
self.assertRaises(AssertionError, ra.PreHandleRequest, req)
def testNoScheme(self):
headers = { http.HTTP_AUTHORIZATION: "", }
req = http.server._HttpServerRequest("GET", "/", headers, None)
ra = _FakeRequestAuth("area1", False, None)
self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
def testUnknownScheme(self):
headers = { http.HTTP_AUTHORIZATION: "NewStyleAuth abc", }
req = http.server._HttpServerRequest("GET", "/", headers, None)
ra = _FakeRequestAuth("area1", False, None)
self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
def testInvalidBase64(self):
headers = { http.HTTP_AUTHORIZATION: "Basic x_=_", }
req = http.server._HttpServerRequest("GET", "/", headers, None)
ra = _FakeRequestAuth("area1", False, None)
self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
def testAuthForPublicResource(self):
headers = {
http.HTTP_AUTHORIZATION: "Basic %s" % ("foo".encode("base64").strip(), ),
}
req = http.server._HttpServerRequest("GET", "/", headers, None)
ra = _FakeRequestAuth("area1", False, None)
self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
def testAuthForPublicResource(self):
headers = {
http.HTTP_AUTHORIZATION:
"Basic %s" % ("foo:bar".encode("base64").strip(), ),
}
req = http.server._HttpServerRequest("GET", "/", headers, None)
ac = _SimpleAuthenticator("foo", "bar")
ra = _FakeRequestAuth("area1", False, ac)
ra.PreHandleRequest(req)
req = http.server._HttpServerRequest("GET", "/", headers, None)
ac = _SimpleAuthenticator("something", "else")
ra = _FakeRequestAuth("area1", False, ac)
self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
def testInvalidRequestHeader(self):
checks = {
http.HttpUnauthorized: ["", "\t", "-", ".", "@", "<", ">", "Digest",
"basic %s" % "foobar".encode("base64").strip()],
http.HttpBadRequest: ["Basic"],
}
for exc, headers in checks.items():
for i in headers:
headers = { http.HTTP_AUTHORIZATION: i, }
req = http.server._HttpServerRequest("GET", "/", headers, None)
ra = _FakeRequestAuth("area1", False, None)
self.assertRaises(exc, ra.PreHandleRequest, req)
def testBasicAuth(self):
for user in ["", "joe", "user name with spaces"]:
for pw in ["", "-", ":", "foobar", "Foo Bar Baz", "@@@", "###",
"foo:bar:baz"]:
for wrong_pw in [True, False]:
basic_auth = "%s:%s" % (user, pw)
if wrong_pw:
basic_auth += "WRONG"
headers = {
http.HTTP_AUTHORIZATION:
"Basic %s" % (basic_auth.encode("base64").strip(), ),
}
req = http.server._HttpServerRequest("GET", "/", headers, None)
ac = _SimpleAuthenticator(user, pw)
self.assertFalse(ac.called)
ra = _FakeRequestAuth("area1", True, ac)
if wrong_pw:
try:
ra.PreHandleRequest(req)
except http.HttpUnauthorized, err:
www_auth = err.headers[http.HTTP_WWW_AUTHENTICATE]
self.assert_(www_auth.startswith(http.auth.HTTP_BASIC_AUTH))
else:
self.fail("Didn't raise HttpUnauthorized")
else:
ra.PreHandleRequest(req)
self.assert_(ac.called)
class TestReadPasswordFile(testutils.GanetiTestCase):
def setUp(self):
testutils.GanetiTestCase.setUp(self)
self.tmpfile = tempfile.NamedTemporaryFile()
def testSimple(self):
self.tmpfile.write("user1 password")
self.tmpfile.flush()
users = http.auth.ReadPasswordFile(self.tmpfile.name)
self.assertEqual(len(users), 1)
self.assertEqual(users["user1"].password, "password")
self.assertEqual(len(users["user1"].options), 0)
def testOptions(self):
self.tmpfile.write("# Passwords\n")
self.tmpfile.write("user1 password\n")
self.tmpfile.write("\n")
self.tmpfile.write("# Comment\n")
self.tmpfile.write("user2 pw write,read\n")
self.tmpfile.write(" \t# Another comment\n")
self.tmpfile.write("invalidline\n")
self.tmpfile.flush()
users = http.auth.ReadPasswordFile(self.tmpfile.name)
self.assertEqual(len(users), 2)
self.assertEqual(users["user1"].password, "password")
self.assertEqual(len(users["user1"].options), 0)
self.assertEqual(users["user2"].password, "pw")
self.assertEqual(users["user2"].options, ["write", "read"])
if __name__ == '__main__':
testutils.GanetiTestProgram()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment