diff --git a/test/ganeti.daemon_unittest.py b/test/ganeti.daemon_unittest.py index 977a073f9d36188f902de153e2862e06aceb4847..86d1a47208ceffb1c5288cf29c2f1394399b254a 100755 --- a/test/ganeti.daemon_unittest.py +++ b/test/ganeti.daemon_unittest.py @@ -26,6 +26,8 @@ import signal import os import socket import time +import tempfile +import shutil from ganeti import daemon from ganeti import errors @@ -246,5 +248,241 @@ class TestAsyncUDPSocket(testutils.GanetiTestCase): self.assertEquals(self.server.received, ["p1", "p2", "terminate"]) +class _MyAsyncStreamServer(daemon.AsyncStreamServer): + + def __init__(self, family, address, handle_connection_fn): + daemon.AsyncStreamServer.__init__(self, family, address) + self.handle_connection_fn = handle_connection_fn + self.error_count = 0 + self.expt_count = 0 + + def handle_connection(self, connected_socket, client_address): + self.handle_connection_fn(connected_socket, client_address) + + def handle_error(self): + self.error_count += 1 + self.close() + raise + + def handle_expt(self): + self.expt_count += 1 + self.close() + + +class _MyMessageStreamHandler(daemon.AsyncTerminatedMessageStream): + + def __init__(self, connected_socket, client_address, terminator, family, + message_fn, client_id): + daemon.AsyncTerminatedMessageStream.__init__(self, connected_socket, + client_address, + terminator, family) + self.message_fn = message_fn + self.client_id = client_id + self.error_count = 0 + + def handle_message(self, message, message_id): + self.message_fn(self, message, message_id) + + def handle_error(self): + self.error_count += 1 + raise + + +class TestAsyncStreamServerTCP(testutils.GanetiTestCase): + """Test daemon.AsyncStreamServer with a TCP connection""" + + family = socket.AF_INET + + def setUp(self): + testutils.GanetiTestCase.setUp(self) + self.mainloop = daemon.Mainloop() + self.address = self.getAddress() + self.server = _MyAsyncStreamServer(self.family, self.address, + self.handle_connection) + self.client_handler = _MyMessageStreamHandler + self.terminator = "\3" + self.address = self.server.getsockname() + self.clients = [] + self.connections = [] + self.messages = {} + self.connect_terminate_count = 0 + self.message_terminate_count = 0 + self.next_client_id = 0 + # Save utils.IgnoreSignals so we can do evil things to it... + self.saved_utils_ignoresignals = utils.IgnoreSignals + + def tearDown(self): + for c in self.clients: + c.close() + for c in self.connections: + c.close() + self.server.close() + # ...and restore it as well + utils.IgnoreSignals = self.saved_utils_ignoresignals + testutils.GanetiTestCase.tearDown(self) + + def getAddress(self): + return ("127.0.0.1", 0) + + def countTerminate(self, name): + value = getattr(self, name) + if value is not None: + value -= 1 + setattr(self, name, value) + if value <= 0: + os.kill(os.getpid(), signal.SIGTERM) + + def handle_connection(self, connected_socket, client_address): + client_id = self.next_client_id + self.next_client_id += 1 + client_handler = self.client_handler(connected_socket, client_address, + self.terminator, self.family, + self.handle_message, + client_id) + self.connections.append(client_handler) + self.countTerminate("connect_terminate_count") + + def handle_message(self, handler, message, message_id): + self.messages.setdefault(handler.client_id, []) + # We should just check that the message_ids are monotonically increasing. + # If in the unit tests we never remove messages from the received queue, + # though, we can just require that the queue length is the same as the + # message id, before pushing the message to it. This forces a more + # restrictive check, but we can live with this for now. + self.assertEquals(len(self.messages[handler.client_id]), message_id) + self.messages[handler.client_id].append(message) + if message == "error": + raise errors.GenericError("error") + self.countTerminate("message_terminate_count") + + def getClient(self): + client = socket.socket(self.family, socket.SOCK_STREAM) + client.connect(self.address) + self.clients.append(client) + return client + + def tearDown(self): + testutils.GanetiTestCase.tearDown(self) + self.server.close() + + def testConnect(self): + self.getClient() + self.mainloop.Run() + self.assertEquals(len(self.connections), 1) + self.getClient() + self.mainloop.Run() + self.assertEquals(len(self.connections), 2) + self.connect_terminate_count = 4 + self.getClient() + self.getClient() + self.getClient() + self.getClient() + self.mainloop.Run() + self.assertEquals(len(self.connections), 6) + + def testBasicMessage(self): + self.connect_terminate_count = None + client = self.getClient() + client.send("ciao\3") + self.mainloop.Run() + self.assertEquals(len(self.connections), 1) + self.assertEquals(len(self.messages[0]), 1) + self.assertEquals(self.messages[0][0], "ciao") + + def testDoubleMessage(self): + self.connect_terminate_count = None + client = self.getClient() + client.send("ciao\3") + self.mainloop.Run() + client.send("foobar\3") + self.mainloop.Run() + self.assertEquals(len(self.connections), 1) + self.assertEquals(len(self.messages[0]), 2) + self.assertEquals(self.messages[0][1], "foobar") + + def testComposedMessage(self): + self.connect_terminate_count = None + self.message_terminate_count = 3 + client = self.getClient() + client.send("one\3composed\3message\3") + self.mainloop.Run() + self.assertEquals(len(self.messages[0]), 3) + self.assertEquals(self.messages[0], ["one", "composed", "message"]) + + def testLongTerminator(self): + self.terminator = "\0\1\2" + self.connect_terminate_count = None + self.message_terminate_count = 3 + client = self.getClient() + client.send("one\0\1\2composed\0\1\2message\0\1\2") + self.mainloop.Run() + self.assertEquals(len(self.messages[0]), 3) + self.assertEquals(self.messages[0], ["one", "composed", "message"]) + + def testErrorHandling(self): + self.connect_terminate_count = None + self.message_terminate_count = None + client = self.getClient() + client.send("one\3two\3error\3three\3") + self.assertRaises(errors.GenericError, self.mainloop.Run) + self.assertEquals(self.connections[0].error_count, 1) + self.assertEquals(self.messages[0], ["one", "two", "error"]) + client.send("error\3") + self.assertRaises(errors.GenericError, self.mainloop.Run) + self.assertEquals(self.connections[0].error_count, 2) + self.assertEquals(self.messages[0], ["one", "two", "error", "three", + "error"]) + + def testDoubleClient(self): + self.connect_terminate_count = None + self.message_terminate_count = 2 + client1 = self.getClient() + client2 = self.getClient() + client1.send("c1m1\3") + client2.send("c2m1\3") + self.mainloop.Run() + self.assertEquals(self.messages[0], ["c1m1"]) + self.assertEquals(self.messages[1], ["c2m1"]) + + def testUnterminatedMessage(self): + self.connect_terminate_count = None + self.message_terminate_count = 3 + client1 = self.getClient() + client2 = self.getClient() + client1.send("message\3unterminated") + client2.send("c2m1\3c2m2\3") + self.mainloop.Run() + self.assertEquals(self.messages[0], ["message"]) + self.assertEquals(self.messages[1], ["c2m1", "c2m2"]) + client1.send("message\3") + self.mainloop.Run() + self.assertEquals(self.messages[0], ["message", "unterminatedmessage"]) + + def testSignaledWhileAccepting(self): + utils.IgnoreSignals = lambda fn, *args, **kwargs: None + client1 = self.getClient() + self.server.handle_accept() + # When interrupted while accepting we don't have a connection, but we + # didn't crash either. + self.assertEquals(len(self.connections), 0) + utils.IgnoreSignals = self.saved_utils_ignoresignals + self.mainloop.Run() + self.assertEquals(len(self.connections), 1) + + +class TestAsyncStreamServerUnixPath(TestAsyncStreamServerTCP): + """Test daemon.AsyncStreamServer with a Unix path connection""" + + family = socket.AF_UNIX + + def getAddress(self): + self.tmpdir = tempfile.mkdtemp() + return os.path.join(self.tmpdir, "server.sock") + + def tearDown(self): + shutil.rmtree(self.tmpdir) + TestAsyncStreamServerTCP.tearDown(self) + + if __name__ == "__main__": testutils.GanetiTestProgram()