workerpool.py 17.8 KB
Newer Older
1
2
3
#
#

4
# Copyright (C) 2008, 2009, 2010 Google Inc.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
# 02110-1301, USA.


"""Base classes for worker pools.

"""

import logging
import threading
28
import heapq
29
import itertools
30

31
from ganeti import compat
32
from ganeti import errors
33

34

35
_TERMINATE = object()
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
_DEFAULT_PRIORITY = 0


class DeferTask(Exception):
  """Special exception class to defer a task.

  This class can be raised by L{BaseWorker.RunTask} to defer the execution of a
  task. Optionally, the priority of the task can be changed.

  """
  def __init__(self, priority=None):
    """Initializes this class.

    @type priority: number
    @param priority: New task priority (None means no change)

    """
    Exception.__init__(self)
    self.priority = priority
55
56


57
58
59
60
61
62
class NoSuchTask(Exception):
  """Exception raised when a task can't be found.

  """


63
64
65
66
67
68
class BaseWorker(threading.Thread, object):
  """Base worker class for worker pools.

  Users of a worker pool must override RunTask in a subclass.

  """
69
  # pylint: disable=W0212
70
71
72
  def __init__(self, pool, worker_id):
    """Constructor for BaseWorker thread.

73
74
    @param pool: the parent worker pool
    @param worker_id: identifier for this worker
75
76

    """
77
    super(BaseWorker, self).__init__(name=worker_id)
78
    self.pool = pool
79
    self._worker_id = worker_id
80
81
    self._current_task = None

82
83
    assert self.getName() == worker_id

84
  def ShouldTerminate(self):
85
86
87
    """Returns whether this worker should terminate.

    Should only be called from within L{RunTask}.
88
89

    """
90
91
92
93
94
95
    self.pool._lock.acquire()
    try:
      assert self._HasRunningTaskUnlocked()
      return self.pool._ShouldWorkerTerminateUnlocked(self)
    finally:
      self.pool._lock.release()
96

97
98
99
100
101
102
103
104
105
106
  def GetCurrentPriority(self):
    """Returns the priority of the current task.

    Should only be called from within L{RunTask}.

    """
    self.pool._lock.acquire()
    try:
      assert self._HasRunningTaskUnlocked()

107
      (priority, _, _, _) = self._current_task
108
109
110
111
112

      return priority
    finally:
      self.pool._lock.release()

113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
  def SetTaskName(self, taskname):
    """Sets the name of the current task.

    Should only be called from within L{RunTask}.

    @type taskname: string
    @param taskname: Task's name

    """
    if taskname:
      name = "%s/%s" % (self._worker_id, taskname)
    else:
      name = self._worker_id

    # Set thread name
    self.setName(name)

130
131
132
133
134
135
  def _HasRunningTaskUnlocked(self):
    """Returns whether this worker is currently running a task.

    """
    return (self._current_task is not None)

136
137
138
139
140
141
142
143
144
  def run(self):
    """Main thread function.

    Waits for new tasks to show up in the queue.

    """
    pool = self.pool

    while True:
145
      assert self._current_task is None
146
147

      defer = None
148
      try:
149
        # Wait on lock to be told either to terminate or to do a task
150
151
        pool._lock.acquire()
        try:
152
          task = pool._WaitForTaskUnlocked(self)
153

154
155
156
          if task is _TERMINATE:
            # Told to terminate
            break
157

158
159
160
          if task is None:
            # Spurious notification, ignore
            continue
161

162
          self._current_task = task
163

164
165
166
          # No longer needed, dispose of reference
          del task

167
          assert self._HasRunningTaskUnlocked()
168

169
170
171
        finally:
          pool._lock.release()

172
        (priority, _, _, args) = self._current_task
173
        try:
174
175
176
          # Run the actual task
          assert defer is None
          logging.debug("Starting task %r, priority %s", args, priority)
177
178
          assert self.getName() == self._worker_id
          try:
179
            self.RunTask(*args) # pylint: disable=W0142
180
181
          finally:
            self.SetTaskName(None)
182
183
184
185
186
187
188
189
          logging.debug("Done with task %r, priority %s", args, priority)
        except DeferTask, err:
          defer = err

          if defer.priority is None:
            # Use same priority
            defer.priority = priority

Michael Hanselmann's avatar
Michael Hanselmann committed
190
191
          logging.debug("Deferring task %r, new priority %s",
                        args, defer.priority)
192
193

          assert self._HasRunningTaskUnlocked()
194
        except: # pylint: disable=W0702
195
          logging.exception("Caught unhandled exception")
196
197

        assert self._HasRunningTaskUnlocked()
198
199
200
201
      finally:
        # Notify pool
        pool._lock.acquire()
        try:
202
203
204
          if defer:
            assert self._current_task
            # Schedule again for later run
205
206
            (_, _, _, args) = self._current_task
            pool._AddTaskUnlocked(args, defer.priority, None)
207

208
209
          if self._current_task:
            self._current_task = None
210
            pool._worker_to_pool.notifyAll()
211
212
213
        finally:
          pool._lock.release()

214
215
      assert not self._HasRunningTaskUnlocked()

216
    logging.debug("Terminates")
217

218
219
220
  def RunTask(self, *args):
    """Function called to start a task.

221
222
    This needs to be implemented by child classes.

223
224
225
226
227
228
229
230
231
    """
    raise NotImplementedError()


class WorkerPool(object):
  """Worker pool with a queue.

  This class is thread-safe.

232
233
234
  Tasks are guaranteed to be started in the order in which they're
  added to the pool. Due to the nature of threading, they're not
  guaranteed to finish in the same order.
235

236
237
238
239
240
241
242
243
244
245
246
247
  @type _tasks: list of tuples
  @ivar _tasks: Each tuple has the format (priority, order ID, task ID,
    arguments). Priority and order ID are numeric and essentially control the
    sort order. The order ID is an increasing number denoting the order in
    which tasks are added to the queue. The task ID is controlled by user of
    workerpool, see L{AddTask} for details. The task arguments are C{None} for
    abandoned tasks, otherwise a sequence of arguments to be passed to
    L{BaseWorker.RunTask}). The list must fulfill the heap property (for use by
    the C{heapq} module).
  @type _taskdata: dict; (task IDs as keys, tuples as values)
  @ivar _taskdata: Mapping from task IDs to entries in L{_tasks}

248
  """
249
  def __init__(self, name, num_workers, worker_class):
250
251
    """Constructor for worker pool.

252
253
254
255
    @param num_workers: number of workers to be started
        (dynamic resizing is not yet implemented)
    @param worker_class: the class to be instantiated for workers;
        should derive from L{BaseWorker}
256
257
258

    """
    # Some of these variables are accessed by BaseWorker
259
260
261
262
    self._lock = threading.Lock()
    self._pool_to_pool = threading.Condition(self._lock)
    self._pool_to_worker = threading.Condition(self._lock)
    self._worker_to_pool = threading.Condition(self._lock)
263
    self._worker_class = worker_class
264
    self._name = name
265
266
267
    self._last_worker_id = 0
    self._workers = []
    self._quiescing = False
268
    self._active = True
269
270
271
272
273

    # Terminating workers
    self._termworkers = []

    # Queued tasks
274
    self._counter = itertools.count()
275
    self._tasks = []
276
    self._taskdata = {}
277
278
279
280
281
282

    # Start workers
    self.Resize(num_workers)

  # TODO: Implement dynamic resizing?

Guido Trotter's avatar
Guido Trotter committed
283
284
285
286
287
288
289
  def _WaitWhileQuiescingUnlocked(self):
    """Wait until the worker pool has finished quiescing.

    """
    while self._quiescing:
      self._pool_to_pool.wait()

290
  def _AddTaskUnlocked(self, args, priority, task_id):
291
292
293
294
295
296
    """Adds a task to the internal queue.

    @type args: sequence
    @param args: Arguments passed to L{BaseWorker.RunTask}
    @type priority: number
    @param priority: Task priority
297
    @param task_id: Task ID
298
299

    """
300
    assert isinstance(args, (tuple, list)), "Arguments must be a sequence"
301
    assert isinstance(priority, (int, long)), "Priority must be numeric"
302

303
304
305
306
307
308
309
    task = [priority, self._counter.next(), task_id, args]

    if task_id is not None:
      assert task_id not in self._taskdata
      # Keep a reference to change priority later if necessary
      self._taskdata[task_id] = task

310
311
    # A counter is used to ensure elements are processed in their incoming
    # order. For processing they're sorted by priority and then counter.
312
    heapq.heappush(self._tasks, task)
313
314
315
316

    # Notify a waiting worker
    self._pool_to_worker.notify()

317
  def AddTask(self, args, priority=_DEFAULT_PRIORITY, task_id=None):
318
319
    """Adds a task to the queue.

320
    @type args: sequence
321
    @param args: arguments passed to L{BaseWorker.RunTask}
322
323
    @type priority: number
    @param priority: Task priority
324
325
326
327
328
    @param task_id: Task ID
    @note: The task ID can be essentially anything that can be used as a
      dictionary key. Callers, however, must ensure a task ID is unique while a
      task is in the pool or while it might return to the pool due to deferring
      using L{DeferTask}.
329
330
331
332

    """
    self._lock.acquire()
    try:
Guido Trotter's avatar
Guido Trotter committed
333
      self._WaitWhileQuiescingUnlocked()
334
      self._AddTaskUnlocked(args, priority, task_id)
335
336
337
    finally:
      self._lock.release()

338
  def AddManyTasks(self, tasks, priority=_DEFAULT_PRIORITY, task_id=None):
Guido Trotter's avatar
Guido Trotter committed
339
340
341
342
    """Add a list of tasks to the queue.

    @type tasks: list of tuples
    @param tasks: list of args passed to L{BaseWorker.RunTask}
343
344
345
    @type priority: number or list of numbers
    @param priority: Priority for all added tasks or a list with the priority
                     for each task
346
347
348
    @type task_id: list
    @param task_id: List with the ID for each task
    @note: See L{AddTask} for a note on task IDs.
Guido Trotter's avatar
Guido Trotter committed
349
350

    """
351
    assert compat.all(isinstance(task, (tuple, list)) for task in tasks), \
352
           "Each task must be a sequence"
353
354
355
    assert (isinstance(priority, (int, long)) or
            compat.all(isinstance(prio, (int, long)) for prio in priority)), \
           "Priority must be numeric or be a list of numeric values"
356
357
    assert task_id is None or isinstance(task_id, (tuple, list)), \
           "Task IDs must be in a sequence"
358
359
360
361
362
363
364
365

    if isinstance(priority, (int, long)):
      priority = [priority] * len(tasks)
    elif len(priority) != len(tasks):
      raise errors.ProgrammerError("Number of priorities (%s) doesn't match"
                                   " number of tasks (%s)" %
                                   (len(priority), len(tasks)))

366
367
368
369
370
371
372
    if task_id is None:
      task_id = [None] * len(tasks)
    elif len(task_id) != len(tasks):
      raise errors.ProgrammerError("Number of task IDs (%s) doesn't match"
                                   " number of tasks (%s)" %
                                   (len(task_id), len(tasks)))

Guido Trotter's avatar
Guido Trotter committed
373
374
375
376
    self._lock.acquire()
    try:
      self._WaitWhileQuiescingUnlocked()

377
378
      assert compat.all(isinstance(prio, (int, long)) for prio in priority)
      assert len(tasks) == len(priority)
379
      assert len(tasks) == len(task_id)
380

381
382
      for (args, prio, tid) in zip(tasks, priority, task_id):
        self._AddTaskUnlocked(args, prio, tid)
Guido Trotter's avatar
Guido Trotter committed
383
384
385
    finally:
      self._lock.release()

386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
  def ChangeTaskPriority(self, task_id, priority):
    """Changes a task's priority.

    @param task_id: Task ID
    @type priority: number
    @param priority: New task priority
    @raise NoSuchTask: When the task referred by C{task_id} can not be found
      (it may never have existed, may have already been processed, or is
      currently running)

    """
    assert isinstance(priority, (int, long)), "Priority must be numeric"

    self._lock.acquire()
    try:
      logging.debug("About to change priority of task %s to %s",
                    task_id, priority)

      # Find old task
      oldtask = self._taskdata.get(task_id, None)
      if oldtask is None:
        msg = "Task '%s' was not found" % task_id
        logging.debug(msg)
        raise NoSuchTask(msg)

      # Prepare new task
      newtask = [priority] + oldtask[1:]

      # Mark old entry as abandoned (this doesn't change the sort order and
      # therefore doesn't invalidate the heap property of L{self._tasks}).
      # See also <http://docs.python.org/library/heapq.html#priority-queue-
      # implementation-notes>.
      oldtask[-1] = None

      # Change reference to new task entry and forget the old one
      assert task_id is not None
      self._taskdata[task_id] = newtask

      # Add a new task with the old number and arguments
      heapq.heappush(self._tasks, newtask)

      # Notify a waiting worker
      self._pool_to_worker.notify()
    finally:
      self._lock.release()

432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
  def SetActive(self, active):
    """Enable/disable processing of tasks.

    This is different from L{Quiesce} in the sense that this function just
    changes an internal flag and doesn't wait for the queue to be empty. Tasks
    already being processed continue normally, but no new tasks will be
    started. New tasks can still be added.

    @type active: bool
    @param active: Whether tasks should be processed

    """
    self._lock.acquire()
    try:
      self._active = active

      if active:
        # Tell all workers to continue processing
        self._pool_to_worker.notifyAll()
    finally:
      self._lock.release()

454
455
456
457
458
459
460
  def _WaitForTaskUnlocked(self, worker):
    """Waits for a task for a worker.

    @type worker: L{BaseWorker}
    @param worker: Worker thread

    """
461
462
463
    while True:
      if self._ShouldWorkerTerminateUnlocked(worker):
        return _TERMINATE
464

465
466
467
468
469
470
471
      # If there's a pending task, return it immediately
      if self._active and self._tasks:
        # Get task from queue and tell pool about it
        try:
          task = heapq.heappop(self._tasks)
        finally:
          self._worker_to_pool.notifyAll()
472

473
474
475
476
477
478
479
480
        (_, _, task_id, args) = task

        # If the priority was changed, "args" is None
        if args is None:
          # Try again
          logging.debug("Found abandoned task (%r)", task)
          continue

481
482
483
484
        # Delete reference
        if task_id is not None:
          del self._taskdata[task_id]

485
        return task
486

487
      logging.debug("Waiting for tasks")
488

489
490
      # wait() releases the lock and sleeps until notified
      self._pool_to_worker.wait()
491

492
      logging.debug("Notified while waiting")
493

494
495
496
497
498
499
500
501
502
503
504
  def _ShouldWorkerTerminateUnlocked(self, worker):
    """Returns whether a worker should terminate.

    """
    return (worker in self._termworkers)

  def _HasRunningTasksUnlocked(self):
    """Checks whether there's a task running in a worker.

    """
    for worker in self._workers + self._termworkers:
505
      if worker._HasRunningTaskUnlocked(): # pylint: disable=W0212
506
507
508
        return True
    return False

509
510
511
512
513
514
515
516
517
518
  def HasRunningTasks(self):
    """Checks whether there's at least one task running.

    """
    self._lock.acquire()
    try:
      return self._HasRunningTasksUnlocked()
    finally:
      self._lock.release()

519
520
521
522
523
524
525
526
527
528
  def Quiesce(self):
    """Waits until the task queue is empty.

    """
    self._lock.acquire()
    try:
      self._quiescing = True

      # Wait while there are tasks pending or running
      while self._tasks or self._HasRunningTasksUnlocked():
529
        self._worker_to_pool.wait()
530
531
532
533
534

    finally:
      self._quiescing = False

      # Make sure AddTasks continues in case it was waiting
535
      self._pool_to_pool.notifyAll()
536
537
538
539

      self._lock.release()

  def _NewWorkerIdUnlocked(self):
540
541
542
    """Return an identifier for a new worker.

    """
543
    self._last_worker_id += 1
544
545

    return "%s%d" % (self._name, self._last_worker_id)
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573

  def _ResizeUnlocked(self, num_workers):
    """Changes the number of workers.

    """
    assert num_workers >= 0, "num_workers must be >= 0"

    logging.debug("Resizing to %s workers", num_workers)

    current_count = len(self._workers)

    if current_count == num_workers:
      # Nothing to do
      pass

    elif current_count > num_workers:
      if num_workers == 0:
        # Create copy of list to iterate over while lock isn't held.
        termworkers = self._workers[:]
        del self._workers[:]
      else:
        # TODO: Implement partial downsizing
        raise NotImplementedError()
        #termworkers = ...

      self._termworkers += termworkers

      # Notify workers that something has changed
574
      self._pool_to_worker.notifyAll()
575
576
577
578
579

      # Join all terminating workers
      self._lock.release()
      try:
        for worker in termworkers:
580
          logging.debug("Waiting for thread %s", worker.getName())
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
          worker.join()
      finally:
        self._lock.acquire()

      # Remove terminated threads. This could be done in a more efficient way
      # (del self._termworkers[:]), but checking worker.isAlive() makes sure we
      # don't leave zombie threads around.
      for worker in termworkers:
        assert worker in self._termworkers, ("Worker not in list of"
                                             " terminating workers")
        if not worker.isAlive():
          self._termworkers.remove(worker)

      assert not self._termworkers, "Zombie worker detected"

    elif current_count < num_workers:
      # Create (num_workers - current_count) new workers
598
      for _ in range(num_workers - current_count):
599
600
601
602
603
604
605
        worker = self._worker_class(self, self._NewWorkerIdUnlocked())
        self._workers.append(worker)
        worker.start()

  def Resize(self, num_workers):
    """Changes the number of workers in the pool.

606
    @param num_workers: the new number of workers
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632

    """
    self._lock.acquire()
    try:
      return self._ResizeUnlocked(num_workers)
    finally:
      self._lock.release()

  def TerminateWorkers(self):
    """Terminate all worker threads.

    Unstarted tasks will be ignored.

    """
    logging.debug("Terminating all workers")

    self._lock.acquire()
    try:
      self._ResizeUnlocked(0)

      if self._tasks:
        logging.debug("There are %s tasks left", len(self._tasks))
    finally:
      self._lock.release()

    logging.debug("All workers terminated")