From 76094e37ce7478f4e035e049898ad509fb6c0252 Mon Sep 17 00:00:00 2001 From: Michael Hanselmann <hansmi@google.com> Date: Fri, 4 Jul 2008 15:34:46 +0000 Subject: [PATCH] Add generic worker pool implementation Reviewed-by: ultrotter --- Makefile.am | 4 +- lib/workerpool.py | 314 +++++++++++++++++++++++++++++ test/ganeti.workerpool_unittest.py | 139 +++++++++++++ 3 files changed, 456 insertions(+), 1 deletion(-) create mode 100644 lib/workerpool.py create mode 100755 test/ganeti.workerpool_unittest.py diff --git a/Makefile.am b/Makefile.am index 28512049b..e96a50b3a 100644 --- a/Makefile.am +++ b/Makefile.am @@ -75,7 +75,8 @@ pkgpython_PYTHON = \ lib/serializer.py \ lib/ssconf.py \ lib/ssh.py \ - lib/utils.py + lib/utils.py \ + lib/workerpool.py hypervisor_PYTHON = \ lib/hypervisor/__init__.py \ @@ -172,6 +173,7 @@ dist_TESTS = \ test/ganeti.ssh_unittest.py \ test/ganeti.locking_unittest.py \ test/ganeti.serializer_unittest.py \ + test/ganeti.workerpool_unittest.py \ test/ganeti.constants_unittest.py nodist_TESTS = diff --git a/lib/workerpool.py b/lib/workerpool.py new file mode 100644 index 000000000..63d85cc8c --- /dev/null +++ b/lib/workerpool.py @@ -0,0 +1,314 @@ +# +# + +# Copyright (C) 2008 Google Inc. +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +# 02110-1301, USA. + + +"""Base classes for worker pools. + +""" + +import collections +import logging +import threading + +from ganeti import errors +from ganeti import utils + + +class BaseWorker(threading.Thread, object): + """Base worker class for worker pools. + + Users of a worker pool must override RunTask in a subclass. + + """ + def __init__(self, pool, worker_id): + """Constructor for BaseWorker thread. + + Args: + - pool: Parent worker pool + - worker_id: Identifier for this worker + + """ + super(BaseWorker, self).__init__() + self.pool = pool + self.worker_id = worker_id + + # Also used by WorkerPool + self._current_task = None + + def ShouldTerminate(self): + """Returns whether a worker should terminate. + + """ + return self.pool.ShouldWorkerTerminate(self) + + def run(self): + """Main thread function. + + Waits for new tasks to show up in the queue. + + """ + pool = self.pool + + assert self._current_task is None + + while True: + try: + # We wait on lock to be told either terminate or do a task. + pool._lock.acquire() + try: + if pool._ShouldWorkerTerminateUnlocked(self): + break + + # We only wait if there's no task for us. + if not pool._tasks: + # wait() releases the lock and sleeps until notified + pool._lock.wait() + + # Were we woken up in order to terminate? + if pool._ShouldWorkerTerminateUnlocked(self): + break + + if not pool._tasks: + # Spurious notification, ignore + continue + + # Get task from queue and tell pool about it + try: + self._current_task = pool._tasks.popleft() + finally: + pool._lock.notifyAll() + finally: + pool._lock.release() + + # Run the actual task + try: + self.RunTask(*self._current_task) + except: + logging.error("Worker %s: Caught unhandled exception", + self.worker_id, exc_info=True) + finally: + self._current_task = None + + # Notify pool + pool._lock.acquire() + try: + pool._lock.notifyAll() + finally: + pool._lock.release() + + def RunTask(self, *args): + """Function called to start a task. + + """ + raise NotImplementedError() + + +class WorkerPool(object): + """Worker pool with a queue. + + This class is thread-safe. + + Tasks are guaranteed to be started in the order in which they're added to the + pool. Due to the nature of threading, they're not guaranteed to finish in the + same order. + + """ + def __init__(self, num_workers, worker_class): + """Constructor for worker pool. + + Args: + - num_workers: Number of workers to be started (dynamic resizing is not + yet implemented) + - worker_class: Class to be instantiated for workers; should derive from + BaseWorker + + """ + # Some of these variables are accessed by BaseWorker + self._lock = threading.Condition(threading.Lock()) + self._worker_class = worker_class + self._last_worker_id = 0 + self._workers = [] + self._quiescing = False + + # Terminating workers + self._termworkers = [] + + # Queued tasks + self._tasks = collections.deque() + + # Start workers + self.Resize(num_workers) + + # TODO: Implement dynamic resizing? + + def AddTask(self, *args): + """Adds a task to the queue. + + Args: + - *args: Arguments passed to BaseWorker.RunTask + + """ + self._lock.acquire() + try: + # Don't add new tasks while we're quiescing + while self._quiescing: + self._lock.wait() + + # Add task to internal queue + self._tasks.append(args) + self._lock.notify() + finally: + self._lock.release() + + def _ShouldWorkerTerminateUnlocked(self, worker): + """Returns whether a worker should terminate. + + """ + return (worker in self._termworkers) + + def ShouldWorkerTerminate(self, worker): + """Returns whether a worker should terminate. + + """ + self._lock.acquire() + try: + return self._ShouldWorkerTerminateUnlocked(self) + finally: + self._lock.release() + + def _HasRunningTasksUnlocked(self): + """Checks whether there's a task running in a worker. + + """ + for worker in self._workers + self._termworkers: + if worker._current_task is not None: + return True + return False + + def Quiesce(self): + """Waits until the task queue is empty. + + """ + self._lock.acquire() + try: + self._quiescing = True + + # Wait while there are tasks pending or running + while self._tasks or self._HasRunningTasksUnlocked(): + self._lock.wait() + + finally: + self._quiescing = False + + # Make sure AddTasks continues in case it was waiting + self._lock.notifyAll() + + self._lock.release() + + def _NewWorkerIdUnlocked(self): + self._last_worker_id += 1 + return self._last_worker_id + + def _ResizeUnlocked(self, num_workers): + """Changes the number of workers. + + """ + assert num_workers >= 0, "num_workers must be >= 0" + + logging.debug("Resizing to %s workers", num_workers) + + current_count = len(self._workers) + + if current_count == num_workers: + # Nothing to do + pass + + elif current_count > num_workers: + if num_workers == 0: + # Create copy of list to iterate over while lock isn't held. + termworkers = self._workers[:] + del self._workers[:] + else: + # TODO: Implement partial downsizing + raise NotImplementedError() + #termworkers = ... + + self._termworkers += termworkers + + # Notify workers that something has changed + self._lock.notifyAll() + + # Join all terminating workers + self._lock.release() + try: + for worker in termworkers: + worker.join() + finally: + self._lock.acquire() + + # Remove terminated threads. This could be done in a more efficient way + # (del self._termworkers[:]), but checking worker.isAlive() makes sure we + # don't leave zombie threads around. + for worker in termworkers: + assert worker in self._termworkers, ("Worker not in list of" + " terminating workers") + if not worker.isAlive(): + self._termworkers.remove(worker) + + assert not self._termworkers, "Zombie worker detected" + + elif current_count < num_workers: + # Create (num_workers - current_count) new workers + for i in xrange(num_workers - current_count): + worker = self._worker_class(self, self._NewWorkerIdUnlocked()) + self._workers.append(worker) + worker.start() + + def Resize(self, num_workers): + """Changes the number of workers in the pool. + + Args: + - num_workers: New number of workers + + """ + self._lock.acquire() + try: + return self._ResizeUnlocked(num_workers) + finally: + self._lock.release() + + def TerminateWorkers(self): + """Terminate all worker threads. + + Unstarted tasks will be ignored. + + """ + logging.debug("Terminating all workers") + + self._lock.acquire() + try: + self._ResizeUnlocked(0) + + if self._tasks: + logging.debug("There are %s tasks left", len(self._tasks)) + finally: + self._lock.release() + + logging.debug("All workers terminated") diff --git a/test/ganeti.workerpool_unittest.py b/test/ganeti.workerpool_unittest.py new file mode 100755 index 000000000..bee3824ea --- /dev/null +++ b/test/ganeti.workerpool_unittest.py @@ -0,0 +1,139 @@ +#!/usr/bin/python +# + +# Copyright (C) 2008 Google Inc. +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +# 02110-1301, USA. + + +"""Script for unittesting the workerpool module""" + +import unittest +import threading +import time +import sys +import zlib + +from ganeti import workerpool + + +class DummyBaseWorker(workerpool.BaseWorker): + def RunTask(self, text): + pass + + +class ChecksumContext: + CHECKSUM_START = zlib.adler32("") + + def __init__(self): + self.lock = threading.Condition(threading.Lock()) + self.checksum = self.CHECKSUM_START + + @staticmethod + def UpdateChecksum(current, value): + return zlib.adler32(str(value), current) + + +class ChecksumBaseWorker(workerpool.BaseWorker): + def RunTask(self, ctx, number): + ctx.lock.acquire() + try: + ctx.checksum = ctx.UpdateChecksum(ctx.checksum, number) + finally: + ctx.lock.release() + + +class TestWorkerpool(unittest.TestCase): + """Workerpool tests""" + + def testDummy(self): + wp = workerpool.WorkerPool(3, DummyBaseWorker) + try: + self._CheckWorkerCount(wp, 3) + + for i in xrange(10): + wp.AddTask("Hello world %s" % i) + + wp.Quiesce() + finally: + wp.TerminateWorkers() + self._CheckWorkerCount(wp, 0) + + def testNoTasks(self): + wp = workerpool.WorkerPool(3, DummyBaseWorker) + try: + self._CheckWorkerCount(wp, 3) + self._CheckNoTasks(wp) + finally: + wp.TerminateWorkers() + self._CheckWorkerCount(wp, 0) + + def testNoTasksQuiesce(self): + wp = workerpool.WorkerPool(3, DummyBaseWorker) + try: + self._CheckWorkerCount(wp, 3) + self._CheckNoTasks(wp) + wp.Quiesce() + self._CheckNoTasks(wp) + finally: + wp.TerminateWorkers() + self._CheckWorkerCount(wp, 0) + + def testChecksum(self): + # Tests whether all tasks are run and, since we're only using a single + # thread, whether everything is started in order. + wp = workerpool.WorkerPool(1, ChecksumBaseWorker) + try: + self._CheckWorkerCount(wp, 1) + + ctx = ChecksumContext() + checksum = ChecksumContext.CHECKSUM_START + for i in xrange(1, 100): + checksum = ChecksumContext.UpdateChecksum(checksum, i) + wp.AddTask(ctx, i) + + wp.Quiesce() + + self._CheckNoTasks(wp) + + # Check sum + ctx.lock.acquire() + try: + self.assertEqual(checksum, ctx.checksum) + finally: + ctx.lock.release() + finally: + wp.TerminateWorkers() + self._CheckWorkerCount(wp, 0) + + def _CheckNoTasks(self, wp): + wp._lock.acquire() + try: + # The task queue must be empty now + self.failUnless(not wp._tasks) + finally: + wp._lock.release() + + def _CheckWorkerCount(self, wp, num_workers): + wp._lock.acquire() + try: + self.assertEqual(len(wp._workers), num_workers) + finally: + wp._lock.release() + + +if __name__ == '__main__': + unittest.main() -- GitLab