Skip to content
Snippets Groups Projects
ganeti.daemon_unittest.py 19.8 KiB
Newer Older
#!/usr/bin/python
#

# Copyright (C) 2010 Google Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
# 02110-1301, USA.


"""Script for unittesting the daemon module"""

import unittest
import signal
import os
Guido Trotter's avatar
Guido Trotter committed
import socket
import tempfile
import shutil

from ganeti import daemon
from ganeti import errors
from ganeti import constants
from ganeti import utils

import testutils


class TestMainloop(testutils.GanetiTestCase):
  """Test daemon.Mainloop"""

  def setUp(self):
    testutils.GanetiTestCase.setUp(self)
    self.mainloop = daemon.Mainloop()
    self.sendsig_events = []
    self.onsignal_events = []

  def _CancelEvent(self, handle):
    self.mainloop.scheduler.cancel(handle)

  def _SendSig(self, sig):
    self.sendsig_events.append(sig)
    os.kill(os.getpid(), sig)

  def OnSignal(self, signum):
    self.onsignal_events.append(signum)

  def testRunAndTermBySched(self):
    self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGTERM])
    self.mainloop.Run() # terminates by _SendSig being scheduled
    self.assertEquals(self.sendsig_events, [signal.SIGTERM])

  def testTerminatingSignals(self):
    self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGCHLD])
    self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGINT])
    self.mainloop.Run()
    self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGINT])
    self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGTERM])
    self.mainloop.Run()
    self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGINT,
                                            signal.SIGTERM])

  def testSchedulerCancel(self):
    handle = self.mainloop.scheduler.enter(0.1, 1, self._SendSig,
                                           [signal.SIGTERM])
    self.mainloop.scheduler.cancel(handle)
    self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGCHLD])
    self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
    self.mainloop.Run()
    self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGTERM])

  def testRegisterSignal(self):
    self.mainloop.RegisterSignal(self)
    self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGCHLD])
    handle = self.mainloop.scheduler.enter(0.1, 1, self._SendSig,
                                           [signal.SIGTERM])
    self.mainloop.scheduler.cancel(handle)
    self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGCHLD])
    self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
    # ...not delievered because they are scheduled after TERM
    self.mainloop.scheduler.enter(0.4, 1, self._SendSig, [signal.SIGCHLD])
    self.mainloop.scheduler.enter(0.5, 1, self._SendSig, [signal.SIGCHLD])
    self.mainloop.Run()
    self.assertEquals(self.sendsig_events,
                      [signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM])
    self.assertEquals(self.onsignal_events, self.sendsig_events)

  def testDeferredCancel(self):
    self.mainloop.RegisterSignal(self)
    now = time.time()
    self.mainloop.scheduler.enterabs(now + 0.1, 1, self._SendSig,
                                     [signal.SIGCHLD])
    handle1 = self.mainloop.scheduler.enterabs(now + 0.3, 2, self._SendSig,
                                               [signal.SIGCHLD])
    handle2 = self.mainloop.scheduler.enterabs(now + 0.4, 2, self._SendSig,
                                               [signal.SIGCHLD])
    self.mainloop.scheduler.enterabs(now + 0.2, 1, self._CancelEvent,
                                     [handle1])
    self.mainloop.scheduler.enterabs(now + 0.2, 1, self._CancelEvent,
                                     [handle2])
    self.mainloop.scheduler.enter(0.5, 1, self._SendSig, [signal.SIGTERM])
    self.mainloop.Run()
    self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGTERM])
    self.assertEquals(self.onsignal_events, self.sendsig_events)

Guido Trotter's avatar
Guido Trotter committed
  def testReRun(self):
    self.mainloop.RegisterSignal(self)
    self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGCHLD])
    self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGCHLD])
    self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
    self.mainloop.scheduler.enter(0.4, 1, self._SendSig, [signal.SIGCHLD])
    self.mainloop.scheduler.enter(0.5, 1, self._SendSig, [signal.SIGCHLD])
    self.mainloop.Run()
    self.assertEquals(self.sendsig_events,
                      [signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM])
    self.assertEquals(self.onsignal_events, self.sendsig_events)
    self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
    self.mainloop.Run()
    self.assertEquals(self.sendsig_events,
                      [signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM,
                       signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM])
    self.assertEquals(self.onsignal_events, self.sendsig_events)

  def testPriority(self):
    # for events at the same time, the highest priority one executes first
    now = time.time()
    self.mainloop.scheduler.enterabs(now + 0.1, 2, self._SendSig,
                                     [signal.SIGCHLD])
    self.mainloop.scheduler.enterabs(now + 0.1, 1, self._SendSig,
                                     [signal.SIGTERM])
    self.mainloop.Run()
    self.assertEquals(self.sendsig_events, [signal.SIGTERM])
    self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGTERM])
    self.mainloop.Run()
    self.assertEquals(self.sendsig_events,
                      [signal.SIGTERM, signal.SIGCHLD, signal.SIGTERM])

Guido Trotter's avatar
Guido Trotter committed
class _MyAsyncUDPSocket(daemon.AsyncUDPSocket):

  def __init__(self):
    daemon.AsyncUDPSocket.__init__(self)
    self.received = []
    self.error_count = 0

  def handle_datagram(self, payload, ip, port):
    self.received.append((payload))
    if payload == "terminate":
      os.kill(os.getpid(), signal.SIGTERM)
    elif payload == "error":
      raise errors.GenericError("error")

  def handle_error(self):
    self.error_count += 1
Guido Trotter's avatar
Guido Trotter committed


class TestAsyncUDPSocket(testutils.GanetiTestCase):
  """Test daemon.AsyncUDPSocket"""

  def setUp(self):
    testutils.GanetiTestCase.setUp(self)
    self.mainloop = daemon.Mainloop()
    self.server = _MyAsyncUDPSocket()
    self.client = _MyAsyncUDPSocket()
    self.server.bind(("127.0.0.1", 0))
    self.port = self.server.getsockname()[1]
    # Save utils.IgnoreSignals so we can do evil things to it...
    self.saved_utils_ignoresignals = utils.IgnoreSignals
Guido Trotter's avatar
Guido Trotter committed

  def tearDown(self):
    self.server.close()
    self.client.close()
    # ...and restore it as well
    utils.IgnoreSignals = self.saved_utils_ignoresignals
Guido Trotter's avatar
Guido Trotter committed
    testutils.GanetiTestCase.tearDown(self)

  def testNoDoubleBind(self):
    self.assertRaises(socket.error, self.client.bind, ("127.0.0.1", self.port))

  def testAsyncClientServer(self):
    self.client.enqueue_send("127.0.0.1", self.port, "p1")
    self.client.enqueue_send("127.0.0.1", self.port, "p2")
    self.client.enqueue_send("127.0.0.1", self.port, "terminate")
    self.mainloop.Run()
    self.assertEquals(self.server.received, ["p1", "p2", "terminate"])

  def testSyncClientServer(self):
    self.client.handle_write()
Guido Trotter's avatar
Guido Trotter committed
    self.client.enqueue_send("127.0.0.1", self.port, "p1")
    self.client.enqueue_send("127.0.0.1", self.port, "p2")
    while self.client.writable():
      self.client.handle_write()
    self.server.process_next_packet()
    self.assertEquals(self.server.received, ["p1"])
    self.server.process_next_packet()
    self.assertEquals(self.server.received, ["p1", "p2"])
    self.client.enqueue_send("127.0.0.1", self.port, "p3")
    while self.client.writable():
      self.client.handle_write()
    self.server.process_next_packet()
    self.assertEquals(self.server.received, ["p1", "p2", "p3"])

  def testErrorHandling(self):
    self.client.enqueue_send("127.0.0.1", self.port, "p1")
    self.client.enqueue_send("127.0.0.1", self.port, "p2")
    self.client.enqueue_send("127.0.0.1", self.port, "error")
    self.client.enqueue_send("127.0.0.1", self.port, "p3")
    self.client.enqueue_send("127.0.0.1", self.port, "error")
    self.client.enqueue_send("127.0.0.1", self.port, "terminate")
    self.assertRaises(errors.GenericError, self.mainloop.Run)
    self.assertEquals(self.server.received,
                      ["p1", "p2", "error"])
    self.assertEquals(self.server.error_count, 1)
    self.assertRaises(errors.GenericError, self.mainloop.Run)
    self.assertEquals(self.server.received,
                      ["p1", "p2", "error", "p3", "error"])
    self.assertEquals(self.server.error_count, 2)
