Commit 981732fb authored by Manuel Franceschini's avatar Manuel Franceschini

Make family argument in FormatAddress optional

By doing this we delegate the task of finding the correct address family
to the FormatAddress method.
Signed-off-by: default avatarManuel Franceschini <livewire@google.com>
Reviewed-by: default avatarIustin Pop <iustin@google.com>
parent 51b13ce9
......@@ -143,7 +143,7 @@ class AsyncStreamServer(GanetiBaseAsyncoreDispatcher):
# is passed in from accept anyway
client_address = netutils.GetSocketCredentials(connected_socket)
logging.info("Accepted connection from %s",
netutils.FormatAddress(self.family, client_address))
netutils.FormatAddress(client_address, family=self.family))
self.handle_connection(connected_socket, client_address)
def handle_connection(self, connected_socket, client_address):
......@@ -274,7 +274,7 @@ class AsyncTerminatedMessageStream(asynchat.async_chat):
def close_log(self):
logging.info("Closing connection from %s",
netutils.FormatAddress(self.family, self.peer_address))
netutils.FormatAddress(self.peer_address, family=self.family))
self.close()
# this method is overriding an asyncore.dispatcher method
......
......@@ -106,8 +106,7 @@ class HttpClientRequest(object):
"""
if netutils.IPAddress.IsValid(self.host):
family = netutils.IPAddress.GetAddressFamily(self.host)
address = netutils.FormatAddress(family, (self.host, self.port))
address = netutils.FormatAddress((self.host, self.port))
else:
address = "%s:%s" % (self.host, self.port)
# TODO: Support for non-SSL requests
......
......@@ -490,15 +490,21 @@ class IP6Address(IPAddress):
return address_int
def FormatAddress(family, address):
def FormatAddress(address, family=None):
"""Format a socket address
@type family: integer
@param family: socket family (one of socket.AF_*)
@type address: family specific (usually tuple)
@param address: address, as reported by this class
@type family: integer
@param family: socket family (one of socket.AF_*) or None
"""
if family is None:
try:
family = IPAddress.GetAddressFamily(address[0])
except errors.IPAddressError:
raise errors.ParameterError(address)
if family == socket.AF_UNIX and len(address) == 3:
return "pid=%s, uid=%s, gid=%s" % address
......
......@@ -388,30 +388,38 @@ class TestFormatAddress(unittest.TestCase):
"""Testcase for FormatAddress"""
def testFormatAddressUnixSocket(self):
res1 = netutils.FormatAddress(socket.AF_UNIX, ("12352", 0, 0))
res1 = netutils.FormatAddress(("12352", 0, 0), family=socket.AF_UNIX)
self.assertEqual(res1, "pid=12352, uid=0, gid=0")
def testFormatAddressIP4(self):
res1 = netutils.FormatAddress(socket.AF_INET, ("127.0.0.1", 1234))
res1 = netutils.FormatAddress(("127.0.0.1", 1234), family=socket.AF_INET)
self.assertEqual(res1, "127.0.0.1:1234")
res2 = netutils.FormatAddress(socket.AF_INET, ("192.0.2.32", None))
res2 = netutils.FormatAddress(("192.0.2.32", None), family=socket.AF_INET)
self.assertEqual(res2, "192.0.2.32")
def testFormatAddressIP6(self):
res1 = netutils.FormatAddress(socket.AF_INET6, ("::1", 1234))
res1 = netutils.FormatAddress(("::1", 1234), family=socket.AF_INET6)
self.assertEqual(res1, "[::1]:1234")
res2 = netutils.FormatAddress(socket.AF_INET6, ("::1", None))
res2 = netutils.FormatAddress(("::1", None), family=socket.AF_INET6)
self.assertEqual(res2, "[::1]")
res2 = netutils.FormatAddress(socket.AF_INET6, ("2001:db8::beef", "80"))
res2 = netutils.FormatAddress(("2001:db8::beef", "80"),
family=socket.AF_INET6)
self.assertEqual(res2, "[2001:db8::beef]:80")
def testFormatAddressWithoutFamily(self):
res1 = netutils.FormatAddress(("127.0.0.1", 1234))
self.assertEqual(res1, "127.0.0.1:1234")
res2 = netutils.FormatAddress(("::1", 1234))
self.assertEqual(res2, "[::1]:1234")
def testInvalidFormatAddress(self):
self.assertRaises(errors.ParameterError,
netutils.FormatAddress, None, ("::1", None))
self.assertRaises(errors.ParameterError,
netutils.FormatAddress, socket.AF_INET, "127.0.0.1")
self.assertRaises(errors.ParameterError,
netutils.FormatAddress, socket.AF_INET, ("::1"))
self.assertRaises(errors.ParameterError, netutils.FormatAddress,
"127.0.0.1")
self.assertRaises(errors.ParameterError, netutils.FormatAddress,
"127.0.0.1", family=socket.AF_INET)
self.assertRaises(errors.ParameterError, netutils.FormatAddress,
("::1"), family=socket.AF_INET )
if __name__ == "__main__":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment