diff --git a/daemons/ganeti-confd b/daemons/ganeti-confd index 3a685c78e2682eeda7eb46d0e26664b5df003ca2..804c87a319bcbf911c0ece06fa2498b58a072ae8 100755 --- a/daemons/ganeti-confd +++ b/daemons/ganeti-confd @@ -48,6 +48,7 @@ from ganeti.confd import server as confd_server from ganeti import constants from ganeti import errors from ganeti import daemon +from ganeti import netutils class ConfdAsyncUDPServer(daemon.AsyncUDPSocket): @@ -58,14 +59,15 @@ class ConfdAsyncUDPServer(daemon.AsyncUDPSocket): """Constructor for ConfdAsyncUDPServer @type bind_address: string - @param bind_address: socket bind address ('' for all) + @param bind_address: socket bind address @type port: int @param port: udp port @type processor: L{confd.server.ConfdProcessor} @param processor: ConfdProcessor to use to handle queries """ - daemon.AsyncUDPSocket.__init__(self) + daemon.AsyncUDPSocket.__init__(self, + netutils.GetAddressFamily(bind_address)) self.bind_address = bind_address self.port = port self.processor = processor @@ -251,6 +253,10 @@ def CheckConfd(_, args): print >> sys.stderr, "Need HMAC key %s to run" % constants.CONFD_HMAC_KEY sys.exit(constants.EXIT_FAILURE) + # TODO: once we have a cluster param specifying the address family + # preference, we need to check if the requested options.bind_address does not + # conflict with that. If so, we might warn or EXIT_FAILURE. + def ExecConfd(options, _): """Main confd function, executed with PID file held diff --git a/lib/confd/client.py b/lib/confd/client.py index 288a1da8c797be768c62b951e380c5f021c279e6..af25c680acbf9ecf7e61647e6bd6631e0d966d9b 100644 --- a/lib/confd/client.py +++ b/lib/confd/client.py @@ -72,14 +72,14 @@ class ConfdAsyncUDPClient(daemon.AsyncUDPSocket): implement a non-asyncore based client library. """ - def __init__(self, client): + def __init__(self, client, family): """Constructor for ConfdAsyncUDPClient @type client: L{ConfdClient} @param client: client library, to pass the datagrams to """ - daemon.AsyncUDPSocket.__init__(self) + daemon.AsyncUDPSocket.__init__(self, family) self.client = client # this method is overriding a daemon.AsyncUDPSocket method @@ -137,8 +137,9 @@ class ConfdClient: raise errors.ProgrammerError("callback must be callable") self.UpdatePeerList(peers) + self._SetPeersAddressFamily() self._hmac_key = hmac_key - self._socket = ConfdAsyncUDPClient(self) + self._socket = ConfdAsyncUDPClient(self, self._family) self._callback = callback self._confd_port = port self._logger = logger @@ -385,6 +386,18 @@ class ConfdClient: else: return MISSING + def _SetPeersAddressFamily(self): + if not self._peers: + raise errors.ConfdClientError("Peer list empty") + try: + peer = self._peers[0] + self._family = netutils.GetAddressFamily(peer) + for peer in self._peers[1:]: + if netutils.GetAddressFamily(peer) != self._family: + raise errors.ConfdClientError("Peers must be of same address family") + except errors.GenericError: + raise errors.ConfdClientError("Peer address %s invalid" % peer) + # UPCALL_REPLY: server reply upcall # has all ConfdUpcallPayload fields populated diff --git a/lib/daemon.py b/lib/daemon.py index e66d787bf9abb2023a284aa232e682b98a1c05d0..1fd1b741eb928a3f7e8d96f9ad9fe45173993d96 100644 --- a/lib/daemon.py +++ b/lib/daemon.py @@ -310,13 +310,14 @@ class AsyncUDPSocket(GanetiBaseAsyncoreDispatcher): """An improved asyncore udp socket. """ - def __init__(self): + def __init__(self, family): """Constructor for AsyncUDPSocket """ GanetiBaseAsyncoreDispatcher.__init__(self) self._out_queue = [] - self.create_socket(socket.AF_INET, socket.SOCK_DGRAM) + self._family = family + self.create_socket(family, socket.SOCK_DGRAM) # this method is overriding an asyncore.dispatcher method def handle_connect(self): @@ -331,7 +332,12 @@ class AsyncUDPSocket(GanetiBaseAsyncoreDispatcher): constants.MAX_UDP_DATA_SIZE) if recv_result is not None: payload, address = recv_result - ip, port = address + if self._family == socket.AF_INET6: + # we ignore 'flow info' and 'scope id' as we don't need them + ip, port, _, _ = address + else: + ip, port = address + self.handle_datagram(payload, ip, port) def handle_datagram(self, payload, ip, port): diff --git a/test/ganeti.confd.client_unittest.py b/test/ganeti.confd.client_unittest.py index e501c1d2ecc5122d498f9431f99175adf95578c3..7c5c2efa5ccacb75138b3b957d467b99f3e81208 100755 --- a/test/ganeti.confd.client_unittest.py +++ b/test/ganeti.confd.client_unittest.py @@ -22,6 +22,7 @@ """Script for unittesting the confd client module""" +import socket import unittest from ganeti import confd @@ -96,8 +97,11 @@ class MockTime(ResettableMock): self.mytime += delta -class TestClient(unittest.TestCase): - """Client tests""" +class _BaseClientTest: + """Base class for client tests""" + mc_list = None + new_peers = None + family = None def setUp(self): self.mock_time = MockTime() @@ -105,16 +109,6 @@ class TestClient(unittest.TestCase): confd.client.ConfdAsyncUDPClient = MockConfdAsyncUDPClient self.logger = MockLogger() hmac_key = "mykeydata" - self.mc_list = ['10.0.0.1', - '10.0.0.2', - '10.0.0.3', - '10.0.0.4', - '10.0.0.5', - '10.0.0.6', - '10.0.0.7', - '10.0.0.8', - '10.0.0.9', - ] self.callback = MockCallback() self.client = confd.client.ConfdClient(hmac_key, self.mc_list, self.callback, logger=self.logger) @@ -178,13 +172,60 @@ class TestClient(unittest.TestCase): self.assertEquals(self.callback.call_count, 1) def testUpdatePeerList(self): - new_peers = ['1.2.3.4', '1.2.3.5'] - self.client.UpdatePeerList(new_peers) - self.assertEquals(self.client._peers, new_peers) + self.client.UpdatePeerList(self.new_peers) + self.assertEquals(self.client._peers, self.new_peers) req = confd.client.ConfdClientRequest(type=constants.CONFD_REQ_PING) self.client.SendRequest(req) - self.assertEquals(self.client._socket.send_count, len(new_peers)) - self.assert_(self.client._socket.last_address in new_peers) + self.assertEquals(self.client._socket.send_count, len(self.new_peers)) + self.assert_(self.client._socket.last_address in self.new_peers) + + def testSetPeersFamily(self): + self.client._SetPeersAddressFamily() + self.assertEquals(self.client._family, self.family) + mixed_peers = ["1.2.3.6", "2001:db8:beef::13"] + self.client.UpdatePeerList(mixed_peers) + self.assertRaises(errors.ConfdClientError, + self.client._SetPeersAddressFamily) + + +class TestIP4Client(unittest.TestCase, _BaseClientTest): + """Client tests""" + mc_list = ["10.0.0.1", + "10.0.0.2", + "10.0.0.3", + "10.0.0.4", + "10.0.0.5", + "10.0.0.6", + "10.0.0.7", + "10.0.0.8", + "10.0.0.9", + ] + new_peers = ["1.2.3.4", "1.2.3.5"] + family = socket.AF_INET + + def setUp(self): + unittest.TestCase.setUp(self) + _BaseClientTest.setUp(self) + + +class TestIP6Client(unittest.TestCase, _BaseClientTest): + """Client tests""" + mc_list = ["2001:db8::1", + "2001:db8::2", + "2001:db8::3", + "2001:db8::4", + "2001:db8::5", + "2001:db8::6", + "2001:db8::7", + "2001:db8::8", + "2001:db8::9", + ] + new_peers = ["2001:db8:beef::11", "2001:db8:beef::12"] + family = socket.AF_INET6 + + def setUp(self): + unittest.TestCase.setUp(self) + _BaseClientTest.setUp(self) if __name__ == '__main__': diff --git a/test/ganeti.daemon_unittest.py b/test/ganeti.daemon_unittest.py index 1343130a10348a27485cebd1306883dafa086488..5be22ddf49b7b5b045c1d0456dbcd8d2ef1cdbca 100755 --- a/test/ganeti.daemon_unittest.py +++ b/test/ganeti.daemon_unittest.py @@ -149,8 +149,8 @@ class TestMainloop(testutils.GanetiTestCase): class _MyAsyncUDPSocket(daemon.AsyncUDPSocket): - def __init__(self): - daemon.AsyncUDPSocket.__init__(self) + def __init__(self, family): + daemon.AsyncUDPSocket.__init__(self, family) self.received = [] self.error_count = 0 @@ -166,15 +166,17 @@ class _MyAsyncUDPSocket(daemon.AsyncUDPSocket): raise -class TestAsyncUDPSocket(testutils.GanetiTestCase): - """Test daemon.AsyncUDPSocket""" +class _BaseAsyncUDPSocketTest: + """Base class for AsyncUDPSocket tests""" + + family = None + address = None def setUp(self): - testutils.GanetiTestCase.setUp(self) self.mainloop = daemon.Mainloop() - self.server = _MyAsyncUDPSocket() - self.client = _MyAsyncUDPSocket() - self.server.bind(("127.0.0.1", 0)) + self.server = _MyAsyncUDPSocket(self.family) + self.client = _MyAsyncUDPSocket(self.family) + self.server.bind((self.address, 0)) self.port = self.server.getsockname()[1] # Save utils.IgnoreSignals so we can do evil things to it... self.saved_utils_ignoresignals = utils.IgnoreSignals @@ -187,38 +189,38 @@ class TestAsyncUDPSocket(testutils.GanetiTestCase): testutils.GanetiTestCase.tearDown(self) def testNoDoubleBind(self): - self.assertRaises(socket.error, self.client.bind, ("127.0.0.1", self.port)) + self.assertRaises(socket.error, self.client.bind, (self.address, self.port)) def testAsyncClientServer(self): - self.client.enqueue_send("127.0.0.1", self.port, "p1") - self.client.enqueue_send("127.0.0.1", self.port, "p2") - self.client.enqueue_send("127.0.0.1", self.port, "terminate") + self.client.enqueue_send(self.address, self.port, "p1") + self.client.enqueue_send(self.address, self.port, "p2") + self.client.enqueue_send(self.address, self.port, "terminate") self.mainloop.Run() self.assertEquals(self.server.received, ["p1", "p2", "terminate"]) def testSyncClientServer(self): self.client.handle_write() - self.client.enqueue_send("127.0.0.1", self.port, "p1") - self.client.enqueue_send("127.0.0.1", self.port, "p2") + self.client.enqueue_send(self.address, self.port, "p1") + self.client.enqueue_send(self.address, self.port, "p2") while self.client.writable(): self.client.handle_write() self.server.process_next_packet() self.assertEquals(self.server.received, ["p1"]) self.server.process_next_packet() self.assertEquals(self.server.received, ["p1", "p2"]) - self.client.enqueue_send("127.0.0.1", self.port, "p3") + self.client.enqueue_send(self.address, self.port, "p3") while self.client.writable(): self.client.handle_write() self.server.process_next_packet() self.assertEquals(self.server.received, ["p1", "p2", "p3"]) def testErrorHandling(self): - self.client.enqueue_send("127.0.0.1", self.port, "p1") - self.client.enqueue_send("127.0.0.1", self.port, "p2") - self.client.enqueue_send("127.0.0.1", self.port, "error") - self.client.enqueue_send("127.0.0.1", self.port, "p3") - self.client.enqueue_send("127.0.0.1", self.port, "error") - self.client.enqueue_send("127.0.0.1", self.port, "terminate") + self.client.enqueue_send(self.address, self.port, "p1") + self.client.enqueue_send(self.address, self.port, "p2") + self.client.enqueue_send(self.address, self.port, "error") + self.client.enqueue_send(self.address, self.port, "p3") + self.client.enqueue_send(self.address, self.port, "error") + self.client.enqueue_send(self.address, self.port, "terminate") self.assertRaises(errors.GenericError, self.mainloop.Run) self.assertEquals(self.server.received, ["p1", "p2", "error"]) @@ -234,11 +236,11 @@ class TestAsyncUDPSocket(testutils.GanetiTestCase): def testSignaledWhileReceiving(self): utils.IgnoreSignals = lambda fn, *args, **kwargs: None - self.client.enqueue_send("127.0.0.1", self.port, "p1") - self.client.enqueue_send("127.0.0.1", self.port, "p2") + self.client.enqueue_send(self.address, self.port, "p1") + self.client.enqueue_send(self.address, self.port, "p2") self.server.handle_read() self.assertEquals(self.server.received, []) - self.client.enqueue_send("127.0.0.1", self.port, "terminate") + self.client.enqueue_send(self.address, self.port, "terminate") utils.IgnoreSignals = self.saved_utils_ignoresignals self.mainloop.Run() self.assertEquals(self.server.received, ["p1", "p2", "terminate"]) @@ -246,7 +248,37 @@ class TestAsyncUDPSocket(testutils.GanetiTestCase): def testOversizedDatagram(self): oversized_data = (constants.MAX_UDP_DATA_SIZE + 1) * "a" self.assertRaises(errors.UdpDataSizeError, self.client.enqueue_send, - "127.0.0.1", self.port, oversized_data) + self.address, self.port, oversized_data) + + +class TestAsyncIP4UDPSocket(testutils.GanetiTestCase, _BaseAsyncUDPSocketTest): + """Test IP4 daemon.AsyncUDPSocket""" + + family = socket.AF_INET + address = "127.0.0.1" + + def setUp(self): + testutils.GanetiTestCase.setUp(self) + _BaseAsyncUDPSocketTest.setUp(self) + + def tearDown(self): + testutils.GanetiTestCase.tearDown(self) + _BaseAsyncUDPSocketTest.tearDown(self) + + +class TestAsyncIP6UDPSocket(testutils.GanetiTestCase, _BaseAsyncUDPSocketTest): + """Test IP6 daemon.AsyncUDPSocket""" + + family = socket.AF_INET6 + address = "::1" + + def setUp(self): + testutils.GanetiTestCase.setUp(self) + _BaseAsyncUDPSocketTest.setUp(self) + + def tearDown(self): + testutils.GanetiTestCase.tearDown(self) + _BaseAsyncUDPSocketTest.tearDown(self) class _MyAsyncStreamServer(daemon.AsyncStreamServer):