Guido Trotter's avatar
Guido Trotter committed
    self.mainloop.Run()
    self.assertEquals(self.server.received,
                      ["p1", "p2", "error", "p3", "error", "terminate"])
    self.assertEquals(self.server.error_count, 2)

  def testSignaledWhileReceiving(self):
    utils.IgnoreSignals = lambda fn, *args, **kwargs: None
    self.client.enqueue_send("127.0.0.1", self.port, "p1")
    self.client.enqueue_send("127.0.0.1", self.port, "p2")
    self.server.handle_read()
    self.assertEquals(self.server.received, [])
    self.client.enqueue_send("127.0.0.1", self.port, "terminate")
    utils.IgnoreSignals = self.saved_utils_ignoresignals
    self.mainloop.Run()
    self.assertEquals(self.server.received, ["p1", "p2", "terminate"])

  def testOversizedDatagram(self):
    oversized_data = (constants.MAX_UDP_DATA_SIZE + 1) * "a"
    self.assertRaises(errors.UdpDataSizeError, self.client.enqueue_send,
                      "127.0.0.1", self.port, oversized_data)

class _MyAsyncStreamServer(daemon.AsyncStreamServer):

  def __init__(self, family, address, handle_connection_fn):
    daemon.AsyncStreamServer.__init__(self, family, address)
    self.handle_connection_fn = handle_connection_fn
    self.error_count = 0
    self.expt_count = 0

  def handle_connection(self, connected_socket, client_address):
    self.handle_connection_fn(connected_socket, client_address)

  def handle_error(self):
    self.error_count += 1
    self.close()
    raise

  def handle_expt(self):
    self.expt_count += 1
    self.close()


class _MyMessageStreamHandler(daemon.AsyncTerminatedMessageStream):

  def __init__(self, connected_socket, client_address, terminator, family,
               message_fn, client_id):
    daemon.AsyncTerminatedMessageStream.__init__(self, connected_socket,
                                                 client_address,
                                                 terminator, family)
    self.message_fn = message_fn
    self.client_id = client_id
    self.error_count = 0

  def handle_message(self, message, message_id):
    self.message_fn(self, message, message_id)

  def handle_error(self):
    self.error_count += 1
    raise


