Commit e818eff8 authored by Christos Stavrakakis's avatar Christos Stavrakakis
Browse files

cyclades: Validate fields of API requests

Make sure that all fields of a request (user input) have a valid type.
This commit validates 'ports', 'servers' and 'floating_ips' APIs.

Refs #4979
parent f807cbfb
......@@ -140,7 +140,7 @@ def allocate_floating_ip(request):
"""Allocate a floating IP."""
req = utils.get_request_dict(request)
floating_ip_dict = api.utils.get_attribute(req, "floatingip",
required=True)
required=True, attr_type=dict)
log.info('allocate_floating_ip %s', req)
userid = request.user_uniq
......@@ -148,7 +148,8 @@ def allocate_floating_ip(request):
# the network_pool is a mandatory field
network_id = api.utils.get_attribute(floating_ip_dict,
"floating_network_id",
required=False)
required=False,
attr_type=(basestring, int))
if network_id is None:
floating_ip = ips.create_floating_ip(userid)
else:
......@@ -161,7 +162,8 @@ def allocate_floating_ip(request):
non_deleted=True)
address = api.utils.get_attribute(floating_ip_dict,
"floating_ip_address",
required=False)
required=False,
attr_type=basestring)
floating_ip = ips.create_floating_ip(userid, network, address)
log.info("User '%s' allocated floating IP '%s'", userid, floating_ip)
......
......@@ -107,22 +107,28 @@ def create_port(request):
req = api.utils.get_request_dict(request)
log.info('create_port %s', req)
port_dict = api.utils.get_attribute(req, "port")
net_id = api.utils.get_attribute(port_dict, "network_id")
port_dict = api.utils.get_attribute(req, "port", attr_type=dict)
net_id = api.utils.get_attribute(port_dict, "network_id",
attr_type=(basestring, int))
device_id = api.utils.get_attribute(port_dict, "device_id", required=False)
device_id = api.utils.get_attribute(port_dict, "device_id", required=False,
attr_type=(basestring, int))
vm = None
if device_id is not None:
vm = util.get_vm(device_id, user_id, for_update=True, non_deleted=True,
non_suspended=True)
# Check if the request contains a valid IPv4 address
fixed_ips = api.utils.get_attribute(port_dict, "fixed_ips", required=False)
fixed_ips = api.utils.get_attribute(port_dict, "fixed_ips", required=False,
attr_type=list)
if fixed_ips is not None and len(fixed_ips) > 0:
if len(fixed_ips) > 1:
msg = "'fixed_ips' attribute must contain only one fixed IP."
raise faults.BadRequest(msg)
fixed_ip_address = fixed_ips[0].get("ip_address")
fixed_ip = fixed_ips[0]
if not isinstance(fixed_ip, dict):
raise faults.BadRequest("Invalid 'fixed_ips' field.")
fixed_ip_address = fixed_ip.get("ip_address")
if fixed_ip_address is not None:
try:
ip = ipaddr.IPAddress(fixed_ip_address)
......@@ -153,19 +159,24 @@ def create_port(request):
ipaddress = ips.allocate_ip(network, user_id,
address=fixed_ip_address)
name = api.utils.get_attribute(port_dict, "name", required=False)
name = api.utils.get_attribute(port_dict, "name", required=False,
attr_type=basestring)
if name is None:
name = ""
security_groups = api.utils.get_attribute(port_dict,
"security_groups",
required=False)
required=False,
attr_type=list)
#validate security groups
# like get security group from db
sg_list = []
if security_groups:
for gid in security_groups:
sg = util.get_security_group(int(gid))
try:
sg = util.get_security_group(int(gid))
except (KeyError, ValueError):
raise faults.BadRequest("Invalid 'security_groups' field.")
sg_list.append(sg)
new_port = servers.create_port(user_id, network, use_ipaddress=ipaddress,
......@@ -191,19 +202,25 @@ def update_port(request, port_id):
port = util.get_port(port_id, request.user_uniq, for_update=True)
req = api.utils.get_request_dict(request)
port_info = api.utils.get_attribute(req, "port", required=True)
name = api.utils.get_attribute(port_info, "name", required=False)
port_info = api.utils.get_attribute(req, "port", required=True,
attr_type=dict)
name = api.utils.get_attribute(port_info, "name", required=False,
attr_type=basestring)
if name:
port.name = name
security_groups = api.utils.get_attribute(port_info, "security_groups",
required=False)
required=False, attr_type=list)
if security_groups:
sg_list = []
#validate security groups
for gid in security_groups:
sg = util.get_security_group(int(gid))
try:
sg = util.get_security_group(int(gid))
except (KeyError, ValueError):
raise faults.BadRequest("Invalid 'security_groups' field.")
sg_list.append(sg)
#clear the old security groups
......
......@@ -447,10 +447,9 @@ def update_server_name(request, server_id):
req = utils.get_request_dict(request)
log.info('update_server_name %s %s', server_id, req)
try:
name = req['server']['name']
except (TypeError, KeyError):
raise faults.BadRequest("Malformed request")
req = utils.get_attribute(req, "server", attr_type=dict, required=True)
name = utils.get_attribute(req, "name", attr_type=basestring,
required=True)
vm = util.get_vm(server_id, request.user_uniq, for_update=True,
non_suspended=True)
......@@ -500,7 +499,7 @@ def demux_server_action(request, server_id):
req = utils.get_request_dict(request)
log.debug('server_action %s %s', server_id, req)
if len(req) != 1:
if not isinstance(req, dict) and len(req) != 1:
raise faults.BadRequest("Malformed request")
# Do not allow any action on deleted or suspended VMs
......@@ -508,14 +507,14 @@ def demux_server_action(request, server_id):
non_deleted=True, non_suspended=True)
action = req.keys()[0]
if not isinstance(action, basestring):
raise faults.BadRequest("Malformed Request. Invalid action.")
if key_to_action(action) not in [x[0] for x in VirtualMachine.ACTIONS]:
if action not in ARBITRARY_ACTIONS:
raise faults.BadRequest("Action %s not supported" % action)
action_args = req[action]
if not isinstance(action_args, dict):
raise faults.BadRequest("Invalid argument")
action_args = utils.get_attribute(req, action, required=True,
attr_type=dict)
return server_actions[action](request, vm, action_args)
......@@ -530,7 +529,8 @@ def list_addresses(request, server_id):
# overLimit (413)
log.debug('list_addresses %s', server_id)
vm = util.get_vm(server_id, request.user_uniq, prefetch_related="nics__ips")
vm = util.get_vm(server_id, request.user_uniq,
prefetch_related="nics__ips")
attachments = [nic_to_attachments(nic)
for nic in vm.nics.filter(state="ACTIVE")]
addresses = attachments_to_addresses(attachments)
......@@ -598,13 +598,13 @@ def update_metadata(request, server_id):
req = utils.get_request_dict(request)
log.info('update_server_metadata %s %s', server_id, req)
vm = util.get_vm(server_id, request.user_uniq, non_suspended=True)
try:
metadata = req['metadata']
assert isinstance(metadata, dict)
except (KeyError, AssertionError):
raise faults.BadRequest("Malformed request")
metadata = utils.get_attribute(req, "metadata", required=True,
attr_type=dict)
for key, val in metadata.items():
if not isinstance(key, (basestring, int)) or\
not isinstance(val, (basestring, int)):
raise faults.BadRequest("Malformed Request. Invalid metadata.")
meta, created = vm.metadata.get_or_create(meta_key=key)
meta.meta_value = val
meta.save()
......
......@@ -117,7 +117,8 @@ class PortTest(BaseAPITest):
# self.assertEqual(res['port']['security_groups'],
# [str(sg2.id), str(sg3.id)])
def test_create_port_no_network(self):
def test_create_port_invalid(self):
# No network
request = {
"port": {
"device_id": "123",
......@@ -127,6 +128,16 @@ class PortTest(BaseAPITest):
}
response = self.post(PORTS_URL, params=json.dumps(request))
self.assertEqual(response.status_code, 404)
net = dbmf.NetworkFactory(public=True)
request = {
"port": {
"name": "port1",
"network_id": net.id,
"fixed_ips": ["lala"]
}
}
response = self.post(PORTS_URL, params=json.dumps(request))
self.assertEqual(response.status_code, 400, response.content)
@patch("synnefo.db.models.get_rapi_client")
def test_create_port_private_net(self, mrapi):
......
......@@ -544,6 +544,18 @@ class ServerCreateAPITest(ComputeAPITest):
self.assertEqual(len(vm.nics.all()), 3)
def test_create_server_with_port(self, mrapi):
# Test invalid networks
request = deepcopy(self.request)
request["server"]["networks"] = {"foo": "lala"}
with override_settings(settings, **self.network_settings):
response = self.mypost("servers", "dummy_user",
json.dumps(request), 'json')
self.assertBadRequest(response)
request["server"]["networks"] = ['1', '2']
with override_settings(settings, **self.network_settings):
response = self.mypost("servers", "dummy_user",
json.dumps(request), 'json')
self.assertBadRequest(response)
mrapi().CreateInstance.return_value = 42
ip = mfactory.IPv4AddressFactory(nic__machine=None)
port1 = ip.nic
......
......@@ -699,10 +699,14 @@ def create_ports_for_request(user_id, networks):
IPs.
"""
if not isinstance(networks, list):
raise faults.BadRequest("Malformed request. Invalid 'networks' field")
return [_port_for_request(user_id, network) for network in networks]
def _port_for_request(user_id, network_dict):
if not isinstance(network_dict, dict):
raise faults.BadRequest("Malformed request. Invalid 'networks' field")
port_id = network_dict.get("port")
network_id = network_dict.get("uuid")
if port_id is not None:
......
......@@ -135,9 +135,13 @@ def filter_modified_since(request, objects):
return objects.filter(deleted=False)
def get_attribute(request, attribute, required=True):
def get_attribute(request, attribute, attr_type=None, required=True):
value = request.get(attribute, None)
if required and value is None:
raise faults.BadRequest("Malformed request. Missing attribute '%s'." %
attribute)
if attr_type is not None and value is not None\
and not isinstance(value, attr_type):
raise faults.BadRequest("Malformed request. Invalid '%s' field"
% attribute)
return value
......@@ -250,30 +250,29 @@ class BaseAPITest(TestCase):
return response
def assertSuccess(self, response):
self.assertTrue(response.status_code in [200, 202, 203, 204])
self.assertTrue(response.status_code in [200, 202, 203, 204],
msg=response.content)
def assertSuccess201(self, response):
self.assertEqual(response.status_code, 201)
self.assertEqual(response.status_code, 201, msg=response.content)
def assertFault(self, response, status_code, name):
self.assertEqual(response.status_code, status_code)
def assertFault(self, response, status_code, name, msg=''):
self.assertEqual(response.status_code, status_code,
msg=msg)
fault = json.loads(response.content)
self.assertEqual(fault.keys(), [name])
def assertBadRequest(self, response):
self.assertFault(response, 400, 'badRequest')
self.assertFault(response, 400, 'badRequest', msg=response.content)
def assertConflict(self, response):
self.assertFault(response, 409, 'conflict')
self.assertFault(response, 409, 'conflict', msg=response.content)
def assertItemNotFound(self, response):
self.assertFault(response, 404, 'itemNotFound')
def assertConflict(self, response):
self.assertFault(response, 409, "conflict")
self.assertFault(response, 404, 'itemNotFound', msg=response.content)
def assertMethodNotAllowed(self, response):
self.assertFault(response, 405, 'notAllowed')
self.assertFault(response, 405, 'notAllowed', msg=response.content)
self.assertTrue('Allow' in response)
try:
error = json.loads(response.content)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment