diff --git a/lib/query.py b/lib/query.py index b922bf97440e2155f6f81f71df1b2393e38fe9a5..01b434b00b1112590d6016404a58ec8c1f64bb12 100644 --- a/lib/query.py +++ b/lib/query.py @@ -269,13 +269,16 @@ class _FilterHints: if op != qlang.OP_OR: self._NeedAllNames() - def NoteUnaryOp(self, op): # pylint: disable=W0613 + def NoteUnaryOp(self, op, datakind): # pylint: disable=W0613 """Called when handling an unary operation. @type op: string @param op: Operator """ + if datakind is not None: + self._datakinds.add(datakind) + self._NeedAllNames() def NoteBinaryOp(self, op, datakind, name, value): @@ -557,19 +560,22 @@ class _FilterCompilerHelper: """ assert op_fn is None - if hints_fn: - hints_fn(op) - if len(operands) != 1: raise errors.ParameterError("Unary operator '%s' expects exactly one" " operand" % op) if op == qlang.OP_TRUE: - (_, _, _, retrieval_fn) = self._LookupField(operands[0]) + (_, datakind, _, retrieval_fn) = self._LookupField(operands[0]) + + if hints_fn: + hints_fn(op, datakind) op_fn = operator.truth arg = retrieval_fn elif op == qlang.OP_NOT: + if hints_fn: + hints_fn(op, None) + op_fn = operator.not_ arg = self._Compile(operands[0], level + 1) else: diff --git a/test/ganeti.query_unittest.py b/test/ganeti.query_unittest.py index f1e848b60b8f95dbdb9e8dd0732f88ed36fb64e0..024102aac3982bfe5e0682f65704599859c0983b 100755 --- a/test/ganeti.query_unittest.py +++ b/test/ganeti.query_unittest.py @@ -1521,6 +1521,19 @@ class TestQueryFilter(unittest.TestCase): self.assertEqual(q.RequestedData(), set([DK_B])) self.assertEqual(q.Query(data), [[]]) + # Data type in boolean operator + q = query.Query(fielddefs, [], namefield="name", + qfilter=["?", "name"]) + self.assertTrue(q.RequestedNames() is None) + self.assertEqual(q.RequestedData(), set([DK_A])) + self.assertEqual(q.Query(data), [[], [], []]) + + q = query.Query(fielddefs, [], namefield="name", + qfilter=["!", ["?", "name"]]) + self.assertTrue(q.RequestedNames() is None) + self.assertEqual(q.RequestedData(), set([DK_A])) + self.assertEqual(q.Query(data), []) + def testFilterContains(self): fielddefs = query._PrepareFieldList([ (query._MakeField("name", "Name", constants.QFT_TEXT, "Name"),