From 4e33853308ae1ed4369b17361731b26d6f775c12 Mon Sep 17 00:00:00 2001
From: Michael Hanselmann <hansmi@google.com>
Date: Mon, 10 May 2010 13:57:04 +0200
Subject: [PATCH] cli: Make PollJob generic to support other protocols

By separating the LUXI-specific code and stdio-related code
into separate classes, we can make cli.PollJob protocol-
agnostic, allowing it to be used with RAPI.

This patch also adds unittests for cli.PollJob.

Signed-off-by: Michael Hanselmann <hansmi@google.com>
Reviewed-by: Iustin Pop <iustin@google.com>
---
 lib/cli.py                  | 236 +++++++++++++++++++++++++++++-------
 test/ganeti.cli_unittest.py | 174 +++++++++++++++++++++++++-
 2 files changed, 364 insertions(+), 46 deletions(-)

diff --git a/lib/cli.py b/lib/cli.py
index 8de292c7b..f598320e6 100644
--- a/lib/cli.py
+++ b/lib/cli.py
@@ -1223,41 +1223,31 @@ def SendJob(ops, cl=None):
   return job_id
 
 
-def PollJob(job_id, cl=None, feedback_fn=None):
-  """Function to poll for the result of a job.
+def GenericPollJob(job_id, cbs, report_cbs):
+  """Generic job-polling function.
 
-  @type job_id: job identified
-  @param job_id: the job to poll for results
-  @type cl: luxi.Client
-  @param cl: the luxi client to use for communicating with the master;
-             if None, a new client will be created
+  @type job_id: number
+  @param job_id: Job ID
+  @type cbs: Instance of L{JobPollCbBase}
+  @param cbs: Data callbacks
+  @type report_cbs: Instance of L{JobPollReportCbBase}
+  @param report_cbs: Reporting callbacks
 
   """
-  if cl is None:
-    cl = GetClient()
-
   prev_job_info = None
   prev_logmsg_serial = None
 
   status = None
 
-  notified_queued = False
-  notified_waitlock = False
-
   while True:
-    result = cl.WaitForJobChangeOnce(job_id, ["status"], prev_job_info,
-                                     prev_logmsg_serial)
+    result = cbs.WaitForJobChangeOnce(job_id, ["status"], prev_job_info,
+                                      prev_logmsg_serial)
     if not result:
       # job not found, go away!
       raise errors.JobLost("Job with id %s lost" % job_id)
-    elif result == constants.JOB_NOTCHANGED:
-      if status is not None and not callable(feedback_fn):
-        if status == constants.JOB_STATUS_QUEUED and not notified_queued:
-          ToStderr("Job %s is waiting in queue", job_id)
-          notified_queued = True
-        elif status == constants.JOB_STATUS_WAITLOCK and not notified_waitlock:
-          ToStderr("Job %s is trying to acquire all necessary locks", job_id)
-          notified_waitlock = True
+
+    if result == constants.JOB_NOTCHANGED:
+      report_cbs.ReportNotChanged(job_id, status)
 
       # Wait again
       continue
@@ -1268,12 +1258,9 @@ def PollJob(job_id, cl=None, feedback_fn=None):
 
     if log_entries:
       for log_entry in log_entries:
-        (serial, timestamp, _, message) = log_entry
-        if callable(feedback_fn):
-          feedback_fn(log_entry[1:])
-        else:
-          encoded = utils.SafeEncode(message)
-          ToStdout("%s %s", time.ctime(utils.MergeTime(timestamp)), encoded)
+        (serial, timestamp, log_type, message) = log_entry
+        report_cbs.ReportLogMessage(job_id, serial, timestamp,
+                                    log_type, message)
         prev_logmsg_serial = max(prev_logmsg_serial, serial)
 
     # TODO: Handle canceled and archived jobs
@@ -1285,30 +1272,189 @@ def PollJob(job_id, cl=None, feedback_fn=None):
 
     prev_job_info = job_info
 
-  jobs = cl.QueryJobs([job_id], ["status", "opstatus", "opresult"])
+  jobs = cbs.QueryJobs([job_id], ["status", "opstatus", "opresult"])
   if not jobs:
     raise errors.JobLost("Job with id %s lost" % job_id)
 
   status, opstatus, result = jobs[0]
+
   if status == constants.JOB_STATUS_SUCCESS:
     return result
-  elif status in (constants.JOB_STATUS_CANCELING,
-                  constants.JOB_STATUS_CANCELED):
+
+  if status in (constants.JOB_STATUS_CANCELING, constants.JOB_STATUS_CANCELED):
     raise errors.OpExecError("Job was canceled")
+
+  has_ok = False
+  for idx, (status, msg) in enumerate(zip(opstatus, result)):
+    if status == constants.OP_STATUS_SUCCESS:
+      has_ok = True
+    elif status == constants.OP_STATUS_ERROR:
+      errors.MaybeRaise(msg)
+
+      if has_ok:
+        raise errors.OpExecError("partial failure (opcode %d): %s" %
+                                 (idx, msg))
+
+      raise errors.OpExecError(str(msg))
+
+  # default failure mode
+  raise errors.OpExecError(result)
+
+
+class JobPollCbBase:
+  """Base class for L{GenericPollJob} callbacks.
+
+  """
+  def __init__(self):
+    """Initializes this class.
+
+    """
+
+  def WaitForJobChangeOnce(self, job_id, fields,
+                           prev_job_info, prev_log_serial):
+    """Waits for changes on a job.
+
+    """
+    raise NotImplementedError()
+
+  def QueryJobs(self, job_ids, fields):
+    """Returns the selected fields for the selected job IDs.
+
+    @type job_ids: list of numbers
+    @param job_ids: Job IDs
+    @type fields: list of strings
+    @param fields: Fields
+
+    """
+    raise NotImplementedError()
+
+
+class JobPollReportCbBase:
+  """Base class for L{GenericPollJob} reporting callbacks.
+
+  """
+  def __init__(self):
+    """Initializes this class.
+
+    """
+
+  def ReportLogMessage(self, job_id, serial, timestamp, log_type, log_msg):
+    """Handles a log message.
+
+    """
+    raise NotImplementedError()
+
+  def ReportNotChanged(self, job_id, status):
+    """Called for if a job hasn't changed in a while.
+
+    @type job_id: number
+    @param job_id: Job ID
+    @type status: string or None
+    @param status: Job status if available
+
+    """
+    raise NotImplementedError()
+
+
+class _LuxiJobPollCb(JobPollCbBase):
+  def __init__(self, cl):
+    """Initializes this class.
+
+    """
+    JobPollCbBase.__init__(self)
+    self.cl = cl
+
+  def WaitForJobChangeOnce(self, job_id, fields,
+                           prev_job_info, prev_log_serial):
+    """Waits for changes on a job.
+
+    """
+    return self.cl.WaitForJobChangeOnce(job_id, fields,
+                                        prev_job_info, prev_log_serial)
+
+  def QueryJobs(self, job_ids, fields):
+    """Returns the selected fields for the selected job IDs.
+
+    """
+    return self.cl.QueryJobs(job_ids, fields)
+
+
+class FeedbackFnJobPollReportCb(JobPollReportCbBase):
+  def __init__(self, feedback_fn):
+    """Initializes this class.
+
+    """
+    JobPollReportCbBase.__init__(self)
+
+    self.feedback_fn = feedback_fn
+
+    assert callable(feedback_fn)
+
+  def ReportLogMessage(self, job_id, serial, timestamp, log_type, log_msg):
+    """Handles a log message.
+
+    """
+    self.feedback_fn((timestamp, log_type, log_msg))
+
+  def ReportNotChanged(self, job_id, status):
+    """Called if a job hasn't changed in a while.
+
+    """
+    # Ignore
+
+
+class StdioJobPollReportCb(JobPollReportCbBase):
+  def __init__(self):
+    """Initializes this class.
+
+    """
+    JobPollReportCbBase.__init__(self)
+
+    self.notified_queued = False
+    self.notified_waitlock = False
+
+  def ReportLogMessage(self, job_id, serial, timestamp, log_type, log_msg):
+    """Handles a log message.
+
+    """
+    ToStdout("%s %s", time.ctime(utils.MergeTime(timestamp)),
+             utils.SafeEncode(log_msg))
+
+  def ReportNotChanged(self, job_id, status):
+    """Called if a job hasn't changed in a while.
+
+    """
+    if status is None:
+      return
+
+    if status == constants.JOB_STATUS_QUEUED and not self.notified_queued:
+      ToStderr("Job %s is waiting in queue", job_id)
+      self.notified_queued = True
+
+    elif status == constants.JOB_STATUS_WAITLOCK and not self.notified_waitlock:
+      ToStderr("Job %s is trying to acquire all necessary locks", job_id)
+      self.notified_waitlock = True
+
+
+def PollJob(job_id, cl=None, feedback_fn=None):
+  """Function to poll for the result of a job.
+
+  @type job_id: job identified
+  @param job_id: the job to poll for results
+  @type cl: luxi.Client
+  @param cl: the luxi client to use for communicating with the master;
+             if None, a new client will be created
+
+  """
+  if cl is None:
+    cl = GetClient()
+
+  if feedback_fn:
+    reporter = FeedbackFnJobPollReportCb(feedback_fn)
   else:
-    has_ok = False
-    for idx, (status, msg) in enumerate(zip(opstatus, result)):
-      if status == constants.OP_STATUS_SUCCESS:
-        has_ok = True
-      elif status == constants.OP_STATUS_ERROR:
-        errors.MaybeRaise(msg)
-        if has_ok:
-          raise errors.OpExecError("partial failure (opcode %d): %s" %
-                                   (idx, msg))
-        else:
-          raise errors.OpExecError(str(msg))
-    # default failure mode
-    raise errors.OpExecError(result)
+    reporter = StdioJobPollReportCb()
+
+  return GenericPollJob(job_id, _LuxiJobPollCb(cl), reporter)
 
 
 def SubmitOpCode(op, cl=None, feedback_fn=None, opts=None):
diff --git a/test/ganeti.cli_unittest.py b/test/ganeti.cli_unittest.py
index b768f4a03..2e5a6e3b6 100755
--- a/test/ganeti.cli_unittest.py
+++ b/test/ganeti.cli_unittest.py
@@ -29,6 +29,8 @@ import testutils
 
 from ganeti import constants
 from ganeti import cli
+from ganeti import errors
+from ganeti import utils
 from ganeti.errors import OpPrereqError, ParameterError
 
 
@@ -100,7 +102,7 @@ class TestIdentKeyVal(unittest.TestCase):
 
 
 class TestToStream(unittest.TestCase):
-  """Thes the ToStream functions"""
+  """Test the ToStream functions"""
 
   def testBasic(self):
     for data in ["foo",
@@ -246,5 +248,175 @@ class TestGenerateTable(unittest.TestCase):
                None, None, "m", exp)
 
 
+class _MockJobPollCb(cli.JobPollCbBase, cli.JobPollReportCbBase):
+  def __init__(self, tc, job_id):
+    self.tc = tc
+    self.job_id = job_id
+    self._wfjcr = []
+    self._jobstatus = []
+    self._expect_notchanged = False
+    self._expect_log = []
+
+  def CheckEmpty(self):
+    self.tc.assertFalse(self._wfjcr)
+    self.tc.assertFalse(self._jobstatus)
+    self.tc.assertFalse(self._expect_notchanged)
+    self.tc.assertFalse(self._expect_log)
+
+  def AddWfjcResult(self, *args):
+    self._wfjcr.append(args)
+
+  def AddQueryJobsResult(self, *args):
+    self._jobstatus.append(args)
+
+  def WaitForJobChangeOnce(self, job_id, fields,
+                           prev_job_info, prev_log_serial):
+    self.tc.assertEqual(job_id, self.job_id)
+    self.tc.assertEqualValues(fields, ["status"])
+    self.tc.assertFalse(self._expect_notchanged)
+    self.tc.assertFalse(self._expect_log)
+
+    (exp_prev_job_info, exp_prev_log_serial, result) = self._wfjcr.pop(0)
+    self.tc.assertEqualValues(prev_job_info, exp_prev_job_info)
+    self.tc.assertEqual(prev_log_serial, exp_prev_log_serial)
+
+    if result == constants.JOB_NOTCHANGED:
+      self._expect_notchanged = True
+    elif result:
+      (_, logmsgs) = result
+      if logmsgs:
+        self._expect_log.extend(logmsgs)
+
+    return result
+
+  def QueryJobs(self, job_ids, fields):
+    self.tc.assertEqual(job_ids, [self.job_id])
+    self.tc.assertEqualValues(fields, ["status", "opstatus", "opresult"])
+    self.tc.assertFalse(self._expect_notchanged)
+    self.tc.assertFalse(self._expect_log)
+
+    result = self._jobstatus.pop(0)
+    self.tc.assertEqual(len(fields), len(result))
+    return [result]
+
+  def ReportLogMessage(self, job_id, serial, timestamp, log_type, log_msg):
+    self.tc.assertEqual(job_id, self.job_id)
+    self.tc.assertEqualValues((serial, timestamp, log_type, log_msg),
+                              self._expect_log.pop(0))
+
+  def ReportNotChanged(self, job_id, status):
+    self.tc.assertEqual(job_id, self.job_id)
+    self.tc.assert_(self._expect_notchanged)
+    self._expect_notchanged = False
+
+
+class TestGenericPollJob(testutils.GanetiTestCase):
+  def testSuccessWithLog(self):
+    job_id = 29609
+    cbs = _MockJobPollCb(self, job_id)
+
+    cbs.AddWfjcResult(None, None, constants.JOB_NOTCHANGED)
+
+    cbs.AddWfjcResult(None, None,
+                      ((constants.JOB_STATUS_QUEUED, ), None))
+
+    cbs.AddWfjcResult((constants.JOB_STATUS_QUEUED, ), None,
+                      constants.JOB_NOTCHANGED)
+
+    cbs.AddWfjcResult((constants.JOB_STATUS_QUEUED, ), None,
+                      ((constants.JOB_STATUS_RUNNING, ),
+                       [(1, utils.SplitTime(1273491611.0),
+                         constants.ELOG_MESSAGE, "Step 1"),
+                        (2, utils.SplitTime(1273491615.9),
+                         constants.ELOG_MESSAGE, "Step 2"),
+                        (3, utils.SplitTime(1273491625.02),
+                         constants.ELOG_MESSAGE, "Step 3"),
+                        (4, utils.SplitTime(1273491635.05),
+                         constants.ELOG_MESSAGE, "Step 4"),
+                        (37, utils.SplitTime(1273491645.0),
+                         constants.ELOG_MESSAGE, "Step 5"),
+                        (203, utils.SplitTime(127349155.0),
+                         constants.ELOG_MESSAGE, "Step 6")]))
+
+    cbs.AddWfjcResult((constants.JOB_STATUS_RUNNING, ), 203,
+                      ((constants.JOB_STATUS_RUNNING, ),
+                       [(300, utils.SplitTime(1273491711.01),
+                         constants.ELOG_MESSAGE, "Step X"),
+                        (302, utils.SplitTime(1273491815.8),
+                         constants.ELOG_MESSAGE, "Step Y"),
+                        (303, utils.SplitTime(1273491925.32),
+                         constants.ELOG_MESSAGE, "Step Z")]))
+
+    cbs.AddWfjcResult((constants.JOB_STATUS_RUNNING, ), 303,
+                      ((constants.JOB_STATUS_SUCCESS, ), None))
+
+    cbs.AddQueryJobsResult(constants.JOB_STATUS_SUCCESS,
+                           [constants.OP_STATUS_SUCCESS,
+                            constants.OP_STATUS_SUCCESS],
+                           ["Hello World", "Foo man bar"])
+
+    self.assertEqual(["Hello World", "Foo man bar"],
+                     cli.GenericPollJob(job_id, cbs, cbs))
+    cbs.CheckEmpty()
+
+  def testJobLost(self):
+    job_id = 13746
+
+    cbs = _MockJobPollCb(self, job_id)
+    cbs.AddWfjcResult(None, None, constants.JOB_NOTCHANGED)
+    cbs.AddWfjcResult(None, None, None)
+    self.assertRaises(errors.JobLost, cli.GenericPollJob, job_id, cbs, cbs)
+    cbs.CheckEmpty()
+
+  def testError(self):
+    job_id = 31088
+
+    cbs = _MockJobPollCb(self, job_id)
+    cbs.AddWfjcResult(None, None, constants.JOB_NOTCHANGED)
+    cbs.AddWfjcResult(None, None, ((constants.JOB_STATUS_ERROR, ), None))
+    cbs.AddQueryJobsResult(constants.JOB_STATUS_ERROR,
+                           [constants.OP_STATUS_SUCCESS,
+                            constants.OP_STATUS_ERROR],
+                           ["Hello World", "Error code 123"])
+    self.assertRaises(errors.OpExecError, cli.GenericPollJob, job_id, cbs, cbs)
+    cbs.CheckEmpty()
+
+  def testError2(self):
+    job_id = 22235
+
+    cbs = _MockJobPollCb(self, job_id)
+    cbs.AddWfjcResult(None, None, ((constants.JOB_STATUS_ERROR, ), None))
+    encexc = errors.EncodeException(errors.LockError("problem"))
+    cbs.AddQueryJobsResult(constants.JOB_STATUS_ERROR,
+                           [constants.OP_STATUS_ERROR], [encexc])
+    self.assertRaises(errors.LockError, cli.GenericPollJob, job_id, cbs, cbs)
+    cbs.CheckEmpty()
+
+  def testWeirdError(self):
+    job_id = 28847
+
+    cbs = _MockJobPollCb(self, job_id)
+    cbs.AddWfjcResult(None, None, ((constants.JOB_STATUS_ERROR, ), None))
+    cbs.AddQueryJobsResult(constants.JOB_STATUS_ERROR,
+                           [constants.OP_STATUS_RUNNING,
+                            constants.OP_STATUS_RUNNING],
+                           [None, None])
+    self.assertRaises(errors.OpExecError, cli.GenericPollJob, job_id, cbs, cbs)
+    cbs.CheckEmpty()
+
+  def testCancel(self):
+    job_id = 4275
+
+    cbs = _MockJobPollCb(self, job_id)
+    cbs.AddWfjcResult(None, None, constants.JOB_NOTCHANGED)
+    cbs.AddWfjcResult(None, None, ((constants.JOB_STATUS_CANCELING, ), None))
+    cbs.AddQueryJobsResult(constants.JOB_STATUS_CANCELING,
+                           [constants.OP_STATUS_CANCELING,
+                            constants.OP_STATUS_CANCELING],
+                           [None, None])
+    self.assertRaises(errors.OpExecError, cli.GenericPollJob, job_id, cbs, cbs)
+    cbs.CheckEmpty()
+
+
 if __name__ == '__main__':
   testutils.GanetiTestProgram()
-- 
GitLab