From 04cdf663c4fb1ba296173d86c93d9de0e72d9094 Mon Sep 17 00:00:00 2001
From: Guido Trotter <ultrotter@google.com>
Date: Mon, 15 Mar 2010 11:42:12 +0000
Subject: [PATCH] ConfdCountingCallback

This new confd callback counts received replies for the registered
queries.

Signed-off-by: Guido Trotter <ultrotter@google.com>
Reviewed-by: Michael Hanselmann <hansmi@google.com>
---
 lib/confd/client.py | 62 +++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 62 insertions(+)

diff --git a/lib/confd/client.py b/lib/confd/client.py
index b9026760c..efcb68d33 100644
--- a/lib/confd/client.py
+++ b/lib/confd/client.py
@@ -432,3 +432,65 @@ class ConfdFilterCallback:
 
     if not filter_upcall:
       self._callback(up)
+
+
+class ConfdCountingCallback:
+  """Callback that calls another callback, and counts the answers
+
+  """
+  def __init__(self, callback, logger=None):
+    """Constructor for ConfdCountingCallback
+
+    @type callback: f(L{ConfdUpcallPayload})
+    @param callback: function to call when getting answers
+    @type logger: logging.Logger
+    @param logger: optional logger for internal conditions
+
+    """
+    if not callable(callback):
+      raise errors.ProgrammerError("callback must be callable")
+
+    self._callback = callback
+    self._logger = logger
+    # answers contains a dict of salt -> count
+    self._answers = {}
+
+  def RegisterQuery(self, salt):
+    if salt in self._answers:
+      raise errors.ProgrammerError("query already registered")
+    self._answers[salt] = 0
+
+  def AllAnswered(self):
+    """Have all the registered queries received at least an answer?
+
+    """
+    return utils.all(self._answers.values())
+
+  def _HandleExpire(self, up):
+    # if we have no answer we have received none, before the expiration.
+    if up.salt in self._answers:
+      del self._answers[up.salt]
+
+  def _HandleReply(self, up):
+    """Handle a single confd reply, and decide whether to filter it.
+
+    @rtype: boolean
+    @return: True if the reply should be filtered, False if it should be passed
+             on to the up-callback
+
+    """
+    if up.salt in self._answers:
+      self._answers[up.salt] += 1
+
+  def __call__(self, up):
+    """Filtering callback
+
+    @type up: L{ConfdUpcallPayload}
+    @param up: upper callback
+
+    """
+    if up.type == UPCALL_REPLY:
+      self._HandleReply(up)
+    elif up.type == UPCALL_EXPIRE:
+      self._HandleExpire(up)
+    self._callback(up)
-- 
GitLab