class TestAsyncStreamServerTCP(testutils.GanetiTestCase):
  """Test daemon.AsyncStreamServer with a TCP connection"""

  family = socket.AF_INET

  def setUp(self):
    testutils.GanetiTestCase.setUp(self)
    self.mainloop = daemon.Mainloop()
    self.address = self.getAddress()
    self.server = _MyAsyncStreamServer(self.family, self.address,
                                       self.handle_connection)
    self.client_handler = _MyMessageStreamHandler
    self.terminator = "\3"
    self.address = self.server.getsockname()
    self.clients = []
    self.connections = []
    self.messages = {}
    self.connect_terminate_count = 0
    self.message_terminate_count = 0
    self.next_client_id = 0
    # Save utils.IgnoreSignals so we can do evil things to it...
    self.saved_utils_ignoresignals = utils.IgnoreSignals

  def tearDown(self):
    for c in self.clients:
      c.close()
    for c in self.connections:
      c.close()
    self.server.close()
    # ...and restore it as well
    utils.IgnoreSignals = self.saved_utils_ignoresignals
    testutils.GanetiTestCase.tearDown(self)

  def getAddress(self):
    return ("127.0.0.1", 0)

  def countTerminate(self, name):
    value = getattr(self, name)
    if value is not None:
      value -= 1
      setattr(self, name, value)
      if value <= 0:
        os.kill(os.getpid(), signal.SIGTERM)

  def handle_connection(self, connected_socket, client_address):
    client_id = self.next_client_id
    self.next_client_id += 1
    client_handler = self.client_handler(connected_socket, client_address,
                                         self.terminator, self.family,
                                         self.handle_message,
                                         client_id)
    self.connections.append(client_handler)
    self.countTerminate("connect_terminate_count")

  def handle_message(self, handler, message, message_id):
    self.messages.setdefault(handler.client_id, [])
    # We should just check that the message_ids are monotonically increasing.
    # If in the unit tests we never remove messages from the received queue,
    # though, we can just require that the queue length is the same as the
    # message id, before pushing the message to it. This forces a more
    # restrictive check, but we can live with this for now.
    self.assertEquals(len(self.messages[handler.client_id]), message_id)
    self.messages[handler.client_id].append(message)
    if message == "error":
      raise errors.GenericError("error")
    self.countTerminate("message_terminate_count")

  def getClient(self):
    client = socket.socket(self.family, socket.SOCK_STREAM)
    client.connect(self.address)
    self.clients.append(client)
    return client

  def tearDown(self):
    testutils.GanetiTestCase.tearDown(self)
    self.server.close()

  def testConnect(self):
    self.getClient()
    self.mainloop.Run()
    self.assertEquals(len(self.connections), 1)
    self.getClient()
    self.mainloop.Run()
    self.assertEquals(len(self.connections), 2)
    self.connect_terminate_count = 4
    self.getClient()
    self.getClient()
    self.getClient()
    self.getClient()
    self.mainloop.Run()
    self.assertEquals(len(self.connections), 6)

  def testBasicMessage(self):
    self.connect_terminate_count = None
    client = self.getClient()
    client.send("ciao\3")
    self.mainloop.Run()
    self.assertEquals(len(self.connections), 1)
    self.assertEquals(len(self.messages[0]), 1)
    self.assertEquals(self.messages[0][0], "ciao")

  def testDoubleMessage(self):
    self.connect_terminate_count = None
    client = self.getClient()
    client.send("ciao\3")
    self.mainloop.Run()
    client.send("foobar\3")
    self.mainloop.Run()
    self.assertEquals(len(self.connections), 1)
    self.assertEquals(len(self.messages[0]), 2)
    self.assertEquals(self.messages[0][1], "foobar")

  def testComposedMessage(self):
    self.connect_terminate_count = None
    self.message_terminate_count = 3
    client = self.getClient()
    client.send("one\3composed\3message\3")
    self.mainloop.Run()
    self.assertEquals(len(self.messages[0]), 3)
    self.assertEquals(self.messages[0], ["one", "composed", "message"])

  def testLongTerminator(self):
    self.terminator = "\0\1\2"
    self.connect_terminate_count = None
    self.message_terminate_count = 3
    client = self.getClient()
    client.send("one\0\1\2composed\0\1\2message\0\1\2")
    self.mainloop.Run()
    self.assertEquals(len(self.messages[0]), 3)
    self.assertEquals(self.messages[0], ["one", "composed", "message"])

  def testErrorHandling(self):
    self.connect_terminate_count = None
    self.message_terminate_count = None
    client = self.getClient()
    client.send("one\3two\3error\3three\3")
    self.assertRaises(errors.GenericError, self.mainloop.Run)
    self.assertEquals(self.connections[0].error_count, 1)
    self.assertEquals(self.messages[0], ["one", "two", "error"])
    client.send("error\3")
    self.assertRaises(errors.GenericError, self.mainloop.Run)
    self.assertEquals(self.connections[0].error_count, 2)
    self.assertEquals(self.messages[0], ["one", "two", "error", "three",
                                         "error"])

  def testDoubleClient(self):
    self.connect_terminate_count = None
    self.message_terminate_count = 2
    client1 = self.getClient()
    client2 = self.getClient()
    client1.send("c1m1\3")
    client2.send("c2m1\3")
    self.mainloop.Run()
    self.assertEquals(self.messages[0], ["c1m1"])
    self.assertEquals(self.messages[1], ["c2m1"])

  def testUnterminatedMessage(self):
    self.connect_terminate_count = None
    self.message_terminate_count = 3
    client1 = self.getClient()
    client2 = self.getClient()
    client1.send("message\3unterminated")
    client2.send("c2m1\3c2m2\3")
    self.mainloop.Run()
    self.assertEquals(self.messages[0], ["message"])
    self.assertEquals(self.messages[1], ["c2m1", "c2m2"])
    client1.send("message\3")
    self.mainloop.Run()
    self.assertEquals(self.messages[0], ["message", "unterminatedmessage"])

  def testSignaledWhileAccepting(self):
    utils.IgnoreSignals = lambda fn, *args, **kwargs: None
    client1 = self.getClient()
    self.server.handle_accept()
    # When interrupted while accepting we don't have a connection, but we
    # didn't crash either.
    self.assertEquals(len(self.connections), 0)
    utils.IgnoreSignals = self.saved_utils_ignoresignals
    self.mainloop.Run()
    self.assertEquals(len(self.connections), 1)


