Skip to content
Snippets Groups Projects
ganeti.daemon_unittest.py 24 KiB
Newer Older
#!/usr/bin/python
#

# Copyright (C) 2010 Google Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
# 02110-1301, USA.


"""Script for unittesting the daemon module"""

import unittest
import signal
import os
Guido Trotter's avatar
Guido Trotter committed
import socket
import tempfile
import shutil

from ganeti import daemon
from ganeti import errors
from ganeti import constants
from ganeti import utils

import testutils


class TestMainloop(testutils.GanetiTestCase):
  """Test daemon.Mainloop"""

  def setUp(self):
    testutils.GanetiTestCase.setUp(self)
    self.mainloop = daemon.Mainloop()
    self.sendsig_events = []
    self.onsignal_events = []

  def _CancelEvent(self, handle):
    self.mainloop.scheduler.cancel(handle)

  def _SendSig(self, sig):
    self.sendsig_events.append(sig)
    os.kill(os.getpid(), sig)

  def OnSignal(self, signum):
    self.onsignal_events.append(signum)

  def testRunAndTermBySched(self):
    self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGTERM])
    self.mainloop.Run() # terminates by _SendSig being scheduled
    self.assertEquals(self.sendsig_events, [signal.SIGTERM])

  def testTerminatingSignals(self):
    self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGCHLD])
    self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGINT])
    self.mainloop.Run()
    self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGINT])
    self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGTERM])
    self.mainloop.Run()
    self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGINT,
                                            signal.SIGTERM])

  def testSchedulerCancel(self):
    handle = self.mainloop.scheduler.enter(0.1, 1, self._SendSig,
                                           [signal.SIGTERM])
    self.mainloop.scheduler.cancel(handle)
    self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGCHLD])
    self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
    self.mainloop.Run()
    self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGTERM])

  def testRegisterSignal(self):
    self.mainloop.RegisterSignal(self)
    self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGCHLD])
    handle = self.mainloop.scheduler.enter(0.1, 1, self._SendSig,
                                           [signal.SIGTERM])
    self.mainloop.scheduler.cancel(handle)
    self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGCHLD])
    self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
    # ...not delievered because they are scheduled after TERM
    self.mainloop.scheduler.enter(0.4, 1, self._SendSig, [signal.SIGCHLD])
    self.mainloop.scheduler.enter(0.5, 1, self._SendSig, [signal.SIGCHLD])
    self.mainloop.Run()
    self.assertEquals(self.sendsig_events,
                      [signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM])
    self.assertEquals(self.onsignal_events, self.sendsig_events)

  def testDeferredCancel(self):
    self.mainloop.RegisterSignal(self)
    now = time.time()
    self.mainloop.scheduler.enterabs(now + 0.1, 1, self._SendSig,
                                     [signal.SIGCHLD])
    handle1 = self.mainloop.scheduler.enterabs(now + 0.3, 2, self._SendSig,
                                               [signal.SIGCHLD])
    handle2 = self.mainloop.scheduler.enterabs(now + 0.4, 2, self._SendSig,
                                               [signal.SIGCHLD])
    self.mainloop.scheduler.enterabs(now + 0.2, 1, self._CancelEvent,
                                     [handle1])
    self.mainloop.scheduler.enterabs(now + 0.2, 1, self._CancelEvent,
                                     [handle2])
    self.mainloop.scheduler.enter(0.5, 1, self._SendSig, [signal.SIGTERM])
    self.mainloop.Run()
    self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGTERM])
    self.assertEquals(self.onsignal_events, self.sendsig_events)

Guido Trotter's avatar
Guido Trotter committed
  def testReRun(self):
    self.mainloop.RegisterSignal(self)
    self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGCHLD])
    self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGCHLD])
    self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
    self.mainloop.scheduler.enter(0.4, 1, self._SendSig, [signal.SIGCHLD])
    self.mainloop.scheduler.enter(0.5, 1, self._SendSig, [signal.SIGCHLD])
    self.mainloop.Run()
    self.assertEquals(self.sendsig_events,
                      [signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM])
    self.assertEquals(self.onsignal_events, self.sendsig_events)
    self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
    self.mainloop.Run()
    self.assertEquals(self.sendsig_events,
                      [signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM,
                       signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM])
    self.assertEquals(self.onsignal_events, self.sendsig_events)

  def testPriority(self):
    # for events at the same time, the highest priority one executes first
    now = time.time()
    self.mainloop.scheduler.enterabs(now + 0.1, 2, self._SendSig,
                                     [signal.SIGCHLD])
    self.mainloop.scheduler.enterabs(now + 0.1, 1, self._SendSig,
                                     [signal.SIGTERM])
    self.mainloop.Run()
    self.assertEquals(self.sendsig_events, [signal.SIGTERM])
    self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGTERM])
    self.mainloop.Run()
    self.assertEquals(self.sendsig_events,
                      [signal.SIGTERM, signal.SIGCHLD, signal.SIGTERM])

Guido Trotter's avatar
Guido Trotter committed
class _MyAsyncUDPSocket(daemon.AsyncUDPSocket):

Manuel Franceschini's avatar
Manuel Franceschini committed
  def __init__(self, family):
    daemon.AsyncUDPSocket.__init__(self, family)
Guido Trotter's avatar
Guido Trotter committed
    self.received = []
    self.error_count = 0

  def handle_datagram(self, payload, ip, port):
    self.received.append((payload))
    if payload == "terminate":
      os.kill(os.getpid(), signal.SIGTERM)
    elif payload == "error":
      raise errors.GenericError("error")

  def handle_error(self):
    self.error_count += 1
Manuel Franceschini's avatar
Manuel Franceschini committed
class _BaseAsyncUDPSocketTest:
  """Base class for  AsyncUDPSocket tests"""

  family = None
  address = None
Guido Trotter's avatar
Guido Trotter committed

  def setUp(self):
    self.mainloop = daemon.Mainloop()
Manuel Franceschini's avatar
Manuel Franceschini committed
    self.server = _MyAsyncUDPSocket(self.family)
    self.client = _MyAsyncUDPSocket(self.family)
    self.server.bind((self.address, 0))
Guido Trotter's avatar
Guido Trotter committed
    self.port = self.server.getsockname()[1]
    # Save utils.IgnoreSignals so we can do evil things to it...
    self.saved_utils_ignoresignals = utils.IgnoreSignals
Guido Trotter's avatar
Guido Trotter committed

  def tearDown(self):
    self.server.close()
    self.client.close()
    # ...and restore it as well
    utils.IgnoreSignals = self.saved_utils_ignoresignals
Guido Trotter's avatar
Guido Trotter committed
    testutils.GanetiTestCase.tearDown(self)

  def testNoDoubleBind(self):
Manuel Franceschini's avatar
Manuel Franceschini committed
    self.assertRaises(socket.error, self.client.bind, (self.address, self.port))
Guido Trotter's avatar
Guido Trotter committed

  def testAsyncClientServer(self):
Manuel Franceschini's avatar
Manuel Franceschini committed
    self.client.enqueue_send(self.address, self.port, "p1")
    self.client.enqueue_send(self.address, self.port, "p2")
    self.client.enqueue_send(self.address, self.port, "terminate")
Guido Trotter's avatar
Guido Trotter committed
    self.mainloop.Run()
    self.assertEquals(self.server.received, ["p1", "p2", "terminate"])

  def testSyncClientServer(self):
    self.client.handle_write()
Manuel Franceschini's avatar
Manuel Franceschini committed
    self.client.enqueue_send(self.address, self.port, "p1")
    self.client.enqueue_send(self.address, self.port, "p2")
Guido Trotter's avatar
Guido Trotter committed
    while self.client.writable():
      self.client.handle_write()
    self.server.process_next_packet()
    self.assertEquals(self.server.received, ["p1"])
    self.server.process_next_packet()
    self.assertEquals(self.server.received, ["p1", "p2"])
Manuel Franceschini's avatar
Manuel Franceschini committed
    self.client.enqueue_send(self.address, self.port, "p3")
Guido Trotter's avatar
Guido Trotter committed
    while self.client.writable():
      self.client.handle_write()
    self.server.process_next_packet()
    self.assertEquals(self.server.received, ["p1", "p2", "p3"])

  def testErrorHandling(self):
Manuel Franceschini's avatar
Manuel Franceschini committed
    self.client.enqueue_send(self.address, self.port, "p1")
    self.client.enqueue_send(self.address, self.port, "p2")
    self.client.enqueue_send(self.address, self.port, "error")
    self.client.enqueue_send(self.address, self.port, "p3")
    self.client.enqueue_send(self.address, self.port, "error")
    self.client.enqueue_send(self.address, self.port, "terminate")
    self.assertRaises(errors.GenericError, self.mainloop.Run)
    self.assertEquals(self.server.received,
                      ["p1", "p2", "error"])
    self.assertEquals(self.server.error_count, 1)
    self.assertRaises(errors.GenericError, self.mainloop.Run)
    self.assertEquals(self.server.received,
                      ["p1", "p2", "error", "p3", "error"])
    self.assertEquals(self.server.error_count, 2)
Guido Trotter's avatar
Guido Trotter committed
    self.mainloop.Run()
    self.assertEquals(self.server.received,
                      ["p1", "p2", "error", "p3", "error", "terminate"])
    self.assertEquals(self.server.error_count, 2)

  def testSignaledWhileReceiving(self):
    utils.IgnoreSignals = lambda fn, *args, **kwargs: None
Manuel Franceschini's avatar
Manuel Franceschini committed
    self.client.enqueue_send(self.address, self.port, "p1")
    self.client.enqueue_send(self.address, self.port, "p2")
    self.server.handle_read()
    self.assertEquals(self.server.received, [])
Manuel Franceschini's avatar
Manuel Franceschini committed
    self.client.enqueue_send(self.address, self.port, "terminate")
    utils.IgnoreSignals = self.saved_utils_ignoresignals
    self.mainloop.Run()
    self.assertEquals(self.server.received, ["p1", "p2", "terminate"])

  def testOversizedDatagram(self):
    oversized_data = (constants.MAX_UDP_DATA_SIZE + 1) * "a"
    self.assertRaises(errors.UdpDataSizeError, self.client.enqueue_send,
Manuel Franceschini's avatar
Manuel Franceschini committed
                      self.address, self.port, oversized_data)


class TestAsyncIP4UDPSocket(testutils.GanetiTestCase, _BaseAsyncUDPSocketTest):
  """Test IP4 daemon.AsyncUDPSocket"""

  family = socket.AF_INET
  address = "127.0.0.1"

  def setUp(self):
    testutils.GanetiTestCase.setUp(self)
    _BaseAsyncUDPSocketTest.setUp(self)

  def tearDown(self):
    testutils.GanetiTestCase.tearDown(self)
    _BaseAsyncUDPSocketTest.tearDown(self)


class TestAsyncIP6UDPSocket(testutils.GanetiTestCase, _BaseAsyncUDPSocketTest):
  """Test IP6 daemon.AsyncUDPSocket"""

  family = socket.AF_INET6
  address = "::1"

  def setUp(self):
    testutils.GanetiTestCase.setUp(self)
    _BaseAsyncUDPSocketTest.setUp(self)

  def tearDown(self):
    testutils.GanetiTestCase.tearDown(self)
    _BaseAsyncUDPSocketTest.tearDown(self)
class _MyAsyncStreamServer(daemon.AsyncStreamServer):

  def __init__(self, family, address, handle_connection_fn):
    daemon.AsyncStreamServer.__init__(self, family, address)
    self.handle_connection_fn = handle_connection_fn
    self.error_count = 0
    self.expt_count = 0

  def handle_connection(self, connected_socket, client_address):
    self.handle_connection_fn(connected_socket, client_address)

  def handle_error(self):
    self.error_count += 1
    self.close()
    raise

  def handle_expt(self):
    self.expt_count += 1
    self.close()


class _MyMessageStreamHandler(daemon.AsyncTerminatedMessageStream):

  def __init__(self, connected_socket, client_address, terminator, family,
               message_fn, client_id, unhandled_limit):
    daemon.AsyncTerminatedMessageStream.__init__(self, connected_socket,
                                                 client_address,
                                                 terminator, family,
                                                 unhandled_limit)
    self.message_fn = message_fn
    self.client_id = client_id
    self.error_count = 0

  def handle_message(self, message, message_id):
    self.message_fn(self, message, message_id)

  def handle_error(self):
    self.error_count += 1
    raise


class TestAsyncStreamServerTCP(testutils.GanetiTestCase):
  """Test daemon.AsyncStreamServer with a TCP connection"""

  family = socket.AF_INET

  def setUp(self):
    testutils.GanetiTestCase.setUp(self)
    self.mainloop = daemon.Mainloop()
    self.address = self.getAddress()
    self.server = _MyAsyncStreamServer(self.family, self.address,
                                       self.handle_connection)
    self.client_handler = _MyMessageStreamHandler
    self.unhandled_limit = None
    self.terminator = "\3"
    self.address = self.server.getsockname()
    self.clients = []
    self.connections = []
    self.messages = {}
    self.connect_terminate_count = 0
    self.message_terminate_count = 0
    self.next_client_id = 0
    # Save utils.IgnoreSignals so we can do evil things to it...
    self.saved_utils_ignoresignals = utils.IgnoreSignals

  def tearDown(self):
    for c in self.clients:
      c.close()
    for c in self.connections:
      c.close()
    self.server.close()
    # ...and restore it as well
    utils.IgnoreSignals = self.saved_utils_ignoresignals
    testutils.GanetiTestCase.tearDown(self)

  def getAddress(self):
    return ("127.0.0.1", 0)

  def countTerminate(self, name):
    value = getattr(self, name)
    if value is not None:
      value -= 1
      setattr(self, name, value)
      if value <= 0:
        os.kill(os.getpid(), signal.SIGTERM)

  def handle_connection(self, connected_socket, client_address):
    client_id = self.next_client_id
    self.next_client_id += 1
    client_handler = self.client_handler(connected_socket, client_address,
                                         self.terminator, self.family,
                                         self.handle_message,
                                         client_id, self.unhandled_limit)
    self.connections.append(client_handler)
    self.countTerminate("connect_terminate_count")

  def handle_message(self, handler, message, message_id):
    self.messages.setdefault(handler.client_id, [])
    # We should just check that the message_ids are monotonically increasing.
    # If in the unit tests we never remove messages from the received queue,
    # though, we can just require that the queue length is the same as the
    # message id, before pushing the message to it. This forces a more
    # restrictive check, but we can live with this for now.
    self.assertEquals(len(self.messages[handler.client_id]), message_id)
    self.messages[handler.client_id].append(message)
    if message == "error":
      raise errors.GenericError("error")
    self.countTerminate("message_terminate_count")

  def getClient(self):
    client = socket.socket(self.family, socket.SOCK_STREAM)
    client.connect(self.address)
    self.clients.append(client)
    return client

  def tearDown(self):
    testutils.GanetiTestCase.tearDown(self)
    self.server.close()

  def testConnect(self):
    self.getClient()
    self.mainloop.Run()
    self.assertEquals(len(self.connections), 1)
    self.getClient()
    self.mainloop.Run()
    self.assertEquals(len(self.connections), 2)
    self.connect_terminate_count = 4
    self.getClient()
    self.getClient()
    self.getClient()
    self.getClient()
    self.mainloop.Run()
    self.assertEquals(len(self.connections), 6)

  def testBasicMessage(self):
    self.connect_terminate_count = None
    client = self.getClient()
    client.send("ciao\3")
    self.mainloop.Run()
    self.assertEquals(len(self.connections), 1)
    self.assertEquals(len(self.messages[0]), 1)
    self.assertEquals(self.messages[0][0], "ciao")

  def testDoubleMessage(self):
    self.connect_terminate_count = None
    client = self.getClient()
    client.send("ciao\3")
    self.mainloop.Run()
    client.send("foobar\3")
    self.mainloop.Run()
    self.assertEquals(len(self.connections), 1)
    self.assertEquals(len(self.messages[0]), 2)
    self.assertEquals(self.messages[0][1], "foobar")

  def testComposedMessage(self):
    self.connect_terminate_count = None
    self.message_terminate_count = 3
    client = self.getClient()
    client.send("one\3composed\3message\3")
    self.mainloop.Run()
    self.assertEquals(len(self.messages[0]), 3)
    self.assertEquals(self.messages[0], ["one", "composed", "message"])

  def testLongTerminator(self):
    self.terminator = "\0\1\2"
    self.connect_terminate_count = None
    self.message_terminate_count = 3
    client = self.getClient()
    client.send("one\0\1\2composed\0\1\2message\0\1\2")
    self.mainloop.Run()
    self.assertEquals(len(self.messages[0]), 3)
    self.assertEquals(self.messages[0], ["one", "composed", "message"])

  def testErrorHandling(self):
    self.connect_terminate_count = None
    self.message_terminate_count = None
    client = self.getClient()
    client.send("one\3two\3error\3three\3")
    self.assertRaises(errors.GenericError, self.mainloop.Run)
    self.assertEquals(self.connections[0].error_count, 1)
    self.assertEquals(self.messages[0], ["one", "two", "error"])
    client.send("error\3")
    self.assertRaises(errors.GenericError, self.mainloop.Run)
    self.assertEquals(self.connections[0].error_count, 2)
    self.assertEquals(self.messages[0], ["one", "two", "error", "three",
                                         "error"])

  def testDoubleClient(self):
    self.connect_terminate_count = None
    self.message_terminate_count = 2
    client1 = self.getClient()
    client2 = self.getClient()
    client1.send("c1m1\3")
    client2.send("c2m1\3")
    self.mainloop.Run()
    self.assertEquals(self.messages[0], ["c1m1"])
    self.assertEquals(self.messages[1], ["c2m1"])

  def testUnterminatedMessage(self):
    self.connect_terminate_count = None
    self.message_terminate_count = 3
    client1 = self.getClient()
    client2 = self.getClient()
    client1.send("message\3unterminated")
    client2.send("c2m1\3c2m2\3")
    self.mainloop.Run()
    self.assertEquals(self.messages[0], ["message"])
    self.assertEquals(self.messages[1], ["c2m1", "c2m2"])
    client1.send("message\3")
    self.mainloop.Run()
    self.assertEquals(self.messages[0], ["message", "unterminatedmessage"])

  def testSignaledWhileAccepting(self):
    utils.IgnoreSignals = lambda fn, *args, **kwargs: None
    client1 = self.getClient()
    self.server.handle_accept()
    # When interrupted while accepting we don't have a connection, but we
    # didn't crash either.
    self.assertEquals(len(self.connections), 0)
    utils.IgnoreSignals = self.saved_utils_ignoresignals
    self.mainloop.Run()
    self.assertEquals(len(self.connections), 1)

  def testSendMessage(self):
    self.connect_terminate_count = None
    self.message_terminate_count = 3
    client1 = self.getClient()
    client2 = self.getClient()
    client1.send("one\3composed\3message\3")
    self.mainloop.Run()
    self.assertEquals(self.messages[0], ["one", "composed", "message"])
    self.assertFalse(self.connections[0].writable())
    self.assertFalse(self.connections[1].writable())
    self.connections[0].send_message("r0")
    self.assert_(self.connections[0].writable())
    self.assertFalse(self.connections[1].writable())
    self.connections[0].send_message("r1")
    self.connections[0].send_message("r2")
    # We currently have no way to terminate the mainloop on write events, but
    # let's assume handle_write will be called if writable() is True.
    while self.connections[0].writable():
      self.connections[0].handle_write()
    client1.setblocking(0)
    client2.setblocking(0)
    self.assertEquals(client1.recv(4096), "r0\3r1\3r2\3")
    self.assertRaises(socket.error, client2.recv, 4096)

  def testLimitedUnhandledMessages(self):
    self.connect_terminate_count = None
    self.message_terminate_count = 3
    self.unhandled_limit = 2
    client1 = self.getClient()
    client2 = self.getClient()
    client1.send("one\3composed\3long\3message\3")
    client2.send("c2one\3")
    self.mainloop.Run()
    self.assertEquals(self.messages[0], ["one", "composed"])
    self.assertEquals(self.messages[1], ["c2one"])
    self.assertFalse(self.connections[0].readable())
    self.assert_(self.connections[1].readable())
    self.connections[0].send_message("r0")
    self.message_terminate_count = None
    client1.send("another\3")
    # when we write replies messages queued also get handled, but not the ones
    # in the socket.
    while self.connections[0].writable():
      self.connections[0].handle_write()
    self.assertFalse(self.connections[0].readable())
    self.assertEquals(self.messages[0], ["one", "composed", "long"])
    self.connections[0].send_message("r1")
    self.connections[0].send_message("r2")
    while self.connections[0].writable():
      self.connections[0].handle_write()
    self.assertEquals(self.messages[0], ["one", "composed", "long", "message"])
    self.assert_(self.connections[0].readable())

  def testLimitedUnhandledMessagesOne(self):
    self.connect_terminate_count = None
    self.message_terminate_count = 2
    self.unhandled_limit = 1
    client1 = self.getClient()
    client2 = self.getClient()
    client1.send("one\3composed\3message\3")
    client2.send("c2one\3")
    self.mainloop.Run()
    self.assertEquals(self.messages[0], ["one"])
    self.assertEquals(self.messages[1], ["c2one"])
    self.assertFalse(self.connections[0].readable())
    self.assertFalse(self.connections[1].readable())
    self.connections[0].send_message("r0")
    self.message_terminate_count = None
    while self.connections[0].writable():
      self.connections[0].handle_write()
    self.assertFalse(self.connections[0].readable())
    self.assertEquals(self.messages[0], ["one", "composed"])
    self.connections[0].send_message("r2")
    self.connections[0].send_message("r3")
    while self.connections[0].writable():
      self.connections[0].handle_write()
    self.assertEquals(self.messages[0], ["one", "composed", "message"])
    self.assert_(self.connections[0].readable())


class TestAsyncStreamServerUnixPath(TestAsyncStreamServerTCP):
  """Test daemon.AsyncStreamServer with a Unix path connection"""

  family = socket.AF_UNIX

  def getAddress(self):
    self.tmpdir = tempfile.mkdtemp()
    return os.path.join(self.tmpdir, "server.sock")

  def tearDown(self):
    shutil.rmtree(self.tmpdir)
    TestAsyncStreamServerTCP.tearDown(self)


Guido Trotter's avatar
Guido Trotter committed
class TestAsyncAwaker(testutils.GanetiTestCase):
  """Test daemon.AsyncAwaker"""

  family = socket.AF_INET

  def setUp(self):
    testutils.GanetiTestCase.setUp(self)
    self.mainloop = daemon.Mainloop()
    self.awaker = daemon.AsyncAwaker(signal_fn=self.handle_signal)
    self.signal_count = 0
    self.signal_terminate_count = 1

  def tearDown(self):
    self.awaker.close()

  def handle_signal(self):
    self.signal_count += 1
    self.signal_terminate_count -= 1
    if self.signal_terminate_count <= 0:
      os.kill(os.getpid(), signal.SIGTERM)

  def testBasicSignaling(self):
    self.awaker.signal()
    self.mainloop.Run()
    self.assertEquals(self.signal_count, 1)

  def testDoubleSignaling(self):
    self.awaker.signal()
    self.awaker.signal()
    self.mainloop.Run()
    # The second signal is never delivered
    self.assertEquals(self.signal_count, 1)

  def testReallyDoubleSignaling(self):
    self.assert_(self.awaker.readable())
    self.awaker.signal()
    # Let's suppose two threads overlap, and both find need_signal True
    self.awaker.need_signal = True
    self.awaker.signal()
    self.mainloop.Run()
    # We still get only one signaling
    self.assertEquals(self.signal_count, 1)

  def testNoSignalFnArgument(self):
    myawaker = daemon.AsyncAwaker()
    self.assertRaises(socket.error, myawaker.handle_read)
    myawaker.signal()
    myawaker.handle_read()
    self.assertRaises(socket.error, myawaker.handle_read)
    myawaker.signal()
    myawaker.signal()
    myawaker.handle_read()
    self.assertRaises(socket.error, myawaker.handle_read)
    myawaker.close()

  def testWrongSignalFnArgument(self):
    self.assertRaises(AssertionError, daemon.AsyncAwaker, 1)
    self.assertRaises(AssertionError, daemon.AsyncAwaker, "string")
    self.assertRaises(AssertionError, daemon.AsyncAwaker, signal_fn=1)
    self.assertRaises(AssertionError, daemon.AsyncAwaker, signal_fn="string")


if __name__ == "__main__":
  testutils.GanetiTestProgram()