From 76094e37ce7478f4e035e049898ad509fb6c0252 Mon Sep 17 00:00:00 2001
From: Michael Hanselmann <hansmi@google.com>
Date: Fri, 4 Jul 2008 15:34:46 +0000
Subject: [PATCH] Add generic worker pool implementation

Reviewed-by: ultrotter
---
 Makefile.am                        |   4 +-
 lib/workerpool.py                  | 314 +++++++++++++++++++++++++++++
 test/ganeti.workerpool_unittest.py | 139 +++++++++++++
 3 files changed, 456 insertions(+), 1 deletion(-)
 create mode 100644 lib/workerpool.py
 create mode 100755 test/ganeti.workerpool_unittest.py

diff --git a/Makefile.am b/Makefile.am
index 28512049b..e96a50b3a 100644
--- a/Makefile.am
+++ b/Makefile.am
@@ -75,7 +75,8 @@ pkgpython_PYTHON = \
 	lib/serializer.py \
 	lib/ssconf.py \
 	lib/ssh.py \
-	lib/utils.py
+	lib/utils.py \
+	lib/workerpool.py
 
 hypervisor_PYTHON = \
 	lib/hypervisor/__init__.py \
@@ -172,6 +173,7 @@ dist_TESTS = \
 	test/ganeti.ssh_unittest.py \
 	test/ganeti.locking_unittest.py \
 	test/ganeti.serializer_unittest.py \
+	test/ganeti.workerpool_unittest.py \
 	test/ganeti.constants_unittest.py
 
 nodist_TESTS =
diff --git a/lib/workerpool.py b/lib/workerpool.py
new file mode 100644
index 000000000..63d85cc8c
--- /dev/null
+++ b/lib/workerpool.py
@@ -0,0 +1,314 @@
+#
+#
+
+# Copyright (C) 2008 Google Inc.
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation; either version 2 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful, but
+# WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+# General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
+# 02110-1301, USA.
+
+
+"""Base classes for worker pools.
+
+"""
+
+import collections
+import logging
+import threading
+
+from ganeti import errors
+from ganeti import utils
+
+
+class BaseWorker(threading.Thread, object):
+  """Base worker class for worker pools.
+
+  Users of a worker pool must override RunTask in a subclass.
+
+  """
+  def __init__(self, pool, worker_id):
+    """Constructor for BaseWorker thread.
+
+    Args:
+    - pool: Parent worker pool
+    - worker_id: Identifier for this worker
+
+    """
+    super(BaseWorker, self).__init__()
+    self.pool = pool
+    self.worker_id = worker_id
+
+    # Also used by WorkerPool
+    self._current_task = None
+
+  def ShouldTerminate(self):
+    """Returns whether a worker should terminate.
+
+    """
+    return self.pool.ShouldWorkerTerminate(self)
+
+  def run(self):
+    """Main thread function.
+
+    Waits for new tasks to show up in the queue.
+
+    """
+    pool = self.pool
+
+    assert self._current_task is None
+
+    while True:
+      try:
+        # We wait on lock to be told either terminate or do a task.
+        pool._lock.acquire()
+        try:
+          if pool._ShouldWorkerTerminateUnlocked(self):
+            break
+
+          # We only wait if there's no task for us.
+          if not pool._tasks:
+            # wait() releases the lock and sleeps until notified
+            pool._lock.wait()
+
+            # Were we woken up in order to terminate?
+            if pool._ShouldWorkerTerminateUnlocked(self):
+              break
+
+            if not pool._tasks:
+              # Spurious notification, ignore
+              continue
+
+          # Get task from queue and tell pool about it
+          try:
+            self._current_task = pool._tasks.popleft()
+          finally:
+            pool._lock.notifyAll()
+        finally:
+          pool._lock.release()
+
+        # Run the actual task
+        try:
+          self.RunTask(*self._current_task)
+        except:
+          logging.error("Worker %s: Caught unhandled exception",
+                        self.worker_id, exc_info=True)
+      finally:
+        self._current_task = None
+
+        # Notify pool
+        pool._lock.acquire()
+        try:
+          pool._lock.notifyAll()
+        finally:
+          pool._lock.release()
+
+  def RunTask(self, *args):
+    """Function called to start a task.
+
+    """
+    raise NotImplementedError()
+
+
+class WorkerPool(object):
+  """Worker pool with a queue.
+
+  This class is thread-safe.
+
+  Tasks are guaranteed to be started in the order in which they're added to the
+  pool. Due to the nature of threading, they're not guaranteed to finish in the
+  same order.
+
+  """
+  def __init__(self, num_workers, worker_class):
+    """Constructor for worker pool.
+
+    Args:
+    - num_workers: Number of workers to be started (dynamic resizing is not
+                   yet implemented)
+    - worker_class: Class to be instantiated for workers; should derive from
+                    BaseWorker
+
+    """
+    # Some of these variables are accessed by BaseWorker
+    self._lock = threading.Condition(threading.Lock())
+    self._worker_class = worker_class
+    self._last_worker_id = 0
+    self._workers = []
+    self._quiescing = False
+
+    # Terminating workers
+    self._termworkers = []
+
+    # Queued tasks
+    self._tasks = collections.deque()
+
+    # Start workers
+    self.Resize(num_workers)
+
+  # TODO: Implement dynamic resizing?
+
+  def AddTask(self, *args):
+    """Adds a task to the queue.
+
+    Args:
+    - *args: Arguments passed to BaseWorker.RunTask
+
+    """
+    self._lock.acquire()
+    try:
+      # Don't add new tasks while we're quiescing
+      while self._quiescing:
+        self._lock.wait()
+
+      # Add task to internal queue
+      self._tasks.append(args)
+      self._lock.notify()
+    finally:
+      self._lock.release()
+
+  def _ShouldWorkerTerminateUnlocked(self, worker):
+    """Returns whether a worker should terminate.
+
+    """
+    return (worker in self._termworkers)
+
+  def ShouldWorkerTerminate(self, worker):
+    """Returns whether a worker should terminate.
+
+    """
+    self._lock.acquire()
+    try:
+      return self._ShouldWorkerTerminateUnlocked(self)
+    finally:
+      self._lock.release()
+
+  def _HasRunningTasksUnlocked(self):
+    """Checks whether there's a task running in a worker.
+
+    """
+    for worker in self._workers + self._termworkers:
+      if worker._current_task is not None:
+        return True
+    return False
+
+  def Quiesce(self):
+    """Waits until the task queue is empty.
+
+    """
+    self._lock.acquire()
+    try:
+      self._quiescing = True
+
+      # Wait while there are tasks pending or running
+      while self._tasks or self._HasRunningTasksUnlocked():
+        self._lock.wait()
+
+    finally:
+      self._quiescing = False
+
+      # Make sure AddTasks continues in case it was waiting
+      self._lock.notifyAll()
+
+      self._lock.release()
+
+  def _NewWorkerIdUnlocked(self):
+    self._last_worker_id += 1
+    return self._last_worker_id
+
+  def _ResizeUnlocked(self, num_workers):
+    """Changes the number of workers.
+
+    """
+    assert num_workers >= 0, "num_workers must be >= 0"
+
+    logging.debug("Resizing to %s workers", num_workers)
+
+    current_count = len(self._workers)
+
+    if current_count == num_workers:
+      # Nothing to do
+      pass
+
+    elif current_count > num_workers:
+      if num_workers == 0:
+        # Create copy of list to iterate over while lock isn't held.
+        termworkers = self._workers[:]
+        del self._workers[:]
+      else:
+        # TODO: Implement partial downsizing
+        raise NotImplementedError()
+        #termworkers = ...
+
+      self._termworkers += termworkers
+
+      # Notify workers that something has changed
+      self._lock.notifyAll()
+
+      # Join all terminating workers
+      self._lock.release()
+      try:
+        for worker in termworkers:
+          worker.join()
+      finally:
+        self._lock.acquire()
+
+      # Remove terminated threads. This could be done in a more efficient way
+      # (del self._termworkers[:]), but checking worker.isAlive() makes sure we
+      # don't leave zombie threads around.
+      for worker in termworkers:
+        assert worker in self._termworkers, ("Worker not in list of"
+                                             " terminating workers")
+        if not worker.isAlive():
+          self._termworkers.remove(worker)
+
+      assert not self._termworkers, "Zombie worker detected"
+
+    elif current_count < num_workers:
+      # Create (num_workers - current_count) new workers
+      for i in xrange(num_workers - current_count):
+        worker = self._worker_class(self, self._NewWorkerIdUnlocked())
+        self._workers.append(worker)
+        worker.start()
+
+  def Resize(self, num_workers):
+    """Changes the number of workers in the pool.
+
+    Args:
+    - num_workers: New number of workers
+
+    """
+    self._lock.acquire()
+    try:
+      return self._ResizeUnlocked(num_workers)
+    finally:
+      self._lock.release()
+
+  def TerminateWorkers(self):
+    """Terminate all worker threads.
+
+    Unstarted tasks will be ignored.
+
+    """
+    logging.debug("Terminating all workers")
+
+    self._lock.acquire()
+    try:
+      self._ResizeUnlocked(0)
+
+      if self._tasks:
+        logging.debug("There are %s tasks left", len(self._tasks))
+    finally:
+      self._lock.release()
+
+    logging.debug("All workers terminated")
diff --git a/test/ganeti.workerpool_unittest.py b/test/ganeti.workerpool_unittest.py
new file mode 100755
index 000000000..bee3824ea
--- /dev/null
+++ b/test/ganeti.workerpool_unittest.py
@@ -0,0 +1,139 @@
+#!/usr/bin/python
+#
+
+# Copyright (C) 2008 Google Inc.
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation; either version 2 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful, but
+# WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+# General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
+# 02110-1301, USA.
+
+
+"""Script for unittesting the workerpool module"""
+
+import unittest
+import threading
+import time
+import sys
+import zlib
+
+from ganeti import workerpool
+
+
+class DummyBaseWorker(workerpool.BaseWorker):
+  def RunTask(self, text):
+    pass
+
+
+class ChecksumContext:
+  CHECKSUM_START = zlib.adler32("")
+
+  def __init__(self):
+    self.lock = threading.Condition(threading.Lock())
+    self.checksum = self.CHECKSUM_START
+
+  @staticmethod
+  def UpdateChecksum(current, value):
+    return zlib.adler32(str(value), current)
+
+
+class ChecksumBaseWorker(workerpool.BaseWorker):
+  def RunTask(self, ctx, number):
+    ctx.lock.acquire()
+    try:
+      ctx.checksum = ctx.UpdateChecksum(ctx.checksum, number)
+    finally:
+      ctx.lock.release()
+
+
+class TestWorkerpool(unittest.TestCase):
+  """Workerpool tests"""
+
+  def testDummy(self):
+    wp = workerpool.WorkerPool(3, DummyBaseWorker)
+    try:
+      self._CheckWorkerCount(wp, 3)
+
+      for i in xrange(10):
+        wp.AddTask("Hello world %s" % i)
+
+      wp.Quiesce()
+    finally:
+      wp.TerminateWorkers()
+      self._CheckWorkerCount(wp, 0)
+
+  def testNoTasks(self):
+    wp = workerpool.WorkerPool(3, DummyBaseWorker)
+    try:
+      self._CheckWorkerCount(wp, 3)
+      self._CheckNoTasks(wp)
+    finally:
+      wp.TerminateWorkers()
+      self._CheckWorkerCount(wp, 0)
+
+  def testNoTasksQuiesce(self):
+    wp = workerpool.WorkerPool(3, DummyBaseWorker)
+    try:
+      self._CheckWorkerCount(wp, 3)
+      self._CheckNoTasks(wp)
+      wp.Quiesce()
+      self._CheckNoTasks(wp)
+    finally:
+      wp.TerminateWorkers()
+      self._CheckWorkerCount(wp, 0)
+
+  def testChecksum(self):
+    # Tests whether all tasks are run and, since we're only using a single
+    # thread, whether everything is started in order.
+    wp = workerpool.WorkerPool(1, ChecksumBaseWorker)
+    try:
+      self._CheckWorkerCount(wp, 1)
+
+      ctx = ChecksumContext()
+      checksum = ChecksumContext.CHECKSUM_START
+      for i in xrange(1, 100):
+        checksum = ChecksumContext.UpdateChecksum(checksum, i)
+        wp.AddTask(ctx, i)
+
+      wp.Quiesce()
+
+      self._CheckNoTasks(wp)
+
+      # Check sum
+      ctx.lock.acquire()
+      try:
+        self.assertEqual(checksum, ctx.checksum)
+      finally:
+        ctx.lock.release()
+    finally:
+      wp.TerminateWorkers()
+      self._CheckWorkerCount(wp, 0)
+
+  def _CheckNoTasks(self, wp):
+    wp._lock.acquire()
+    try:
+      # The task queue must be empty now
+      self.failUnless(not wp._tasks)
+    finally:
+      wp._lock.release()
+
+  def _CheckWorkerCount(self, wp, num_workers):
+    wp._lock.acquire()
+    try:
+      self.assertEqual(len(wp._workers), num_workers)
+    finally:
+      wp._lock.release()
+
+
+if __name__ == '__main__':
+  unittest.main()
-- 
GitLab