Commit 18215385 authored by Guido Trotter's avatar Guido Trotter
Browse files

Test the new streaming daemon classes



Unittests cover AsyncStreamServer and AsyncTerminatedMessageStream with
both tcp and unix sockets.
Signed-off-by: default avatarGuido Trotter <ultrotter@google.com>
Reviewed-by: default avatarMichael Hanselmann <hansmi@google.com>
parent b66ab629
......@@ -26,6 +26,8 @@ import signal
import os
import socket
import time
import tempfile
import shutil
from ganeti import daemon
from ganeti import errors
......@@ -246,5 +248,241 @@ class TestAsyncUDPSocket(testutils.GanetiTestCase):
self.assertEquals(self.server.received, ["p1", "p2", "terminate"])
class _MyAsyncStreamServer(daemon.AsyncStreamServer):
def __init__(self, family, address, handle_connection_fn):
daemon.AsyncStreamServer.__init__(self, family, address)
self.handle_connection_fn = handle_connection_fn
self.error_count = 0
self.expt_count = 0
def handle_connection(self, connected_socket, client_address):
self.handle_connection_fn(connected_socket, client_address)
def handle_error(self):
self.error_count += 1
self.close()
raise
def handle_expt(self):
self.expt_count += 1
self.close()
class _MyMessageStreamHandler(daemon.AsyncTerminatedMessageStream):
def __init__(self, connected_socket, client_address, terminator, family,
message_fn, client_id):
daemon.AsyncTerminatedMessageStream.__init__(self, connected_socket,
client_address,
terminator, family)
self.message_fn = message_fn
self.client_id = client_id
self.error_count = 0
def handle_message(self, message, message_id):
self.message_fn(self, message, message_id)
def handle_error(self):
self.error_count += 1
raise
class TestAsyncStreamServerTCP(testutils.GanetiTestCase):
"""Test daemon.AsyncStreamServer with a TCP connection"""
family = socket.AF_INET
def setUp(self):
testutils.GanetiTestCase.setUp(self)
self.mainloop = daemon.Mainloop()
self.address = self.getAddress()
self.server = _MyAsyncStreamServer(self.family, self.address,
self.handle_connection)
self.client_handler = _MyMessageStreamHandler
self.terminator = "\3"
self.address = self.server.getsockname()
self.clients = []
self.connections = []
self.messages = {}
self.connect_terminate_count = 0
self.message_terminate_count = 0
self.next_client_id = 0
# Save utils.IgnoreSignals so we can do evil things to it...
self.saved_utils_ignoresignals = utils.IgnoreSignals
def tearDown(self):
for c in self.clients:
c.close()
for c in self.connections:
c.close()
self.server.close()
# ...and restore it as well
utils.IgnoreSignals = self.saved_utils_ignoresignals
testutils.GanetiTestCase.tearDown(self)
def getAddress(self):
return ("127.0.0.1", 0)
def countTerminate(self, name):
value = getattr(self, name)
if value is not None:
value -= 1
setattr(self, name, value)
if value <= 0:
os.kill(os.getpid(), signal.SIGTERM)
def handle_connection(self, connected_socket, client_address):
client_id = self.next_client_id
self.next_client_id += 1
client_handler = self.client_handler(connected_socket, client_address,
self.terminator, self.family,
self.handle_message,
client_id)
self.connections.append(client_handler)
self.countTerminate("connect_terminate_count")
def handle_message(self, handler, message, message_id):
self.messages.setdefault(handler.client_id, [])
# We should just check that the message_ids are monotonically increasing.
# If in the unit tests we never remove messages from the received queue,
# though, we can just require that the queue length is the same as the
# message id, before pushing the message to it. This forces a more
# restrictive check, but we can live with this for now.
self.assertEquals(len(self.messages[handler.client_id]), message_id)
self.messages[handler.client_id].append(message)
if message == "error":
raise errors.GenericError("error")
self.countTerminate("message_terminate_count")
def getClient(self):
client = socket.socket(self.family, socket.SOCK_STREAM)
client.connect(self.address)
self.clients.append(client)
return client
def tearDown(self):
testutils.GanetiTestCase.tearDown(self)
self.server.close()
def testConnect(self):
self.getClient()
self.mainloop.Run()
self.assertEquals(len(self.connections), 1)
self.getClient()
self.mainloop.Run()
self.assertEquals(len(self.connections), 2)
self.connect_terminate_count = 4
self.getClient()
self.getClient()
self.getClient()
self.getClient()
self.mainloop.Run()
self.assertEquals(len(self.connections), 6)
def testBasicMessage(self):
self.connect_terminate_count = None
client = self.getClient()
client.send("ciao\3")
self.mainloop.Run()
self.assertEquals(len(self.connections), 1)
self.assertEquals(len(self.messages[0]), 1)
self.assertEquals(self.messages[0][0], "ciao")
def testDoubleMessage(self):
self.connect_terminate_count = None
client = self.getClient()
client.send("ciao\3")
self.mainloop.Run()
client.send("foobar\3")
self.mainloop.Run()
self.assertEquals(len(self.connections), 1)
self.assertEquals(len(self.messages[0]), 2)
self.assertEquals(self.messages[0][1], "foobar")
def testComposedMessage(self):
self.connect_terminate_count = None
self.message_terminate_count = 3
client = self.getClient()
client.send("one\3composed\3message\3")
self.mainloop.Run()
self.assertEquals(len(self.messages[0]), 3)
self.assertEquals(self.messages[0], ["one", "composed", "message"])
def testLongTerminator(self):
self.terminator = "\0\1\2"
self.connect_terminate_count = None
self.message_terminate_count = 3
client = self.getClient()
client.send("one\0\1\2composed\0\1\2message\0\1\2")
self.mainloop.Run()
self.assertEquals(len(self.messages[0]), 3)
self.assertEquals(self.messages[0], ["one", "composed", "message"])
def testErrorHandling(self):
self.connect_terminate_count = None
self.message_terminate_count = None
client = self.getClient()
client.send("one\3two\3error\3three\3")
self.assertRaises(errors.GenericError, self.mainloop.Run)
self.assertEquals(self.connections[0].error_count, 1)
self.assertEquals(self.messages[0], ["one", "two", "error"])
client.send("error\3")
self.assertRaises(errors.GenericError, self.mainloop.Run)
self.assertEquals(self.connections[0].error_count, 2)
self.assertEquals(self.messages[0], ["one", "two", "error", "three",
"error"])
def testDoubleClient(self):
self.connect_terminate_count = None
self.message_terminate_count = 2
client1 = self.getClient()
client2 = self.getClient()
client1.send("c1m1\3")
client2.send("c2m1\3")
self.mainloop.Run()
self.assertEquals(self.messages[0], ["c1m1"])
self.assertEquals(self.messages[1], ["c2m1"])
def testUnterminatedMessage(self):
self.connect_terminate_count = None
self.message_terminate_count = 3
client1 = self.getClient()
client2 = self.getClient()
client1.send("message\3unterminated")
client2.send("c2m1\3c2m2\3")
self.mainloop.Run()
self.assertEquals(self.messages[0], ["message"])
self.assertEquals(self.messages[1], ["c2m1", "c2m2"])
client1.send("message\3")
self.mainloop.Run()
self.assertEquals(self.messages[0], ["message", "unterminatedmessage"])
def testSignaledWhileAccepting(self):
utils.IgnoreSignals = lambda fn, *args, **kwargs: None
client1 = self.getClient()
self.server.handle_accept()
# When interrupted while accepting we don't have a connection, but we
# didn't crash either.
self.assertEquals(len(self.connections), 0)
utils.IgnoreSignals = self.saved_utils_ignoresignals
self.mainloop.Run()
self.assertEquals(len(self.connections), 1)
class TestAsyncStreamServerUnixPath(TestAsyncStreamServerTCP):
"""Test daemon.AsyncStreamServer with a Unix path connection"""
family = socket.AF_UNIX
def getAddress(self):
self.tmpdir = tempfile.mkdtemp()
return os.path.join(self.tmpdir, "server.sock")
def tearDown(self):
shutil.rmtree(self.tmpdir)
TestAsyncStreamServerTCP.tearDown(self)
if __name__ == "__main__":
testutils.GanetiTestProgram()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment