Commit 37e62cb9 authored by Guido Trotter's avatar Guido Trotter
Browse files

AsyncTerminatedMessageStream: limit message count



Currently the message stream can process any number of messages in
parallel (if they get dispatched to different threads or processes).
In order to limit their number we only handle messages and read from
the socket if we're under a certain limit of unanswered ones.
Signed-off-by: default avatarGuido Trotter <ultrotter@google.com>
Reviewed-by: default avatarMichael Hanselmann <hansmi@google.com>
parent 1e063ccd
......@@ -175,7 +175,8 @@ class AsyncTerminatedMessageStream(asynchat.async_chat):
separator. For each complete message handle_message is called.
"""
def __init__(self, connected_socket, peer_address, terminator, family):
def __init__(self, connected_socket, peer_address, terminator, family,
unhandled_limit):
"""AsyncTerminatedMessageStream constructor.
@type connected_socket: socket.socket
......@@ -185,6 +186,8 @@ class AsyncTerminatedMessageStream(asynchat.async_chat):
@param terminator: terminator separating messages in the stream
@type family: integer
@param family: socket family
@type unhandled_limit: integer or None
@param unhandled_limit: maximum unanswered messages
"""
# python 2.4/2.5 uses conn=... while 2.6 has sock=... we have to cheat by
......@@ -197,22 +200,36 @@ class AsyncTerminatedMessageStream(asynchat.async_chat):
self.family = family
self.peer_address = peer_address
self.terminator = terminator
self.unhandled_limit = unhandled_limit
self.set_terminator(terminator)
self.ibuffer = []
self.next_incoming_message = 0
self.receive_count = 0
self.send_count = 0
self.oqueue = collections.deque()
self.iqueue = collections.deque()
# this method is overriding an asynchat.async_chat method
def collect_incoming_data(self, data):
self.ibuffer.append(data)
def _can_handle_message(self):
return (self.unhandled_limit is None or
(self.receive_count < self.send_count + self.unhandled_limit) and
not self.iqueue)
# this method is overriding an asynchat.async_chat method
def found_terminator(self):
message = "".join(self.ibuffer)
self.ibuffer = []
message_id = self.next_incoming_message
self.next_incoming_message += 1
self.handle_message(message, message_id)
message_id = self.receive_count
# We need to increase the receive_count after checking if the message can
# be handled, but before calling handle_message
can_handle = self._can_handle_message()
self.receive_count += 1
if can_handle:
self.handle_message(message, message_id)
else:
self.iqueue.append((message, message_id))
def handle_message(self, message, message_id):
"""Handle a terminated message.
......@@ -240,9 +257,16 @@ class AsyncTerminatedMessageStream(asynchat.async_chat):
"""
# If we just append the message we received to the output queue, this
# function can be safely called by multiple threads at the same time, and
# we don't need locking, since deques are thread safe.
# we don't need locking, since deques are thread safe. handle_write in the
# asyncore thread will handle the next input message if there are any
# enqueued.
self.oqueue.append(message)
# this method is overriding an asyncore.dispatcher method
def readable(self):
# read from the socket if we can handle the next requests
return self._can_handle_message() and asynchat.async_chat.readable(self)
# this method is overriding an asyncore.dispatcher method
def writable(self):
# the output queue may become full just after we called writable. This only
......@@ -253,8 +277,14 @@ class AsyncTerminatedMessageStream(asynchat.async_chat):
# this method is overriding an asyncore.dispatcher method
def handle_write(self):
if self.oqueue:
# if we have data in the output queue, then send_message was called.
# this means we can process one more message from the input queue, if
# there are any.
data = self.oqueue.popleft()
self.push(data + self.terminator)
self.send_count += 1
if self.iqueue:
self.handle_message(*self.iqueue.popleft())
self.initiate_send()
def close_log(self):
......
......@@ -273,10 +273,11 @@ class _MyAsyncStreamServer(daemon.AsyncStreamServer):
class _MyMessageStreamHandler(daemon.AsyncTerminatedMessageStream):
def __init__(self, connected_socket, client_address, terminator, family,
message_fn, client_id):
message_fn, client_id, unhandled_limit):
daemon.AsyncTerminatedMessageStream.__init__(self, connected_socket,
client_address,
terminator, family)
terminator, family,
unhandled_limit)
self.message_fn = message_fn
self.client_id = client_id
self.error_count = 0
......@@ -301,6 +302,7 @@ class TestAsyncStreamServerTCP(testutils.GanetiTestCase):
self.server = _MyAsyncStreamServer(self.family, self.address,
self.handle_connection)
self.client_handler = _MyMessageStreamHandler
self.unhandled_limit = None
self.terminator = "\3"
self.address = self.server.getsockname()
self.clients = []
......@@ -339,7 +341,7 @@ class TestAsyncStreamServerTCP(testutils.GanetiTestCase):
client_handler = self.client_handler(connected_socket, client_address,
self.terminator, self.family,
self.handle_message,
client_id)
client_id, self.unhandled_limit)
self.connections.append(client_handler)
self.countTerminate("connect_terminate_count")
......@@ -494,6 +496,61 @@ class TestAsyncStreamServerTCP(testutils.GanetiTestCase):
self.assertEquals(client1.recv(4096), "r0\3r1\3r2\3")
self.assertRaises(socket.error, client2.recv, 4096)
def testLimitedUnhandledMessages(self):
self.connect_terminate_count = None
self.message_terminate_count = 3
self.unhandled_limit = 2
client1 = self.getClient()
client2 = self.getClient()
client1.send("one\3composed\3long\3message\3")
client2.send("c2one\3")
self.mainloop.Run()
self.assertEquals(self.messages[0], ["one", "composed"])
self.assertEquals(self.messages[1], ["c2one"])
self.assertFalse(self.connections[0].readable())
self.assert_(self.connections[1].readable())
self.connections[0].send_message("r0")
self.message_terminate_count = None
client1.send("another\3")
# when we write replies messages queued also get handled, but not the ones
# in the socket.
while self.connections[0].writable():
self.connections[0].handle_write()
self.assertFalse(self.connections[0].readable())
self.assertEquals(self.messages[0], ["one", "composed", "long"])
self.connections[0].send_message("r1")
self.connections[0].send_message("r2")
while self.connections[0].writable():
self.connections[0].handle_write()
self.assertEquals(self.messages[0], ["one", "composed", "long", "message"])
self.assert_(self.connections[0].readable())
def testLimitedUnhandledMessagesOne(self):
self.connect_terminate_count = None
self.message_terminate_count = 2
self.unhandled_limit = 1
client1 = self.getClient()
client2 = self.getClient()
client1.send("one\3composed\3message\3")
client2.send("c2one\3")
self.mainloop.Run()
self.assertEquals(self.messages[0], ["one"])
self.assertEquals(self.messages[1], ["c2one"])
self.assertFalse(self.connections[0].readable())
self.assertFalse(self.connections[1].readable())
self.connections[0].send_message("r0")
self.message_terminate_count = None
while self.connections[0].writable():
self.connections[0].handle_write()
self.assertFalse(self.connections[0].readable())
self.assertEquals(self.messages[0], ["one", "composed"])
self.connections[0].send_message("r2")
self.connections[0].send_message("r3")
while self.connections[0].writable():
self.connections[0].handle_write()
self.assertEquals(self.messages[0], ["one", "composed", "message"])
self.assert_(self.connections[0].readable())
class TestAsyncStreamServerUnixPath(TestAsyncStreamServerTCP):
"""Test daemon.AsyncStreamServer with a Unix path connection"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment