workerpool.py 16.1 KB
Newer Older
1
2
3
#
#

4
# Copyright (C) 2008, 2009, 2010 Google Inc.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
# 02110-1301, USA.


"""Base classes for worker pools.

"""

import logging
import threading
28
import heapq
29
import itertools
30

31
from ganeti import compat
32
from ganeti import errors
33

34

35
_TERMINATE = object()
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
_DEFAULT_PRIORITY = 0


class DeferTask(Exception):
  """Special exception class to defer a task.

  This class can be raised by L{BaseWorker.RunTask} to defer the execution of a
  task. Optionally, the priority of the task can be changed.

  """
  def __init__(self, priority=None):
    """Initializes this class.

    @type priority: number
    @param priority: New task priority (None means no change)

    """
    Exception.__init__(self)
    self.priority = priority
55
56


57
58
59
60
61
62
class BaseWorker(threading.Thread, object):
  """Base worker class for worker pools.

  Users of a worker pool must override RunTask in a subclass.

  """
63
  # pylint: disable=W0212
64
65
66
  def __init__(self, pool, worker_id):
    """Constructor for BaseWorker thread.

67
68
    @param pool: the parent worker pool
    @param worker_id: identifier for this worker
69
70

    """
71
    super(BaseWorker, self).__init__(name=worker_id)
72
    self.pool = pool
73
    self._worker_id = worker_id
74
75
    self._current_task = None

76
77
    assert self.getName() == worker_id

78
  def ShouldTerminate(self):
79
80
81
    """Returns whether this worker should terminate.

    Should only be called from within L{RunTask}.
82
83

    """
84
85
86
87
88
89
    self.pool._lock.acquire()
    try:
      assert self._HasRunningTaskUnlocked()
      return self.pool._ShouldWorkerTerminateUnlocked(self)
    finally:
      self.pool._lock.release()
90

91
92
93
94
95
96
97
98
99
100
  def GetCurrentPriority(self):
    """Returns the priority of the current task.

    Should only be called from within L{RunTask}.

    """
    self.pool._lock.acquire()
    try:
      assert self._HasRunningTaskUnlocked()

101
      (priority, _, _, _) = self._current_task
102
103
104
105
106

      return priority
    finally:
      self.pool._lock.release()

107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
  def SetTaskName(self, taskname):
    """Sets the name of the current task.

    Should only be called from within L{RunTask}.

    @type taskname: string
    @param taskname: Task's name

    """
    if taskname:
      name = "%s/%s" % (self._worker_id, taskname)
    else:
      name = self._worker_id

    # Set thread name
    self.setName(name)

124
125
126
127
128
129
  def _HasRunningTaskUnlocked(self):
    """Returns whether this worker is currently running a task.

    """
    return (self._current_task is not None)

130
131
132
133
134
135
136
137
138
  def run(self):
    """Main thread function.

    Waits for new tasks to show up in the queue.

    """
    pool = self.pool

    while True:
139
      assert self._current_task is None
140
141

      defer = None
142
      try:
143
        # Wait on lock to be told either to terminate or to do a task
144
145
        pool._lock.acquire()
        try:
146
          task = pool._WaitForTaskUnlocked(self)
147

148
149
150
          if task is _TERMINATE:
            # Told to terminate
            break
151

152
153
154
          if task is None:
            # Spurious notification, ignore
            continue
155

156
          self._current_task = task
157

158
159
160
          # No longer needed, dispose of reference
          del task

161
          assert self._HasRunningTaskUnlocked()
162

163
164
165
        finally:
          pool._lock.release()

166
        (priority, _, _, args) = self._current_task
167
        try:
168
169
170
          # Run the actual task
          assert defer is None
          logging.debug("Starting task %r, priority %s", args, priority)
171
172
          assert self.getName() == self._worker_id
          try:
173
            self.RunTask(*args) # pylint: disable=W0142
174
175
          finally:
            self.SetTaskName(None)
176
177
178
179
180
181
182
183
          logging.debug("Done with task %r, priority %s", args, priority)
        except DeferTask, err:
          defer = err

          if defer.priority is None:
            # Use same priority
            defer.priority = priority

Michael Hanselmann's avatar
Michael Hanselmann committed
184
185
          logging.debug("Deferring task %r, new priority %s",
                        args, defer.priority)
186
187

          assert self._HasRunningTaskUnlocked()
188
        except: # pylint: disable=W0702
189
          logging.exception("Caught unhandled exception")
190
191

        assert self._HasRunningTaskUnlocked()
192
193
194
195
      finally:
        # Notify pool
        pool._lock.acquire()
        try:
196
197
198
          if defer:
            assert self._current_task
            # Schedule again for later run
199
200
            (_, _, _, args) = self._current_task
            pool._AddTaskUnlocked(args, defer.priority, None)
201

202
203
          if self._current_task:
            self._current_task = None
204
            pool._worker_to_pool.notifyAll()
205
206
207
        finally:
          pool._lock.release()

208
209
      assert not self._HasRunningTaskUnlocked()

210
    logging.debug("Terminates")
211

212
213
214
  def RunTask(self, *args):
    """Function called to start a task.

215
216
    This needs to be implemented by child classes.

217
218
219
220
221
222
223
224
225
    """
    raise NotImplementedError()


class WorkerPool(object):
  """Worker pool with a queue.

  This class is thread-safe.

226
227
228
  Tasks are guaranteed to be started in the order in which they're
  added to the pool. Due to the nature of threading, they're not
  guaranteed to finish in the same order.
229

230
231
232
233
234
235
236
237
238
239
240
241
  @type _tasks: list of tuples
  @ivar _tasks: Each tuple has the format (priority, order ID, task ID,
    arguments). Priority and order ID are numeric and essentially control the
    sort order. The order ID is an increasing number denoting the order in
    which tasks are added to the queue. The task ID is controlled by user of
    workerpool, see L{AddTask} for details. The task arguments are C{None} for
    abandoned tasks, otherwise a sequence of arguments to be passed to
    L{BaseWorker.RunTask}). The list must fulfill the heap property (for use by
    the C{heapq} module).
  @type _taskdata: dict; (task IDs as keys, tuples as values)
  @ivar _taskdata: Mapping from task IDs to entries in L{_tasks}

242
  """
243
  def __init__(self, name, num_workers, worker_class):
244
245
    """Constructor for worker pool.

246
247
248
249
    @param num_workers: number of workers to be started
        (dynamic resizing is not yet implemented)
    @param worker_class: the class to be instantiated for workers;
        should derive from L{BaseWorker}
250
251
252

    """
    # Some of these variables are accessed by BaseWorker
253
254
255
256
    self._lock = threading.Lock()
    self._pool_to_pool = threading.Condition(self._lock)
    self._pool_to_worker = threading.Condition(self._lock)
    self._worker_to_pool = threading.Condition(self._lock)
257
    self._worker_class = worker_class
258
    self._name = name
259
260
261
    self._last_worker_id = 0
    self._workers = []
    self._quiescing = False
262
    self._active = True
263
264
265
266
267

    # Terminating workers
    self._termworkers = []

    # Queued tasks
268
    self._counter = itertools.count()
269
    self._tasks = []
270
    self._taskdata = {}
271
272
273
274
275
276

    # Start workers
    self.Resize(num_workers)

  # TODO: Implement dynamic resizing?

Guido Trotter's avatar
Guido Trotter committed
277
278
279
280
281
282
283
  def _WaitWhileQuiescingUnlocked(self):
    """Wait until the worker pool has finished quiescing.

    """
    while self._quiescing:
      self._pool_to_pool.wait()

284
  def _AddTaskUnlocked(self, args, priority, task_id):
285
286
287
288
289
290
    """Adds a task to the internal queue.

    @type args: sequence
    @param args: Arguments passed to L{BaseWorker.RunTask}
    @type priority: number
    @param priority: Task priority
291
    @param task_id: Task ID
292
293

    """
294
    assert isinstance(args, (tuple, list)), "Arguments must be a sequence"
295
    assert isinstance(priority, (int, long)), "Priority must be numeric"
296

297
298
299
300
301
302
303
    task = [priority, self._counter.next(), task_id, args]

    if task_id is not None:
      assert task_id not in self._taskdata
      # Keep a reference to change priority later if necessary
      self._taskdata[task_id] = task

304
305
    # A counter is used to ensure elements are processed in their incoming
    # order. For processing they're sorted by priority and then counter.
306
    heapq.heappush(self._tasks, task)
307
308
309
310

    # Notify a waiting worker
    self._pool_to_worker.notify()

311
  def AddTask(self, args, priority=_DEFAULT_PRIORITY, task_id=None):
312
313
    """Adds a task to the queue.

314
    @type args: sequence
315
    @param args: arguments passed to L{BaseWorker.RunTask}
316
317
    @type priority: number
    @param priority: Task priority
318
319
320
321
322
    @param task_id: Task ID
    @note: The task ID can be essentially anything that can be used as a
      dictionary key. Callers, however, must ensure a task ID is unique while a
      task is in the pool or while it might return to the pool due to deferring
      using L{DeferTask}.
323
324
325
326

    """
    self._lock.acquire()
    try:
Guido Trotter's avatar
Guido Trotter committed
327
      self._WaitWhileQuiescingUnlocked()
328
      self._AddTaskUnlocked(args, priority, task_id)
329
330
331
    finally:
      self._lock.release()

332
  def AddManyTasks(self, tasks, priority=_DEFAULT_PRIORITY, task_id=None):
Guido Trotter's avatar
Guido Trotter committed
333
334
335
336
    """Add a list of tasks to the queue.

    @type tasks: list of tuples
    @param tasks: list of args passed to L{BaseWorker.RunTask}
337
338
339
    @type priority: number or list of numbers
    @param priority: Priority for all added tasks or a list with the priority
                     for each task
340
341
342
    @type task_id: list
    @param task_id: List with the ID for each task
    @note: See L{AddTask} for a note on task IDs.
Guido Trotter's avatar
Guido Trotter committed
343
344

    """
345
    assert compat.all(isinstance(task, (tuple, list)) for task in tasks), \
346
           "Each task must be a sequence"
347
348
349
    assert (isinstance(priority, (int, long)) or
            compat.all(isinstance(prio, (int, long)) for prio in priority)), \
           "Priority must be numeric or be a list of numeric values"
350
351
    assert task_id is None or isinstance(task_id, (tuple, list)), \
           "Task IDs must be in a sequence"
352
353
354
355
356
357
358
359

    if isinstance(priority, (int, long)):
      priority = [priority] * len(tasks)
    elif len(priority) != len(tasks):
      raise errors.ProgrammerError("Number of priorities (%s) doesn't match"
                                   " number of tasks (%s)" %
                                   (len(priority), len(tasks)))

360
361
362
363
364
365
366
    if task_id is None:
      task_id = [None] * len(tasks)
    elif len(task_id) != len(tasks):
      raise errors.ProgrammerError("Number of task IDs (%s) doesn't match"
                                   " number of tasks (%s)" %
                                   (len(task_id), len(tasks)))

Guido Trotter's avatar
Guido Trotter committed
367
368
369
370
    self._lock.acquire()
    try:
      self._WaitWhileQuiescingUnlocked()

371
372
      assert compat.all(isinstance(prio, (int, long)) for prio in priority)
      assert len(tasks) == len(priority)
373
      assert len(tasks) == len(task_id)
374

375
376
      for (args, prio, tid) in zip(tasks, priority, task_id):
        self._AddTaskUnlocked(args, prio, tid)
Guido Trotter's avatar
Guido Trotter committed
377
378
379
    finally:
      self._lock.release()

380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
  def SetActive(self, active):
    """Enable/disable processing of tasks.

    This is different from L{Quiesce} in the sense that this function just
    changes an internal flag and doesn't wait for the queue to be empty. Tasks
    already being processed continue normally, but no new tasks will be
    started. New tasks can still be added.

    @type active: bool
    @param active: Whether tasks should be processed

    """
    self._lock.acquire()
    try:
      self._active = active

      if active:
        # Tell all workers to continue processing
        self._pool_to_worker.notifyAll()
    finally:
      self._lock.release()

402
403
404
405
406
407
408
  def _WaitForTaskUnlocked(self, worker):
    """Waits for a task for a worker.

    @type worker: L{BaseWorker}
    @param worker: Worker thread

    """
409
410
411
    while True:
      if self._ShouldWorkerTerminateUnlocked(worker):
        return _TERMINATE
412

413
414
415
416
417
418
419
      # If there's a pending task, return it immediately
      if self._active and self._tasks:
        # Get task from queue and tell pool about it
        try:
          task = heapq.heappop(self._tasks)
        finally:
          self._worker_to_pool.notifyAll()
420

421
422
423
424
425
        # Delete reference
        (_, _, task_id, _) = task
        if task_id is not None:
          del self._taskdata[task_id]

426
        return task
427

428
      logging.debug("Waiting for tasks")
429

430
431
      # wait() releases the lock and sleeps until notified
      self._pool_to_worker.wait()
432

433
      logging.debug("Notified while waiting")
434

435
436
437
438
439
440
441
442
443
444
445
  def _ShouldWorkerTerminateUnlocked(self, worker):
    """Returns whether a worker should terminate.

    """
    return (worker in self._termworkers)

  def _HasRunningTasksUnlocked(self):
    """Checks whether there's a task running in a worker.

    """
    for worker in self._workers + self._termworkers:
446
      if worker._HasRunningTaskUnlocked(): # pylint: disable=W0212
447
448
449
        return True
    return False

450
451
452
453
454
455
456
457
458
459
  def HasRunningTasks(self):
    """Checks whether there's at least one task running.

    """
    self._lock.acquire()
    try:
      return self._HasRunningTasksUnlocked()
    finally:
      self._lock.release()

460
461
462
463
464
465
466
467
468
469
  def Quiesce(self):
    """Waits until the task queue is empty.

    """
    self._lock.acquire()
    try:
      self._quiescing = True

      # Wait while there are tasks pending or running
      while self._tasks or self._HasRunningTasksUnlocked():
470
        self._worker_to_pool.wait()
471
472
473
474
475

    finally:
      self._quiescing = False

      # Make sure AddTasks continues in case it was waiting
476
      self._pool_to_pool.notifyAll()
477
478
479
480

      self._lock.release()

  def _NewWorkerIdUnlocked(self):
481
482
483
    """Return an identifier for a new worker.

    """
484
    self._last_worker_id += 1
485
486

    return "%s%d" % (self._name, self._last_worker_id)
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514

  def _ResizeUnlocked(self, num_workers):
    """Changes the number of workers.

    """
    assert num_workers >= 0, "num_workers must be >= 0"

    logging.debug("Resizing to %s workers", num_workers)

    current_count = len(self._workers)

    if current_count == num_workers:
      # Nothing to do
      pass

    elif current_count > num_workers:
      if num_workers == 0:
        # Create copy of list to iterate over while lock isn't held.
        termworkers = self._workers[:]
        del self._workers[:]
      else:
        # TODO: Implement partial downsizing
        raise NotImplementedError()
        #termworkers = ...

      self._termworkers += termworkers

      # Notify workers that something has changed
515
      self._pool_to_worker.notifyAll()
516
517
518
519
520

      # Join all terminating workers
      self._lock.release()
      try:
        for worker in termworkers:
521
          logging.debug("Waiting for thread %s", worker.getName())
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
          worker.join()
      finally:
        self._lock.acquire()

      # Remove terminated threads. This could be done in a more efficient way
      # (del self._termworkers[:]), but checking worker.isAlive() makes sure we
      # don't leave zombie threads around.
      for worker in termworkers:
        assert worker in self._termworkers, ("Worker not in list of"
                                             " terminating workers")
        if not worker.isAlive():
          self._termworkers.remove(worker)

      assert not self._termworkers, "Zombie worker detected"

    elif current_count < num_workers:
      # Create (num_workers - current_count) new workers
539
      for _ in range(num_workers - current_count):
540
541
542
543
544
545
546
        worker = self._worker_class(self, self._NewWorkerIdUnlocked())
        self._workers.append(worker)
        worker.start()

  def Resize(self, num_workers):
    """Changes the number of workers in the pool.

547
    @param num_workers: the new number of workers
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573

    """
    self._lock.acquire()
    try:
      return self._ResizeUnlocked(num_workers)
    finally:
      self._lock.release()

  def TerminateWorkers(self):
    """Terminate all worker threads.

    Unstarted tasks will be ignored.

    """
    logging.debug("Terminating all workers")

    self._lock.acquire()
    try:
      self._ResizeUnlocked(0)

      if self._tasks:
        logging.debug("There are %s tasks left", len(self._tasks))
    finally:
      self._lock.release()

    logging.debug("All workers terminated")