From d8bcfe2136f04c151baf0b918e43359a9f9a0541 Mon Sep 17 00:00:00 2001
From: Manuel Franceschini <livewire@google.com>
Date: Wed, 30 Jun 2010 11:55:18 +0200
Subject: [PATCH] Confd IPv6 support

This patch series basically adds a new parameter 'family' to the constructors
of daemon.AsyncUDPSocket and confd.client.ConfdUDPClient. This enables the
users of these two classes to support IPv6.

In ganeti-confd.ConfdAsyncUDPClient a method to check the address families of
all peers is added.

Furthermore it adds unittests for the added functionality.

Signed-off-by: Manuel Franceschini <livewire@google.com>
Reviewed-by: Iustin Pop <iustin@google.com>
---
 daemons/ganeti-confd                 | 10 +++-
 lib/confd/client.py                  | 19 ++++++-
 lib/daemon.py                        | 12 +++-
 test/ganeti.confd.client_unittest.py | 75 +++++++++++++++++++------
 test/ganeti.daemon_unittest.py       | 82 +++++++++++++++++++---------
 5 files changed, 148 insertions(+), 50 deletions(-)

diff --git a/daemons/ganeti-confd b/daemons/ganeti-confd
index 3a685c78e..804c87a31 100755
--- a/daemons/ganeti-confd
+++ b/daemons/ganeti-confd
@@ -48,6 +48,7 @@ from ganeti.confd import server as confd_server
 from ganeti import constants
 from ganeti import errors
 from ganeti import daemon
+from ganeti import netutils
 
 
 class ConfdAsyncUDPServer(daemon.AsyncUDPSocket):
@@ -58,14 +59,15 @@ class ConfdAsyncUDPServer(daemon.AsyncUDPSocket):
     """Constructor for ConfdAsyncUDPServer
 
     @type bind_address: string
-    @param bind_address: socket bind address ('' for all)
+    @param bind_address: socket bind address
     @type port: int
     @param port: udp port
     @type processor: L{confd.server.ConfdProcessor}
     @param processor: ConfdProcessor to use to handle queries
 
     """
-    daemon.AsyncUDPSocket.__init__(self)
+    daemon.AsyncUDPSocket.__init__(self,
+                                   netutils.GetAddressFamily(bind_address))
     self.bind_address = bind_address
     self.port = port
     self.processor = processor
@@ -251,6 +253,10 @@ def CheckConfd(_, args):
     print >> sys.stderr, "Need HMAC key %s to run" % constants.CONFD_HMAC_KEY
     sys.exit(constants.EXIT_FAILURE)
 
+  # TODO: once we have a cluster param specifying the address family
+  # preference, we need to check if the requested options.bind_address does not
+  # conflict with that. If so, we might warn or EXIT_FAILURE.
+
 
 def ExecConfd(options, _):
   """Main confd function, executed with PID file held
diff --git a/lib/confd/client.py b/lib/confd/client.py
index 288a1da8c..af25c680a 100644
--- a/lib/confd/client.py
+++ b/lib/confd/client.py
@@ -72,14 +72,14 @@ class ConfdAsyncUDPClient(daemon.AsyncUDPSocket):
   implement a non-asyncore based client library.
 
   """
-  def __init__(self, client):
+  def __init__(self, client, family):
     """Constructor for ConfdAsyncUDPClient
 
     @type client: L{ConfdClient}
     @param client: client library, to pass the datagrams to
 
     """
-    daemon.AsyncUDPSocket.__init__(self)
+    daemon.AsyncUDPSocket.__init__(self, family)
     self.client = client
 
   # this method is overriding a daemon.AsyncUDPSocket method
@@ -137,8 +137,9 @@ class ConfdClient:
       raise errors.ProgrammerError("callback must be callable")
 
     self.UpdatePeerList(peers)
+    self._SetPeersAddressFamily()
     self._hmac_key = hmac_key
-    self._socket = ConfdAsyncUDPClient(self)
+    self._socket = ConfdAsyncUDPClient(self, self._family)
     self._callback = callback
     self._confd_port = port
     self._logger = logger
@@ -385,6 +386,18 @@ class ConfdClient:
       else:
         return MISSING
 
+  def _SetPeersAddressFamily(self):
+    if not self._peers:
+      raise errors.ConfdClientError("Peer list empty")
+    try:
+      peer = self._peers[0]
+      self._family = netutils.GetAddressFamily(peer)
+      for peer in self._peers[1:]:
+        if netutils.GetAddressFamily(peer) != self._family:
+          raise errors.ConfdClientError("Peers must be of same address family")
+    except errors.GenericError:
+      raise errors.ConfdClientError("Peer address %s invalid" % peer)
+
 
 # UPCALL_REPLY: server reply upcall
 # has all ConfdUpcallPayload fields populated
diff --git a/lib/daemon.py b/lib/daemon.py
index e66d787bf..1fd1b741e 100644
--- a/lib/daemon.py
+++ b/lib/daemon.py
@@ -310,13 +310,14 @@ class AsyncUDPSocket(GanetiBaseAsyncoreDispatcher):
   """An improved asyncore udp socket.
 
   """
-  def __init__(self):
+  def __init__(self, family):
     """Constructor for AsyncUDPSocket
 
     """
     GanetiBaseAsyncoreDispatcher.__init__(self)
     self._out_queue = []
-    self.create_socket(socket.AF_INET, socket.SOCK_DGRAM)
+    self._family = family
+    self.create_socket(family, socket.SOCK_DGRAM)
 
   # this method is overriding an asyncore.dispatcher method
   def handle_connect(self):
@@ -331,7 +332,12 @@ class AsyncUDPSocket(GanetiBaseAsyncoreDispatcher):
                                       constants.MAX_UDP_DATA_SIZE)
     if recv_result is not None:
       payload, address = recv_result
-      ip, port = address
+      if self._family == socket.AF_INET6:
+        # we ignore 'flow info' and 'scope id' as we don't need them
+        ip, port, _, _ = address
+      else:
+        ip, port = address
+
       self.handle_datagram(payload, ip, port)
 
   def handle_datagram(self, payload, ip, port):
diff --git a/test/ganeti.confd.client_unittest.py b/test/ganeti.confd.client_unittest.py
index e501c1d2e..7c5c2efa5 100755
--- a/test/ganeti.confd.client_unittest.py
+++ b/test/ganeti.confd.client_unittest.py
@@ -22,6 +22,7 @@
 """Script for unittesting the confd client module"""
 
 
+import socket
 import unittest
 
 from ganeti import confd
@@ -96,8 +97,11 @@ class MockTime(ResettableMock):
     self.mytime += delta
 
 
-class TestClient(unittest.TestCase):
-  """Client tests"""
+class _BaseClientTest:
+  """Base class for client tests"""
+  mc_list = None
+  new_peers = None
+  family = None
 
   def setUp(self):
     self.mock_time = MockTime()
@@ -105,16 +109,6 @@ class TestClient(unittest.TestCase):
     confd.client.ConfdAsyncUDPClient = MockConfdAsyncUDPClient
     self.logger = MockLogger()
     hmac_key = "mykeydata"
-    self.mc_list = ['10.0.0.1',
-                    '10.0.0.2',
-                    '10.0.0.3',
-                    '10.0.0.4',
-                    '10.0.0.5',
-                    '10.0.0.6',
-                    '10.0.0.7',
-                    '10.0.0.8',
-                    '10.0.0.9',
-                   ]
     self.callback = MockCallback()
     self.client = confd.client.ConfdClient(hmac_key, self.mc_list,
                                            self.callback, logger=self.logger)
@@ -178,13 +172,60 @@ class TestClient(unittest.TestCase):
     self.assertEquals(self.callback.call_count, 1)
 
   def testUpdatePeerList(self):
-    new_peers = ['1.2.3.4', '1.2.3.5']
-    self.client.UpdatePeerList(new_peers)
-    self.assertEquals(self.client._peers, new_peers)
+    self.client.UpdatePeerList(self.new_peers)
+    self.assertEquals(self.client._peers, self.new_peers)
     req = confd.client.ConfdClientRequest(type=constants.CONFD_REQ_PING)
     self.client.SendRequest(req)
-    self.assertEquals(self.client._socket.send_count, len(new_peers))
-    self.assert_(self.client._socket.last_address in new_peers)
+    self.assertEquals(self.client._socket.send_count, len(self.new_peers))
+    self.assert_(self.client._socket.last_address in self.new_peers)
+
+  def testSetPeersFamily(self):
+    self.client._SetPeersAddressFamily()
+    self.assertEquals(self.client._family, self.family)
+    mixed_peers = ["1.2.3.6", "2001:db8:beef::13"]
+    self.client.UpdatePeerList(mixed_peers)
+    self.assertRaises(errors.ConfdClientError,
+                      self.client._SetPeersAddressFamily)
+
+
+class TestIP4Client(unittest.TestCase, _BaseClientTest):
+  """Client tests"""
+  mc_list = ["10.0.0.1",
+             "10.0.0.2",
+             "10.0.0.3",
+             "10.0.0.4",
+             "10.0.0.5",
+             "10.0.0.6",
+             "10.0.0.7",
+             "10.0.0.8",
+             "10.0.0.9",
+            ]
+  new_peers = ["1.2.3.4", "1.2.3.5"]
+  family = socket.AF_INET
+
+  def setUp(self):
+    unittest.TestCase.setUp(self)
+    _BaseClientTest.setUp(self)
+
+
+class TestIP6Client(unittest.TestCase, _BaseClientTest):
+  """Client tests"""
+  mc_list = ["2001:db8::1",
+             "2001:db8::2",
+             "2001:db8::3",
+             "2001:db8::4",
+             "2001:db8::5",
+             "2001:db8::6",
+             "2001:db8::7",
+             "2001:db8::8",
+             "2001:db8::9",
+            ]
+  new_peers = ["2001:db8:beef::11", "2001:db8:beef::12"]
+  family = socket.AF_INET6
+
+  def setUp(self):
+    unittest.TestCase.setUp(self)
+    _BaseClientTest.setUp(self)
 
 
 if __name__ == '__main__':
diff --git a/test/ganeti.daemon_unittest.py b/test/ganeti.daemon_unittest.py
index 1343130a1..5be22ddf4 100755
--- a/test/ganeti.daemon_unittest.py
+++ b/test/ganeti.daemon_unittest.py
@@ -149,8 +149,8 @@ class TestMainloop(testutils.GanetiTestCase):
 
 class _MyAsyncUDPSocket(daemon.AsyncUDPSocket):
 
-  def __init__(self):
-    daemon.AsyncUDPSocket.__init__(self)
+  def __init__(self, family):
+    daemon.AsyncUDPSocket.__init__(self, family)
     self.received = []
     self.error_count = 0
 
@@ -166,15 +166,17 @@ class _MyAsyncUDPSocket(daemon.AsyncUDPSocket):
     raise
 
 
-class TestAsyncUDPSocket(testutils.GanetiTestCase):
-  """Test daemon.AsyncUDPSocket"""
+class _BaseAsyncUDPSocketTest:
+  """Base class for  AsyncUDPSocket tests"""
+
+  family = None
+  address = None
 
   def setUp(self):
-    testutils.GanetiTestCase.setUp(self)
     self.mainloop = daemon.Mainloop()
-    self.server = _MyAsyncUDPSocket()
-    self.client = _MyAsyncUDPSocket()
-    self.server.bind(("127.0.0.1", 0))
+    self.server = _MyAsyncUDPSocket(self.family)
+    self.client = _MyAsyncUDPSocket(self.family)
+    self.server.bind((self.address, 0))
     self.port = self.server.getsockname()[1]
     # Save utils.IgnoreSignals so we can do evil things to it...
     self.saved_utils_ignoresignals = utils.IgnoreSignals
@@ -187,38 +189,38 @@ class TestAsyncUDPSocket(testutils.GanetiTestCase):
     testutils.GanetiTestCase.tearDown(self)
 
   def testNoDoubleBind(self):
-    self.assertRaises(socket.error, self.client.bind, ("127.0.0.1", self.port))
+    self.assertRaises(socket.error, self.client.bind, (self.address, self.port))
 
   def testAsyncClientServer(self):
-    self.client.enqueue_send("127.0.0.1", self.port, "p1")
-    self.client.enqueue_send("127.0.0.1", self.port, "p2")
-    self.client.enqueue_send("127.0.0.1", self.port, "terminate")
+    self.client.enqueue_send(self.address, self.port, "p1")
+    self.client.enqueue_send(self.address, self.port, "p2")
+    self.client.enqueue_send(self.address, self.port, "terminate")
     self.mainloop.Run()
     self.assertEquals(self.server.received, ["p1", "p2", "terminate"])
 
   def testSyncClientServer(self):
     self.client.handle_write()
-    self.client.enqueue_send("127.0.0.1", self.port, "p1")
-    self.client.enqueue_send("127.0.0.1", self.port, "p2")
+    self.client.enqueue_send(self.address, self.port, "p1")
+    self.client.enqueue_send(self.address, self.port, "p2")
     while self.client.writable():
       self.client.handle_write()
     self.server.process_next_packet()
     self.assertEquals(self.server.received, ["p1"])
     self.server.process_next_packet()
     self.assertEquals(self.server.received, ["p1", "p2"])
-    self.client.enqueue_send("127.0.0.1", self.port, "p3")
+    self.client.enqueue_send(self.address, self.port, "p3")
     while self.client.writable():
       self.client.handle_write()
     self.server.process_next_packet()
     self.assertEquals(self.server.received, ["p1", "p2", "p3"])
 
   def testErrorHandling(self):
-    self.client.enqueue_send("127.0.0.1", self.port, "p1")
-    self.client.enqueue_send("127.0.0.1", self.port, "p2")
-    self.client.enqueue_send("127.0.0.1", self.port, "error")
-    self.client.enqueue_send("127.0.0.1", self.port, "p3")
-    self.client.enqueue_send("127.0.0.1", self.port, "error")
-    self.client.enqueue_send("127.0.0.1", self.port, "terminate")
+    self.client.enqueue_send(self.address, self.port, "p1")
+    self.client.enqueue_send(self.address, self.port, "p2")
+    self.client.enqueue_send(self.address, self.port, "error")
+    self.client.enqueue_send(self.address, self.port, "p3")
+    self.client.enqueue_send(self.address, self.port, "error")
+    self.client.enqueue_send(self.address, self.port, "terminate")
     self.assertRaises(errors.GenericError, self.mainloop.Run)
     self.assertEquals(self.server.received,
                       ["p1", "p2", "error"])
@@ -234,11 +236,11 @@ class TestAsyncUDPSocket(testutils.GanetiTestCase):
 
   def testSignaledWhileReceiving(self):
     utils.IgnoreSignals = lambda fn, *args, **kwargs: None
-    self.client.enqueue_send("127.0.0.1", self.port, "p1")
-    self.client.enqueue_send("127.0.0.1", self.port, "p2")
+    self.client.enqueue_send(self.address, self.port, "p1")
+    self.client.enqueue_send(self.address, self.port, "p2")
     self.server.handle_read()
     self.assertEquals(self.server.received, [])
-    self.client.enqueue_send("127.0.0.1", self.port, "terminate")
+    self.client.enqueue_send(self.address, self.port, "terminate")
     utils.IgnoreSignals = self.saved_utils_ignoresignals
     self.mainloop.Run()
     self.assertEquals(self.server.received, ["p1", "p2", "terminate"])
@@ -246,7 +248,37 @@ class TestAsyncUDPSocket(testutils.GanetiTestCase):
   def testOversizedDatagram(self):
     oversized_data = (constants.MAX_UDP_DATA_SIZE + 1) * "a"
     self.assertRaises(errors.UdpDataSizeError, self.client.enqueue_send,
-                      "127.0.0.1", self.port, oversized_data)
+                      self.address, self.port, oversized_data)
+
+
+class TestAsyncIP4UDPSocket(testutils.GanetiTestCase, _BaseAsyncUDPSocketTest):
+  """Test IP4 daemon.AsyncUDPSocket"""
+
+  family = socket.AF_INET
+  address = "127.0.0.1"
+
+  def setUp(self):
+    testutils.GanetiTestCase.setUp(self)
+    _BaseAsyncUDPSocketTest.setUp(self)
+
+  def tearDown(self):
+    testutils.GanetiTestCase.tearDown(self)
+    _BaseAsyncUDPSocketTest.tearDown(self)
+
+
+class TestAsyncIP6UDPSocket(testutils.GanetiTestCase, _BaseAsyncUDPSocketTest):
+  """Test IP6 daemon.AsyncUDPSocket"""
+
+  family = socket.AF_INET6
+  address = "::1"
+
+  def setUp(self):
+    testutils.GanetiTestCase.setUp(self)
+    _BaseAsyncUDPSocketTest.setUp(self)
+
+  def tearDown(self):
+    testutils.GanetiTestCase.tearDown(self)
+    _BaseAsyncUDPSocketTest.tearDown(self)
 
 
 class _MyAsyncStreamServer(daemon.AsyncStreamServer):
-- 
GitLab