diff --git a/test/ganeti.daemon_unittest.py b/test/ganeti.daemon_unittest.py index 374d1c37b2cb261d377cf1a76a6a1f96584155c0..e9f10bd2818cad37991c7311321c986c788b5790 100755 --- a/test/ganeti.daemon_unittest.py +++ b/test/ganeti.daemon_unittest.py @@ -24,6 +24,7 @@ import unittest import signal import os +import socket from ganeti import daemon @@ -94,5 +95,83 @@ class TestMainloop(testutils.GanetiTestCase): self.assertEquals(self.onsignal_events, self.sendsig_events) +class _MyAsyncUDPSocket(daemon.AsyncUDPSocket): + + def __init__(self): + daemon.AsyncUDPSocket.__init__(self) + self.received = [] + self.error_count = 0 + + def handle_datagram(self, payload, ip, port): + self.received.append((payload)) + if payload == "terminate": + os.kill(os.getpid(), signal.SIGTERM) + elif payload == "error": + raise errors.GenericError("error") + + def handle_error(self): + self.error_count += 1 + + +class TestAsyncUDPSocket(testutils.GanetiTestCase): + """Test daemon.AsyncUDPSocket""" + + def setUp(self): + testutils.GanetiTestCase.setUp(self) + self.mainloop = daemon.Mainloop() + self.server = _MyAsyncUDPSocket() + self.client = _MyAsyncUDPSocket() + self.server.bind(("127.0.0.1", 0)) + self.port = self.server.getsockname()[1] + + def tearDown(self): + self.server.close() + self.client.close() + testutils.GanetiTestCase.tearDown(self) + + def testNoDoubleBind(self): + self.assertRaises(socket.error, self.client.bind, ("127.0.0.1", self.port)) + + def _ThreadedClient(self, payload): + self.client.enqueue_send("127.0.0.1", self.port, payload) + print "sending %s" % payload + while self.client.writable(): + self.client.handle_write() + + def testAsyncClientServer(self): + self.client.enqueue_send("127.0.0.1", self.port, "p1") + self.client.enqueue_send("127.0.0.1", self.port, "p2") + self.client.enqueue_send("127.0.0.1", self.port, "terminate") + self.mainloop.Run() + self.assertEquals(self.server.received, ["p1", "p2", "terminate"]) + + def testSyncClientServer(self): + self.client.enqueue_send("127.0.0.1", self.port, "p1") + self.client.enqueue_send("127.0.0.1", self.port, "p2") + while self.client.writable(): + self.client.handle_write() + self.server.process_next_packet() + self.assertEquals(self.server.received, ["p1"]) + self.server.process_next_packet() + self.assertEquals(self.server.received, ["p1", "p2"]) + self.client.enqueue_send("127.0.0.1", self.port, "p3") + while self.client.writable(): + self.client.handle_write() + self.server.process_next_packet() + self.assertEquals(self.server.received, ["p1", "p2", "p3"]) + + def testErrorHandling(self): + self.client.enqueue_send("127.0.0.1", self.port, "p1") + self.client.enqueue_send("127.0.0.1", self.port, "p2") + self.client.enqueue_send("127.0.0.1", self.port, "error") + self.client.enqueue_send("127.0.0.1", self.port, "p3") + self.client.enqueue_send("127.0.0.1", self.port, "error") + self.client.enqueue_send("127.0.0.1", self.port, "terminate") + self.mainloop.Run() + self.assertEquals(self.server.received, + ["p1", "p2", "error", "p3", "error", "terminate"]) + self.assertEquals(self.server.error_count, 2) + + if __name__ == "__main__": testutils.GanetiTestProgram()