From fbc263a9ca2f8166f35c377fa00b52792272eb31 Mon Sep 17 00:00:00 2001
From: Michael Hanselmann <hansmi@google.com>
Date: Tue, 1 Mar 2011 18:25:04 +0100
Subject: [PATCH] query: Fix bug when names are specified
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

If the client/caller would specify names through the use of a filter,
the result would be sorted. This is a regression over earlier Ganeti
versions and verified in QA. This patch adds an optional parameter to
control the sorting and provides unittests.

Signed-off-by: Michael Hanselmann <hansmi@google.com>
Reviewed-by: RenΓ© Nussbaumer <rn@google.com>
---
 lib/cmdlib.py                 |  9 +++-
 lib/query.py                  | 27 ++++++----
 test/ganeti.query_unittest.py | 93 +++++++++++++++++++++++++++++++++++
 3 files changed, 118 insertions(+), 11 deletions(-)

diff --git a/lib/cmdlib.py b/lib/cmdlib.py
index 2b3f2e015..2e59a0b44 100644
--- a/lib/cmdlib.py
+++ b/lib/cmdlib.py
@@ -464,6 +464,9 @@ class _QueryBase:
     self.requested_data = self.query.RequestedData()
     self.names = self.query.RequestedNames()
 
+    # Sort only if no names were requested
+    self.sort_by_name = not self.names
+
     self.do_locking = None
     self.wanted = None
 
@@ -530,13 +533,15 @@ class _QueryBase:
     """Collect data and execute query.
 
     """
-    return query.GetQueryResponse(self.query, self._GetQueryData(lu))
+    return query.GetQueryResponse(self.query, self._GetQueryData(lu),
+                                  sort_by_name=self.sort_by_name)
 
   def OldStyleQuery(self, lu):
     """Collect data and execute query.
 
     """
-    return self.query.OldStyleQuery(self._GetQueryData(lu))
+    return self.query.OldStyleQuery(self._GetQueryData(lu),
+                                    sort_by_name=self.sort_by_name)
 
 
 def _GetWantedNodes(lu, nodes):
diff --git a/lib/query.py b/lib/query.py
index 40a03bec3..6497c6076 100644
--- a/lib/query.py
+++ b/lib/query.py
@@ -629,13 +629,18 @@ class Query:
     """
     return GetAllFields(self._fields)
 
-  def Query(self, ctx):
+  def Query(self, ctx, sort_by_name=True):
     """Execute a query.
 
     @param ctx: Data container passed to field retrieval functions, must
       support iteration using C{__iter__}
+    @type sort_by_name: boolean
+    @param sort_by_name: Whether to sort by name or keep the input data's
+      ordering
 
     """
+    sort = (self._name_fn and sort_by_name)
+
     result = []
 
     for idx, item in enumerate(ctx):
@@ -648,15 +653,16 @@ class Query:
       if __debug__:
         _VerifyResultRow(self._fields, row)
 
-      if self._name_fn:
+      if sort:
         (status, name) = _ProcessResult(self._name_fn(ctx, item))
         assert status == constants.RS_NORMAL
         # TODO: Are there cases where we wouldn't want to use NiceSort?
-        sortname = utils.NiceSortKey(name)
+        result.append((utils.NiceSortKey(name), idx, row))
       else:
-        sortname = None
+        result.append(row)
 
-      result.append((sortname, idx, row))
+    if not sort:
+      return result
 
     # TODO: Would "heapq" be more efficient than sorting?
 
@@ -667,7 +673,7 @@ class Query:
 
     return map(operator.itemgetter(2), result)
 
-  def OldStyleQuery(self, ctx):
+  def OldStyleQuery(self, ctx, sort_by_name=True):
     """Query with "old" query result format.
 
     See L{Query.Query} for arguments.
@@ -681,7 +687,7 @@ class Query:
                                  errors.ECODE_INVAL)
 
     return [[value for (_, value) in row]
-            for row in self.Query(ctx)]
+            for row in self.Query(ctx, sort_by_name=sort_by_name)]
 
 
 def _ProcessResult(value):
@@ -776,14 +782,17 @@ def _PrepareFieldList(fields, aliases):
   return result
 
 
-def GetQueryResponse(query, ctx):
+def GetQueryResponse(query, ctx, sort_by_name=True):
   """Prepares the response for a query.
 
   @type query: L{Query}
   @param ctx: Data container, see L{Query.Query}
+  @type sort_by_name: boolean
+  @param sort_by_name: Whether to sort by name or keep the input data's
+    ordering
 
   """
-  return objects.QueryResponse(data=query.Query(ctx),
+  return objects.QueryResponse(data=query.Query(ctx, sort_by_name=sort_by_name),
                                fields=query.GetFields()).ToDict()
 
 
diff --git a/test/ganeti.query_unittest.py b/test/ganeti.query_unittest.py
index cbbe490e8..38e67f44e 100755
--- a/test/ganeti.query_unittest.py
+++ b/test/ganeti.query_unittest.py
@@ -1118,6 +1118,99 @@ class TestQueryFilter(unittest.TestCase):
       ["node1", "node44"],
       ])
 
+    # Name field, but no sorting, result must be in incoming order
+    q = query.Query(fielddefs, ["pnode", "snode"], namefield="pnode")
+    self.assertFalse(q.RequestedData())
+    self.assertEqual(q.Query(data, sort_by_name=False),
+      [[(constants.RS_NORMAL, "node1"), (constants.RS_NORMAL, "node44")],
+       [(constants.RS_NORMAL, "node30"), (constants.RS_NORMAL, "node90")],
+       [(constants.RS_NORMAL, "node25"), (constants.RS_NORMAL, "node1")],
+       [(constants.RS_NORMAL, "node20"), (constants.RS_NORMAL, "node1")]])
+    self.assertEqual(q.OldStyleQuery(data, sort_by_name=False), [
+      ["node1", "node44"],
+      ["node30", "node90"],
+      ["node25", "node1"],
+      ["node20", "node1"],
+      ])
+    self.assertEqual(q.Query(reversed(data), sort_by_name=False),
+      [[(constants.RS_NORMAL, "node20"), (constants.RS_NORMAL, "node1")],
+       [(constants.RS_NORMAL, "node25"), (constants.RS_NORMAL, "node1")],
+       [(constants.RS_NORMAL, "node30"), (constants.RS_NORMAL, "node90")],
+       [(constants.RS_NORMAL, "node1"), (constants.RS_NORMAL, "node44")]])
+    self.assertEqual(q.OldStyleQuery(reversed(data), sort_by_name=False), [
+      ["node20", "node1"],
+      ["node25", "node1"],
+      ["node30", "node90"],
+      ["node1", "node44"],
+      ])
+
+  def testEqualNamesOrder(self):
+    fielddefs = query._PrepareFieldList([
+      (query._MakeField("pnode", "PNode", constants.QFT_TEXT, "Primary"),
+       None, 0, lambda ctx, item: item["pnode"]),
+      (query._MakeField("num", "Num", constants.QFT_NUMBER, "Num"),
+       None, 0, lambda ctx, item: item["num"]),
+      ], [])
+
+    data = [
+      { "pnode": "node1", "num": 100, },
+      { "pnode": "node1", "num": 25, },
+      { "pnode": "node2", "num": 90, },
+      { "pnode": "node2", "num": 30, },
+      ]
+
+    q = query.Query(fielddefs, ["pnode", "num"], namefield="pnode",
+                    filter_=["|", ["=", "pnode", "node1"],
+                                  ["=", "pnode", "node2"],
+                                  ["=", "pnode", "node1"]])
+    self.assertEqual(q.RequestedNames(), ["node1", "node2"],
+                     msg="Did not return unique names")
+    self.assertFalse(q.RequestedData())
+    self.assertEqual(q.Query(data),
+      [[(constants.RS_NORMAL, "node1"), (constants.RS_NORMAL, 100)],
+       [(constants.RS_NORMAL, "node1"), (constants.RS_NORMAL, 25)],
+       [(constants.RS_NORMAL, "node2"), (constants.RS_NORMAL, 90)],
+       [(constants.RS_NORMAL, "node2"), (constants.RS_NORMAL, 30)]])
+    self.assertEqual(q.Query(data, sort_by_name=False),
+      [[(constants.RS_NORMAL, "node1"), (constants.RS_NORMAL, 100)],
+       [(constants.RS_NORMAL, "node1"), (constants.RS_NORMAL, 25)],
+       [(constants.RS_NORMAL, "node2"), (constants.RS_NORMAL, 90)],
+       [(constants.RS_NORMAL, "node2"), (constants.RS_NORMAL, 30)]])
+
+    data = [
+      { "pnode": "nodeX", "num": 50, },
+      { "pnode": "nodeY", "num": 40, },
+      { "pnode": "nodeX", "num": 30, },
+      { "pnode": "nodeX", "num": 20, },
+      { "pnode": "nodeM", "num": 10, },
+      ]
+
+    q = query.Query(fielddefs, ["pnode", "num"], namefield="pnode",
+                    filter_=["|", ["=", "pnode", "nodeX"],
+                                  ["=", "pnode", "nodeY"],
+                                  ["=", "pnode", "nodeY"],
+                                  ["=", "pnode", "nodeY"],
+                                  ["=", "pnode", "nodeM"]])
+    self.assertEqual(q.RequestedNames(), ["nodeX", "nodeY", "nodeM"],
+                     msg="Did not return unique names")
+    self.assertFalse(q.RequestedData())
+
+    # First sorted by name, then input order
+    self.assertEqual(q.Query(data, sort_by_name=True),
+      [[(constants.RS_NORMAL, "nodeM"), (constants.RS_NORMAL, 10)],
+       [(constants.RS_NORMAL, "nodeX"), (constants.RS_NORMAL, 50)],
+       [(constants.RS_NORMAL, "nodeX"), (constants.RS_NORMAL, 30)],
+       [(constants.RS_NORMAL, "nodeX"), (constants.RS_NORMAL, 20)],
+       [(constants.RS_NORMAL, "nodeY"), (constants.RS_NORMAL, 40)]])
+
+    # Input order
+    self.assertEqual(q.Query(data, sort_by_name=False),
+      [[(constants.RS_NORMAL, "nodeX"), (constants.RS_NORMAL, 50)],
+       [(constants.RS_NORMAL, "nodeY"), (constants.RS_NORMAL, 40)],
+       [(constants.RS_NORMAL, "nodeX"), (constants.RS_NORMAL, 30)],
+       [(constants.RS_NORMAL, "nodeX"), (constants.RS_NORMAL, 20)],
+       [(constants.RS_NORMAL, "nodeM"), (constants.RS_NORMAL, 10)]])
+
   def testFilter(self):
     (DK_A, DK_B) = range(1000, 1002)
 
-- 
GitLab