Commit 981732fb authored by Manuel Franceschini's avatar Manuel Franceschini
Browse files

Make family argument in FormatAddress optional



By doing this we delegate the task of finding the correct address family
to the FormatAddress method.
Signed-off-by: default avatarManuel Franceschini <livewire@google.com>
Reviewed-by: default avatarIustin Pop <iustin@google.com>
parent 51b13ce9
...@@ -143,7 +143,7 @@ class AsyncStreamServer(GanetiBaseAsyncoreDispatcher): ...@@ -143,7 +143,7 @@ class AsyncStreamServer(GanetiBaseAsyncoreDispatcher):
# is passed in from accept anyway # is passed in from accept anyway
client_address = netutils.GetSocketCredentials(connected_socket) client_address = netutils.GetSocketCredentials(connected_socket)
logging.info("Accepted connection from %s", logging.info("Accepted connection from %s",
netutils.FormatAddress(self.family, client_address)) netutils.FormatAddress(client_address, family=self.family))
self.handle_connection(connected_socket, client_address) self.handle_connection(connected_socket, client_address)
def handle_connection(self, connected_socket, client_address): def handle_connection(self, connected_socket, client_address):
...@@ -274,7 +274,7 @@ class AsyncTerminatedMessageStream(asynchat.async_chat): ...@@ -274,7 +274,7 @@ class AsyncTerminatedMessageStream(asynchat.async_chat):
def close_log(self): def close_log(self):
logging.info("Closing connection from %s", logging.info("Closing connection from %s",
netutils.FormatAddress(self.family, self.peer_address)) netutils.FormatAddress(self.peer_address, family=self.family))
self.close() self.close()
# this method is overriding an asyncore.dispatcher method # this method is overriding an asyncore.dispatcher method
......
...@@ -106,8 +106,7 @@ class HttpClientRequest(object): ...@@ -106,8 +106,7 @@ class HttpClientRequest(object):
""" """
if netutils.IPAddress.IsValid(self.host): if netutils.IPAddress.IsValid(self.host):
family = netutils.IPAddress.GetAddressFamily(self.host) address = netutils.FormatAddress((self.host, self.port))
address = netutils.FormatAddress(family, (self.host, self.port))
else: else:
address = "%s:%s" % (self.host, self.port) address = "%s:%s" % (self.host, self.port)
# TODO: Support for non-SSL requests # TODO: Support for non-SSL requests
......
...@@ -490,15 +490,21 @@ class IP6Address(IPAddress): ...@@ -490,15 +490,21 @@ class IP6Address(IPAddress):
return address_int return address_int
def FormatAddress(family, address): def FormatAddress(address, family=None):
"""Format a socket address """Format a socket address
@type family: integer
@param family: socket family (one of socket.AF_*)
@type address: family specific (usually tuple) @type address: family specific (usually tuple)
@param address: address, as reported by this class @param address: address, as reported by this class
@type family: integer
@param family: socket family (one of socket.AF_*) or None
""" """
if family is None:
try:
family = IPAddress.GetAddressFamily(address[0])
except errors.IPAddressError:
raise errors.ParameterError(address)
if family == socket.AF_UNIX and len(address) == 3: if family == socket.AF_UNIX and len(address) == 3:
return "pid=%s, uid=%s, gid=%s" % address return "pid=%s, uid=%s, gid=%s" % address
......
...@@ -388,30 +388,38 @@ class TestFormatAddress(unittest.TestCase): ...@@ -388,30 +388,38 @@ class TestFormatAddress(unittest.TestCase):
"""Testcase for FormatAddress""" """Testcase for FormatAddress"""
def testFormatAddressUnixSocket(self): def testFormatAddressUnixSocket(self):
res1 = netutils.FormatAddress(socket.AF_UNIX, ("12352", 0, 0)) res1 = netutils.FormatAddress(("12352", 0, 0), family=socket.AF_UNIX)
self.assertEqual(res1, "pid=12352, uid=0, gid=0") self.assertEqual(res1, "pid=12352, uid=0, gid=0")
def testFormatAddressIP4(self): def testFormatAddressIP4(self):
res1 = netutils.FormatAddress(socket.AF_INET, ("127.0.0.1", 1234)) res1 = netutils.FormatAddress(("127.0.0.1", 1234), family=socket.AF_INET)
self.assertEqual(res1, "127.0.0.1:1234") self.assertEqual(res1, "127.0.0.1:1234")
res2 = netutils.FormatAddress(socket.AF_INET, ("192.0.2.32", None)) res2 = netutils.FormatAddress(("192.0.2.32", None), family=socket.AF_INET)
self.assertEqual(res2, "192.0.2.32") self.assertEqual(res2, "192.0.2.32")
def testFormatAddressIP6(self): def testFormatAddressIP6(self):
res1 = netutils.FormatAddress(socket.AF_INET6, ("::1", 1234)) res1 = netutils.FormatAddress(("::1", 1234), family=socket.AF_INET6)
self.assertEqual(res1, "[::1]:1234") self.assertEqual(res1, "[::1]:1234")
res2 = netutils.FormatAddress(socket.AF_INET6, ("::1", None)) res2 = netutils.FormatAddress(("::1", None), family=socket.AF_INET6)
self.assertEqual(res2, "[::1]") self.assertEqual(res2, "[::1]")
res2 = netutils.FormatAddress(socket.AF_INET6, ("2001:db8::beef", "80")) res2 = netutils.FormatAddress(("2001:db8::beef", "80"),
family=socket.AF_INET6)
self.assertEqual(res2, "[2001:db8::beef]:80") self.assertEqual(res2, "[2001:db8::beef]:80")
def testFormatAddressWithoutFamily(self):
res1 = netutils.FormatAddress(("127.0.0.1", 1234))
self.assertEqual(res1, "127.0.0.1:1234")
res2 = netutils.FormatAddress(("::1", 1234))
self.assertEqual(res2, "[::1]:1234")
def testInvalidFormatAddress(self): def testInvalidFormatAddress(self):
self.assertRaises(errors.ParameterError, self.assertRaises(errors.ParameterError, netutils.FormatAddress,
netutils.FormatAddress, None, ("::1", None)) "127.0.0.1")
self.assertRaises(errors.ParameterError, self.assertRaises(errors.ParameterError, netutils.FormatAddress,
netutils.FormatAddress, socket.AF_INET, "127.0.0.1") "127.0.0.1", family=socket.AF_INET)
self.assertRaises(errors.ParameterError, self.assertRaises(errors.ParameterError, netutils.FormatAddress,
netutils.FormatAddress, socket.AF_INET, ("::1")) ("::1"), family=socket.AF_INET )
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment