From 981732fbc3f81a108bc4e083b5dfd592f5287149 Mon Sep 17 00:00:00 2001 From: Manuel Franceschini Date: Mon, 23 Aug 2010 12:26:55 +0200 Subject: [PATCH] Make family argument in FormatAddress optional By doing this we delegate the task of finding the correct address family to the FormatAddress method. Signed-off-by: Manuel Franceschini Reviewed-by: Iustin Pop --- lib/daemon.py | 4 ++-- lib/http/client.py | 3 +-- lib/netutils.py | 12 +++++++++--- test/ganeti.netutils_unittest.py | 32 ++++++++++++++++++++------------ 4 files changed, 32 insertions(+), 19 deletions(-) diff --git a/lib/daemon.py b/lib/daemon.py index 5346e47da..6a1367092 100644 --- a/lib/daemon.py +++ b/lib/daemon.py @@ -143,7 +143,7 @@ class AsyncStreamServer(GanetiBaseAsyncoreDispatcher): # is passed in from accept anyway client_address = netutils.GetSocketCredentials(connected_socket) logging.info("Accepted connection from %s", - netutils.FormatAddress(self.family, client_address)) + netutils.FormatAddress(client_address, family=self.family)) self.handle_connection(connected_socket, client_address) def handle_connection(self, connected_socket, client_address): @@ -274,7 +274,7 @@ class AsyncTerminatedMessageStream(asynchat.async_chat): def close_log(self): logging.info("Closing connection from %s", - netutils.FormatAddress(self.family, self.peer_address)) + netutils.FormatAddress(self.peer_address, family=self.family)) self.close() # this method is overriding an asyncore.dispatcher method diff --git a/lib/http/client.py b/lib/http/client.py index 891865f39..67456378e 100644 --- a/lib/http/client.py +++ b/lib/http/client.py @@ -106,8 +106,7 @@ class HttpClientRequest(object): """ if netutils.IPAddress.IsValid(self.host): - family = netutils.IPAddress.GetAddressFamily(self.host) - address = netutils.FormatAddress(family, (self.host, self.port)) + address = netutils.FormatAddress((self.host, self.port)) else: address = "%s:%s" % (self.host, self.port) # TODO: Support for non-SSL requests diff --git a/lib/netutils.py b/lib/netutils.py index 509329eb7..a763fad4d 100644 --- a/lib/netutils.py +++ b/lib/netutils.py @@ -490,15 +490,21 @@ class IP6Address(IPAddress): return address_int -def FormatAddress(family, address): +def FormatAddress(address, family=None): """Format a socket address - @type family: integer - @param family: socket family (one of socket.AF_*) @type address: family specific (usually tuple) @param address: address, as reported by this class + @type family: integer + @param family: socket family (one of socket.AF_*) or None """ + if family is None: + try: + family = IPAddress.GetAddressFamily(address[0]) + except errors.IPAddressError: + raise errors.ParameterError(address) + if family == socket.AF_UNIX and len(address) == 3: return "pid=%s, uid=%s, gid=%s" % address diff --git a/test/ganeti.netutils_unittest.py b/test/ganeti.netutils_unittest.py index 0fcddb722..e9ed0dbae 100755 --- a/test/ganeti.netutils_unittest.py +++ b/test/ganeti.netutils_unittest.py @@ -388,30 +388,38 @@ class TestFormatAddress(unittest.TestCase): """Testcase for FormatAddress""" def testFormatAddressUnixSocket(self): - res1 = netutils.FormatAddress(socket.AF_UNIX, ("12352", 0, 0)) + res1 = netutils.FormatAddress(("12352", 0, 0), family=socket.AF_UNIX) self.assertEqual(res1, "pid=12352, uid=0, gid=0") def testFormatAddressIP4(self): - res1 = netutils.FormatAddress(socket.AF_INET, ("127.0.0.1", 1234)) + res1 = netutils.FormatAddress(("127.0.0.1", 1234), family=socket.AF_INET) self.assertEqual(res1, "127.0.0.1:1234") - res2 = netutils.FormatAddress(socket.AF_INET, ("192.0.2.32", None)) + res2 = netutils.FormatAddress(("192.0.2.32", None), family=socket.AF_INET) self.assertEqual(res2, "192.0.2.32") def testFormatAddressIP6(self): - res1 = netutils.FormatAddress(socket.AF_INET6, ("::1", 1234)) + res1 = netutils.FormatAddress(("::1", 1234), family=socket.AF_INET6) self.assertEqual(res1, "[::1]:1234") - res2 = netutils.FormatAddress(socket.AF_INET6, ("::1", None)) + res2 = netutils.FormatAddress(("::1", None), family=socket.AF_INET6) self.assertEqual(res2, "[::1]") - res2 = netutils.FormatAddress(socket.AF_INET6, ("2001:db8::beef", "80")) + res2 = netutils.FormatAddress(("2001:db8::beef", "80"), + family=socket.AF_INET6) self.assertEqual(res2, "[2001:db8::beef]:80") + def testFormatAddressWithoutFamily(self): + res1 = netutils.FormatAddress(("127.0.0.1", 1234)) + self.assertEqual(res1, "127.0.0.1:1234") + res2 = netutils.FormatAddress(("::1", 1234)) + self.assertEqual(res2, "[::1]:1234") + + def testInvalidFormatAddress(self): - self.assertRaises(errors.ParameterError, - netutils.FormatAddress, None, ("::1", None)) - self.assertRaises(errors.ParameterError, - netutils.FormatAddress, socket.AF_INET, "127.0.0.1") - self.assertRaises(errors.ParameterError, - netutils.FormatAddress, socket.AF_INET, ("::1")) + self.assertRaises(errors.ParameterError, netutils.FormatAddress, + "127.0.0.1") + self.assertRaises(errors.ParameterError, netutils.FormatAddress, + "127.0.0.1", family=socket.AF_INET) + self.assertRaises(errors.ParameterError, netutils.FormatAddress, + ("::1"), family=socket.AF_INET ) if __name__ == "__main__": -- GitLab