diff --git a/lib/daemon.py b/lib/daemon.py index 5346e47da294178206d2ffd18b39027b669767c6..6a13670920e90e287c93aa501351da11b9d5705a 100644 --- a/lib/daemon.py +++ b/lib/daemon.py @@ -143,7 +143,7 @@ class AsyncStreamServer(GanetiBaseAsyncoreDispatcher): # is passed in from accept anyway client_address = netutils.GetSocketCredentials(connected_socket) logging.info("Accepted connection from %s", - netutils.FormatAddress(self.family, client_address)) + netutils.FormatAddress(client_address, family=self.family)) self.handle_connection(connected_socket, client_address) def handle_connection(self, connected_socket, client_address): @@ -274,7 +274,7 @@ class AsyncTerminatedMessageStream(asynchat.async_chat): def close_log(self): logging.info("Closing connection from %s", - netutils.FormatAddress(self.family, self.peer_address)) + netutils.FormatAddress(self.peer_address, family=self.family)) self.close() # this method is overriding an asyncore.dispatcher method diff --git a/lib/http/client.py b/lib/http/client.py index 891865f39fb414862ab68f679c8f9cf33b2d404f..67456378e85b89973e635a53c154a791d52419bc 100644 --- a/lib/http/client.py +++ b/lib/http/client.py @@ -106,8 +106,7 @@ class HttpClientRequest(object): """ if netutils.IPAddress.IsValid(self.host): - family = netutils.IPAddress.GetAddressFamily(self.host) - address = netutils.FormatAddress(family, (self.host, self.port)) + address = netutils.FormatAddress((self.host, self.port)) else: address = "%s:%s" % (self.host, self.port) # TODO: Support for non-SSL requests diff --git a/lib/netutils.py b/lib/netutils.py index 509329eb7a23582e481e01959f951b7e037e9d38..a763fad4da50ceeb7e9ceecb8c33c0e1d7fe0f05 100644 --- a/lib/netutils.py +++ b/lib/netutils.py @@ -490,15 +490,21 @@ class IP6Address(IPAddress): return address_int -def FormatAddress(family, address): +def FormatAddress(address, family=None): """Format a socket address - @type family: integer - @param family: socket family (one of socket.AF_*) @type address: family specific (usually tuple) @param address: address, as reported by this class + @type family: integer + @param family: socket family (one of socket.AF_*) or None """ + if family is None: + try: + family = IPAddress.GetAddressFamily(address[0]) + except errors.IPAddressError: + raise errors.ParameterError(address) + if family == socket.AF_UNIX and len(address) == 3: return "pid=%s, uid=%s, gid=%s" % address diff --git a/test/ganeti.netutils_unittest.py b/test/ganeti.netutils_unittest.py index 0fcddb7221b0fe0f71c86a1a6da4e1cdfdc32e22..e9ed0dbaebc0a49f60b8d080423587c168bfe39b 100755 --- a/test/ganeti.netutils_unittest.py +++ b/test/ganeti.netutils_unittest.py @@ -388,30 +388,38 @@ class TestFormatAddress(unittest.TestCase): """Testcase for FormatAddress""" def testFormatAddressUnixSocket(self): - res1 = netutils.FormatAddress(socket.AF_UNIX, ("12352", 0, 0)) + res1 = netutils.FormatAddress(("12352", 0, 0), family=socket.AF_UNIX) self.assertEqual(res1, "pid=12352, uid=0, gid=0") def testFormatAddressIP4(self): - res1 = netutils.FormatAddress(socket.AF_INET, ("127.0.0.1", 1234)) + res1 = netutils.FormatAddress(("127.0.0.1", 1234), family=socket.AF_INET) self.assertEqual(res1, "127.0.0.1:1234") - res2 = netutils.FormatAddress(socket.AF_INET, ("192.0.2.32", None)) + res2 = netutils.FormatAddress(("192.0.2.32", None), family=socket.AF_INET) self.assertEqual(res2, "192.0.2.32") def testFormatAddressIP6(self): - res1 = netutils.FormatAddress(socket.AF_INET6, ("::1", 1234)) + res1 = netutils.FormatAddress(("::1", 1234), family=socket.AF_INET6) self.assertEqual(res1, "[::1]:1234") - res2 = netutils.FormatAddress(socket.AF_INET6, ("::1", None)) + res2 = netutils.FormatAddress(("::1", None), family=socket.AF_INET6) self.assertEqual(res2, "[::1]") - res2 = netutils.FormatAddress(socket.AF_INET6, ("2001:db8::beef", "80")) + res2 = netutils.FormatAddress(("2001:db8::beef", "80"), + family=socket.AF_INET6) self.assertEqual(res2, "[2001:db8::beef]:80") + def testFormatAddressWithoutFamily(self): + res1 = netutils.FormatAddress(("127.0.0.1", 1234)) + self.assertEqual(res1, "127.0.0.1:1234") + res2 = netutils.FormatAddress(("::1", 1234)) + self.assertEqual(res2, "[::1]:1234") + + def testInvalidFormatAddress(self): - self.assertRaises(errors.ParameterError, - netutils.FormatAddress, None, ("::1", None)) - self.assertRaises(errors.ParameterError, - netutils.FormatAddress, socket.AF_INET, "127.0.0.1") - self.assertRaises(errors.ParameterError, - netutils.FormatAddress, socket.AF_INET, ("::1")) + self.assertRaises(errors.ParameterError, netutils.FormatAddress, + "127.0.0.1") + self.assertRaises(errors.ParameterError, netutils.FormatAddress, + "127.0.0.1", family=socket.AF_INET) + self.assertRaises(errors.ParameterError, netutils.FormatAddress, + ("::1"), family=socket.AF_INET ) if __name__ == "__main__":