From 697f49d5d0410a772ea9d942d83093327c5695ba Mon Sep 17 00:00:00 2001
From: Michael Hanselmann <hansmi@google.com>
Date: Fri, 8 Jul 2011 21:33:16 +0200
Subject: [PATCH] ht: Add new check for numbers

Places which receive floats can usually also deal with integers, e.g.
OpTestDelay. Tests are added and the new check function is used for the
aforementioned opcode and verifying query results.

Signed-off-by: Michael Hanselmann <hansmi@google.com>
Reviewed-by: Iustin Pop <iustin@google.com>
---
 lib/ht.py                  |  3 +++
 lib/opcodes.py             |  2 +-
 lib/query.py               |  2 +-
 test/ganeti.ht_unittest.py | 10 ++++++++++
 4 files changed, 15 insertions(+), 2 deletions(-)

diff --git a/lib/ht.py b/lib/ht.py
index 57dcdcaf2..54b3ea07e 100644
--- a/lib/ht.py
+++ b/lib/ht.py
@@ -178,6 +178,9 @@ TPositiveInt = TAnd(TInt, lambda v: v >= 0)
 #: a strictly positive integer
 TStrictPositiveInt = TAnd(TInt, lambda v: v > 0)
 
+#: Number
+TNumber = TOr(TInt, TFloat)
+
 
 def TListOf(my_type):
   """Checks if a given value is a list with all elements of the same type.
diff --git a/lib/opcodes.py b/lib/opcodes.py
index b89444d3e..61f29e81a 100644
--- a/lib/opcodes.py
+++ b/lib/opcodes.py
@@ -1206,7 +1206,7 @@ class OpTestDelay(OpCode):
   """
   OP_DSC_FIELD = "duration"
   OP_PARAMS = [
-    ("duration", ht.NoDefault, ht.TFloat),
+    ("duration", ht.NoDefault, ht.TNumber),
     ("on_master", True, ht.TBool),
     ("on_nodes", ht.EmptyList, ht.TListOf(ht.TNonEmptyString)),
     ("repeat", 0, ht.TPositiveInt)
diff --git a/lib/query.py b/lib/query.py
index 1bb014d14..84a8394eb 100644
--- a/lib/query.py
+++ b/lib/query.py
@@ -101,7 +101,7 @@ _VERIFY_FN = {
   QFT_BOOL: ht.TBool,
   QFT_NUMBER: ht.TInt,
   QFT_UNIT: ht.TInt,
-  QFT_TIMESTAMP: ht.TOr(ht.TInt, ht.TFloat),
+  QFT_TIMESTAMP: ht.TNumber,
   QFT_OTHER: lambda _: True,
   }
 
diff --git a/test/ganeti.ht_unittest.py b/test/ganeti.ht_unittest.py
index 34ae6715e..a2cdf3b04 100755
--- a/test/ganeti.ht_unittest.py
+++ b/test/ganeti.ht_unittest.py
@@ -53,6 +53,7 @@ class TestTypeChecks(unittest.TestCase):
   def testInt(self):
     for val in [-100, -3, 0, 16, 128, 923874]:
       self.assertTrue(ht.TInt(val))
+      self.assertTrue(ht.TNumber(val))
 
     for val in [False, True, None, "", [], "Hello", 0.0, 0.23, -3818.163]:
       self.assertFalse(ht.TInt(val))
@@ -76,10 +77,19 @@ class TestTypeChecks(unittest.TestCase):
   def testFloat(self):
     for val in [-100.21, -3.0, 0.0, 16.12, 128.3433, 923874.928]:
       self.assertTrue(ht.TFloat(val))
+      self.assertTrue(ht.TNumber(val))
 
     for val in [False, True, None, "", [], "Hello", 0, 28, -1, -3281]:
       self.assertFalse(ht.TFloat(val))
 
+  def testNumber(self):
+    for val in [-100, -3, 0, 16, 128, 923874,
+                -100.21, -3.0, 0.0, 16.12, 128.3433, 923874.928]:
+      self.assertTrue(ht.TNumber(val))
+
+    for val in [False, True, None, "", [], "Hello", "1"]:
+      self.assertFalse(ht.TNumber(val))
+
   def testString(self):
     for val in ["", "abc", "Hello World", "123",
                 u"", u"\u272C", u"abc"]:
-- 
GitLab