-
Manuel Franceschini authored
Signed-off-by:
Manuel Franceschini <livewire@google.com> Reviewed-by:
Iustin Pop <iustin@google.com>
b43dcc5a
ganeti.rpc_unittest.py 10.00 KiB
#!/usr/bin/python
#
# Copyright (C) 2010 Google Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
# 02110-1301, USA.
"""Script for testing ganeti.rpc"""
import os
import sys
import unittest
from ganeti import constants
from ganeti import compat
from ganeti import rpc
from ganeti import http
from ganeti import errors
from ganeti import serializer
import testutils
class TestTimeouts(unittest.TestCase):
def test(self):
names = [name[len("call_"):] for name in dir(rpc.RpcRunner)
if name.startswith("call_")]
self.assertEqual(len(names), len(rpc._TIMEOUTS))
self.assertFalse([name for name in names
if not (rpc._TIMEOUTS[name] is None or
rpc._TIMEOUTS[name] > 0)])
class FakeHttpPool:
def __init__(self, response_fn):
self._response_fn = response_fn
self.reqcount = 0
def ProcessRequests(self, reqs):
for req in reqs:
self.reqcount += 1
self._response_fn(req)
def GetFakeSimpleStoreClass(fn):
class FakeSimpleStore:
GetNodePrimaryIPList = fn
GetPrimaryIPFamily = lambda _: None
return FakeSimpleStore
class TestClient(unittest.TestCase):
def _FakeAddressLookup(self, map):
return lambda node_list: [map.get(node) for node in node_list]
def _GetVersionResponse(self, req):
self.assertEqual(req.host, "localhost")
self.assertEqual(req.port, 24094)
self.assertEqual(req.path, "/version")
req.success = True
req.resp_status_code = http.HTTP_OK
req.resp_body = serializer.DumpJson((True, 123))
def testVersionSuccess(self):
fn = self._FakeAddressLookup({"localhost": "localhost"})
client = rpc.Client("version", None, 24094, address_lookup_fn=fn)
client.ConnectNode("localhost")
pool = FakeHttpPool(self._GetVersionResponse)
result = client.GetResults(http_pool=pool)
self.assertEqual(result.keys(), ["localhost"])
lhresp = result["localhost"]
self.assertFalse(lhresp.offline)
self.assertEqual(lhresp.node, "localhost")
self.assertFalse(lhresp.fail_msg)
self.assertEqual(lhresp.payload, 123)
self.assertEqual(lhresp.call, "version")
lhresp.Raise("should not raise")
self.assertEqual(pool.reqcount, 1)
def _GetMultiVersionResponse(self, req):
self.assert_(req.host.startswith("node"))
self.assertEqual(req.port, 23245)
self.assertEqual(req.path, "/version")
req.success = True
req.resp_status_code = http.HTTP_OK
req.resp_body = serializer.DumpJson((True, 987))
def testMultiVersionSuccess(self):
nodes = ["node%s" % i for i in range(50)]
fn = self._FakeAddressLookup(dict(zip(nodes, nodes)))
client = rpc.Client("version", None, 23245, address_lookup_fn=fn)
client.ConnectList(nodes)
pool = FakeHttpPool(self._GetMultiVersionResponse)
result = client.GetResults(http_pool=pool)
self.assertEqual(sorted(result.keys()), sorted(nodes))
for name in nodes:
lhresp = result[name]
self.assertFalse(lhresp.offline)
self.assertEqual(lhresp.node, name)
self.assertFalse(lhresp.fail_msg)
self.assertEqual(lhresp.payload, 987)
self.assertEqual(lhresp.call, "version")
lhresp.Raise("should not raise")
self.assertEqual(pool.reqcount, len(nodes))
def _GetVersionResponseFail(self, req):
self.assertEqual(req.path, "/version")
req.success = True
req.resp_status_code = http.HTTP_OK
req.resp_body = serializer.DumpJson((False, "Unknown error"))
def testVersionFailure(self):
lookup_map = {"aef9ur4i.example.com": "aef9ur4i.example.com"}
fn = self._FakeAddressLookup(lookup_map)
client = rpc.Client("version", None, 5903, address_lookup_fn=fn)
client.ConnectNode("aef9ur4i.example.com")
pool = FakeHttpPool(self._GetVersionResponseFail)
result = client.GetResults(http_pool=pool)
self.assertEqual(result.keys(), ["aef9ur4i.example.com"])
lhresp = result["aef9ur4i.example.com"]
self.assertFalse(lhresp.offline)
self.assertEqual(lhresp.node, "aef9ur4i.example.com")
self.assert_(lhresp.fail_msg)
self.assertFalse(lhresp.payload)
self.assertEqual(lhresp.call, "version")
self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
self.assertEqual(pool.reqcount, 1)
def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
self.assertEqual(req.path, "/vg_list")
self.assertEqual(req.port, 15165)
if req.host in httperrnodes:
req.success = False
req.error = "Node set up for HTTP errors"
elif req.host in failnodes:
req.success = True
req.resp_status_code = 404
req.resp_body = serializer.DumpJson({
"code": 404,
"message": "Method not found",
"explain": "Explanation goes here",
})
else:
req.success = True
req.resp_status_code = http.HTTP_OK
req.resp_body = serializer.DumpJson((True, hash(req.host)))
def testHttpError(self):
nodes = ["uaf6pbbv%s" % i for i in range(50)]
fn = self._FakeAddressLookup(dict(zip(nodes, nodes)))
httperrnodes = set(nodes[1::7])
self.assertEqual(len(httperrnodes), 7)
failnodes = set(nodes[2::3]) - httperrnodes
self.assertEqual(len(failnodes), 14)
self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
client = rpc.Client("vg_list", None, 15165, address_lookup_fn=fn)
client.ConnectList(nodes)
pool = FakeHttpPool(compat.partial(self._GetHttpErrorResponse,
httperrnodes, failnodes))
result = client.GetResults(http_pool=pool)
self.assertEqual(sorted(result.keys()), sorted(nodes))
for name in nodes:
lhresp = result[name]
self.assertFalse(lhresp.offline)
self.assertEqual(lhresp.node, name)
self.assertEqual(lhresp.call, "vg_list")
if name in httperrnodes:
self.assert_(lhresp.fail_msg)
self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
elif name in failnodes:
self.assert_(lhresp.fail_msg)
self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
prereq=True, ecode=errors.ECODE_INVAL)
else:
self.assertFalse(lhresp.fail_msg)
self.assertEqual(lhresp.payload, hash(name))
lhresp.Raise("should not raise")
self.assertEqual(pool.reqcount, len(nodes))
def _GetInvalidResponseA(self, req):
self.assertEqual(req.path, "/version")
req.success = True
req.resp_status_code = http.HTTP_OK
req.resp_body = serializer.DumpJson(("This", "is", "an", "invalid",
"response", "!", 1, 2, 3))
def _GetInvalidResponseB(self, req):
self.assertEqual(req.path, "/version")
req.success = True
req.resp_status_code = http.HTTP_OK
req.resp_body = serializer.DumpJson("invalid response")
def testInvalidResponse(self):
lookup_map = {"oqo7lanhly.example.com": "oqo7lanhly.example.com"}
fn = self._FakeAddressLookup(lookup_map)
client = rpc.Client("version", None, 19978, address_lookup_fn=fn)
for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
client.ConnectNode("oqo7lanhly.example.com")
pool = FakeHttpPool(fn)
result = client.GetResults(http_pool=pool)
self.assertEqual(result.keys(), ["oqo7lanhly.example.com"])
lhresp = result["oqo7lanhly.example.com"]
self.assertFalse(lhresp.offline)
self.assertEqual(lhresp.node, "oqo7lanhly.example.com")
self.assert_(lhresp.fail_msg)
self.assertFalse(lhresp.payload)
self.assertEqual(lhresp.call, "version")
self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
self.assertEqual(pool.reqcount, 1)
def testAddressLookupSimpleStore(self):
addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
node_addr_list = [ " ".join(t) for t in zip(node_list, addr_list)]
ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
result = rpc._AddressLookup(node_list, ssc=ssc)
self.assertEqual(result, addr_list)
def testAddressLookupNSLookup(self):
addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
ssc = GetFakeSimpleStoreClass(lambda _: [])
node_addr_map = dict(zip(node_list, addr_list))
nslookup_fn = lambda name, family=None: node_addr_map.get(name)
result = rpc._AddressLookup(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
self.assertEqual(result, addr_list)
def testAddressLookupBoth(self):
addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
n = len(addr_list) / 2
node_addr_list = [ " ".join(t) for t in zip(node_list[n:], addr_list[n:])]
ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
nslookup_fn = lambda name, family=None: node_addr_map.get(name)
result = rpc._AddressLookup(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
self.assertEqual(result, addr_list)
def testAddressLookupIPv6(self):
addr_list = ["2001:db8::%d" % n for n in range(0, 255, 13)]
node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
node_addr_list = [ " ".join(t) for t in zip(node_list, addr_list)]
ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
result = rpc._AddressLookup(node_list, ssc=ssc)
self.assertEqual(result, addr_list)
if __name__ == "__main__":
testutils.GanetiTestProgram()