class TestAsyncStreamServerUnixPath(TestAsyncStreamServerTCP):
  """Test daemon.AsyncStreamServer with a Unix path connection"""

  family = socket.AF_UNIX

  def getAddress(self):
    self.tmpdir = tempfile.mkdtemp()
    return os.path.join(self.tmpdir, "server.sock")

  def tearDown(self):
    shutil.rmtree(self.tmpdir)
    TestAsyncStreamServerTCP.tearDown(self)


Guido Trotter's avatar
Guido Trotter committed
class TestAsyncAwaker(testutils.GanetiTestCase):
  """Test daemon.AsyncAwaker"""

  family = socket.AF_INET

  def setUp(self):
    testutils.GanetiTestCase.setUp(self)
    self.mainloop = daemon.Mainloop()
    self.awaker = daemon.AsyncAwaker(signal_fn=self.handle_signal)
    self.signal_count = 0
    self.signal_terminate_count = 1

  def tearDown(self):
    self.awaker.close()

  def handle_signal(self):
    self.signal_count += 1
    self.signal_terminate_count -= 1
    if self.signal_terminate_count <= 0:
      os.kill(os.getpid(), signal.SIGTERM)

  def testBasicSignaling(self):
    self.awaker.signal()
    self.mainloop.Run()
    self.assertEquals(self.signal_count, 1)

  def testDoubleSignaling(self):
    self.awaker.signal()
    self.awaker.signal()
    self.mainloop.Run()
    # The second signal is never delivered
    self.assertEquals(self.signal_count, 1)

  def testReallyDoubleSignaling(self):
    self.assert_(self.awaker.readable())
    self.awaker.signal()
    # Let's suppose two threads overlap, and both find need_signal True
    self.awaker.need_signal = True
    self.awaker.signal()
    self.mainloop.Run()
    # We still get only one signaling
    self.assertEquals(self.signal_count, 1)

  def testNoSignalFnArgument(self):
    myawaker = daemon.AsyncAwaker()
    self.assertRaises(socket.error, myawaker.handle_read)
    myawaker.signal()
    myawaker.handle_read()
    self.assertRaises(socket.error, myawaker.handle_read)
    myawaker.signal()
    myawaker.signal()
    myawaker.handle_read()
    self.assertRaises(socket.error, myawaker.handle_read)
    myawaker.close()

  def testWrongSignalFnArgument(self):
    self.assertRaises(AssertionError, daemon.AsyncAwaker, 1)
    self.assertRaises(AssertionError, daemon.AsyncAwaker, "string")
    self.assertRaises(AssertionError, daemon.AsyncAwaker, signal_fn=1)
    self.assertRaises(AssertionError, daemon.AsyncAwaker, signal_fn="string")


if __name__ == "__main__":
  testutils.GanetiTestProgram()