Commit 8b489e97 authored by Christos Stavrakakis's avatar Christos Stavrakakis
Browse files

snf_django: Handle unicode errors in JSON bodies

Catch unicode errors when doing JSON de-serialization of the body of a request.
Also, rename 'get_request_dict' function to 'get_json_body', as it is a
more representative name for what the function is doing.
parent 5bd8300e
...@@ -91,7 +91,7 @@ def authenticate(request): ...@@ -91,7 +91,7 @@ def authenticate(request):
d = defaultdict(dict) d = defaultdict(dict)
if not public_mode: if not public_mode:
req = utils.get_request_dict(request) req = utils.get_json_body(request)
uuid = None uuid = None
try: try:
......
...@@ -545,7 +545,8 @@ class TokensApiTest(TestCase): ...@@ -545,7 +545,8 @@ class TokensApiTest(TestCase):
r = client.post(url, "not json", content_type='application/json') r = client.post(url, "not json", content_type='application/json')
self.assertEqual(r.status_code, 400) self.assertEqual(r.status_code, 400)
body = json.loads(r.content) body = json.loads(r.content)
self.assertEqual(body['badRequest']['message'], 'Invalid JSON data') self.assertEqual(body['badRequest']['message'],
'Could not decode request body as JSON')
# Check auth with token # Check auth with token
post_data = """{"auth":{"token": {"id":"%s"}, post_data = """{"auth":{"token": {"id":"%s"},
......
...@@ -61,7 +61,7 @@ def get_cyclades_stats(request): ...@@ -61,7 +61,7 @@ def get_cyclades_stats(request):
images = True images = True
backend = None backend = None
if request.body: if request.body:
req = utils.get_request_dict(request) req = utils.get_json_body(request)
req_stats = utils.get_attribute(req, "stats", required=True, req_stats = utils.get_attribute(req, "stats", required=True,
attr_type=dict) attr_type=dict)
# Check backend # Check backend
......
...@@ -93,7 +93,7 @@ def floating_ip_demux(request, floating_ip_id): ...@@ -93,7 +93,7 @@ def floating_ip_demux(request, floating_ip_id):
serializations=["json"]) serializations=["json"])
def floating_ip_action_demux(request, floating_ip_id): def floating_ip_action_demux(request, floating_ip_id):
userid = request.user_uniq userid = request.user_uniq
req = utils.get_request_dict(request) req = utils.get_json_body(request)
log.debug('floating_ip_action %s %s', floating_ip_id, req) log.debug('floating_ip_action %s %s', floating_ip_id, req)
if len(req) != 1: if len(req) != 1:
raise faults.BadRequest('Malformed request.') raise faults.BadRequest('Malformed request.')
...@@ -165,7 +165,7 @@ def get_floating_ip(request, floating_ip_id): ...@@ -165,7 +165,7 @@ def get_floating_ip(request, floating_ip_id):
@transaction.commit_on_success @transaction.commit_on_success
def allocate_floating_ip(request): def allocate_floating_ip(request):
"""Allocate a floating IP.""" """Allocate a floating IP."""
req = utils.get_request_dict(request) req = utils.get_json_body(request)
floating_ip_dict = api.utils.get_attribute(req, "floatingip", floating_ip_dict = api.utils.get_attribute(req, "floatingip",
required=True, attr_type=dict) required=True, attr_type=dict)
userid = request.user_uniq userid = request.user_uniq
...@@ -225,7 +225,7 @@ def update_floating_ip(request, floating_ip_id): ...@@ -225,7 +225,7 @@ def update_floating_ip(request, floating_ip_id):
#userid = request.user_uniq #userid = request.user_uniq
#log.info("update_floating_ip '%s'. User '%s'.", floating_ip_id, userid) #log.info("update_floating_ip '%s'. User '%s'.", floating_ip_id, userid)
#req = utils.get_request_dict(request) #req = utils.get_json_body(request)
#info = api.utils.get_attribute(req, "floatingip", required=True) #info = api.utils.get_attribute(req, "floatingip", required=True)
#device_id = api.utils.get_attribute(info, "device_id", required=False) #device_id = api.utils.get_attribute(info, "device_id", required=False)
......
...@@ -247,7 +247,7 @@ def update_metadata(request, image_id): ...@@ -247,7 +247,7 @@ def update_metadata(request, image_id):
# badMediaType(415), # badMediaType(415),
# overLimit (413) # overLimit (413)
req = utils.get_request_dict(request) req = utils.get_json_body(request)
log.info('update_image_metadata %s %s', image_id, req) log.info('update_image_metadata %s %s', image_id, req)
with backend.PlanktonBackend(request.user_uniq) as b: with backend.PlanktonBackend(request.user_uniq) as b:
image = b.get_image(image_id) image = b.get_image(image_id)
...@@ -296,7 +296,7 @@ def create_metadata_item(request, image_id, key): ...@@ -296,7 +296,7 @@ def create_metadata_item(request, image_id, key):
# badMediaType(415), # badMediaType(415),
# overLimit (413) # overLimit (413)
req = utils.get_request_dict(request) req = utils.get_json_body(request)
log.info('create_image_metadata_item %s %s %s', image_id, key, req) log.info('create_image_metadata_item %s %s %s', image_id, key, req)
try: try:
metadict = req['meta'] metadict = req['meta']
......
...@@ -86,7 +86,7 @@ def network_demux(request, network_id): ...@@ -86,7 +86,7 @@ def network_demux(request, network_id):
@api.api_method(http_method='POST', user_required=True, logger=log) @api.api_method(http_method='POST', user_required=True, logger=log)
def network_action_demux(request, network_id): def network_action_demux(request, network_id):
req = utils.get_request_dict(request) req = utils.get_json_body(request)
network = util.get_network(network_id, request.user_uniq, for_update=True) network = util.get_network(network_id, request.user_uniq, for_update=True)
action = req.keys()[0] action = req.keys()[0]
try: try:
...@@ -126,7 +126,7 @@ def list_networks(request, detail=True): ...@@ -126,7 +126,7 @@ def list_networks(request, detail=True):
@api.api_method(http_method='POST', user_required=True, logger=log) @api.api_method(http_method='POST', user_required=True, logger=log)
def create_network(request): def create_network(request):
userid = request.user_uniq userid = request.user_uniq
req = api.utils.get_request_dict(request) req = api.utils.get_json_body(request)
log.info('create_network user: %s request: %s', userid, req) log.info('create_network user: %s request: %s', userid, req)
network_dict = api.utils.get_attribute(req, "network", network_dict = api.utils.get_attribute(req, "network",
...@@ -163,7 +163,7 @@ def get_network_details(request, network_id): ...@@ -163,7 +163,7 @@ def get_network_details(request, network_id):
@api.api_method(http_method='PUT', user_required=True, logger=log) @api.api_method(http_method='PUT', user_required=True, logger=log)
def update_network(request, network_id): def update_network(request, network_id):
info = api.utils.get_request_dict(request) info = api.utils.get_json_body(request)
network = api.utils.get_attribute(info, "network", attr_type=dict, network = api.utils.get_attribute(info, "network", attr_type=dict,
required=True) required=True)
......
...@@ -104,7 +104,7 @@ def list_ports(request, detail=True): ...@@ -104,7 +104,7 @@ def list_ports(request, detail=True):
@transaction.commit_on_success @transaction.commit_on_success
def create_port(request): def create_port(request):
user_id = request.user_uniq user_id = request.user_uniq
req = api.utils.get_request_dict(request) req = api.utils.get_json_body(request)
log.info('create_port user: %s request: %s', user_id, req) log.info('create_port user: %s request: %s', user_id, req)
port_dict = api.utils.get_attribute(req, "port", attr_type=dict) port_dict = api.utils.get_attribute(req, "port", attr_type=dict)
...@@ -200,7 +200,7 @@ def update_port(request, port_id): ...@@ -200,7 +200,7 @@ def update_port(request, port_id):
You can update only name, security_groups You can update only name, security_groups
''' '''
port = util.get_port(port_id, request.user_uniq, for_update=True) port = util.get_port(port_id, request.user_uniq, for_update=True)
req = api.utils.get_request_dict(request) req = api.utils.get_json_body(request)
port_info = api.utils.get_attribute(req, "port", required=True, port_info = api.utils.get_attribute(req, "port", required=True,
attr_type=dict) attr_type=dict)
......
...@@ -401,7 +401,7 @@ def create_server(request): ...@@ -401,7 +401,7 @@ def create_server(request):
# badRequest (400), # badRequest (400),
# serverCapacityUnavailable (503), # serverCapacityUnavailable (503),
# overLimit (413) # overLimit (413)
req = utils.get_request_dict(request) req = utils.get_json_body(request)
user_id = request.user_uniq user_id = request.user_uniq
log.info('create_server user: %s request: %s', user_id, req) log.info('create_server user: %s request: %s', user_id, req)
...@@ -539,7 +539,7 @@ def update_server_name(request, server_id): ...@@ -539,7 +539,7 @@ def update_server_name(request, server_id):
# buildInProgress (409), # buildInProgress (409),
# overLimit (413) # overLimit (413)
req = utils.get_request_dict(request) req = utils.get_json_body(request)
log.info('update_server_name %s %s', server_id, req) log.info('update_server_name %s %s', server_id, req)
req = utils.get_attribute(req, "server", attr_type=dict, required=True) req = utils.get_attribute(req, "server", attr_type=dict, required=True)
...@@ -591,7 +591,7 @@ def key_to_action(key): ...@@ -591,7 +591,7 @@ def key_to_action(key):
@api.api_method(http_method='POST', user_required=True, logger=log) @api.api_method(http_method='POST', user_required=True, logger=log)
@transaction.commit_on_success @transaction.commit_on_success
def demux_server_action(request, server_id): def demux_server_action(request, server_id):
req = utils.get_request_dict(request) req = utils.get_json_body(request)
log.debug('server_action %s %s', server_id, req) log.debug('server_action %s %s', server_id, req)
if not isinstance(req, dict) and len(req) != 1: if not isinstance(req, dict) and len(req) != 1:
...@@ -690,7 +690,7 @@ def update_metadata(request, server_id): ...@@ -690,7 +690,7 @@ def update_metadata(request, server_id):
# badMediaType(415), # badMediaType(415),
# overLimit (413) # overLimit (413)
req = utils.get_request_dict(request) req = utils.get_json_body(request)
log.info('update_server_metadata %s %s', server_id, req) log.info('update_server_metadata %s %s', server_id, req)
vm = util.get_vm(server_id, request.user_uniq, non_suspended=True) vm = util.get_vm(server_id, request.user_uniq, non_suspended=True)
metadata = utils.get_attribute(req, "metadata", required=True, metadata = utils.get_attribute(req, "metadata", required=True,
...@@ -739,7 +739,7 @@ def create_metadata_item(request, server_id, key): ...@@ -739,7 +739,7 @@ def create_metadata_item(request, server_id, key):
# badMediaType(415), # badMediaType(415),
# overLimit (413) # overLimit (413)
req = utils.get_request_dict(request) req = utils.get_json_body(request)
log.info('create_server_metadata_item %s %s %s', server_id, key, req) log.info('create_server_metadata_item %s %s %s', server_id, key, req)
vm = util.get_vm(server_id, request.user_uniq, non_suspended=True) vm = util.get_vm(server_id, request.user_uniq, non_suspended=True)
try: try:
...@@ -1095,7 +1095,7 @@ def get_volume_info(request, server_id, volume_id): ...@@ -1095,7 +1095,7 @@ def get_volume_info(request, server_id, volume_id):
@api.api_method(http_method='POST', user_required=True, logger=log) @api.api_method(http_method='POST', user_required=True, logger=log)
def attach_volume(request, server_id): def attach_volume(request, server_id):
req = utils.get_request_dict(request) req = utils.get_json_body(request)
log.debug("attach_volume server_id %s request", server_id, req) log.debug("attach_volume server_id %s request", server_id, req)
user_id = request.user_uniq user_id = request.user_uniq
vm = util.get_vm(server_id, user_id, for_update=True) vm = util.get_vm(server_id, user_id, for_update=True)
......
...@@ -101,7 +101,7 @@ def create_subnet(request): ...@@ -101,7 +101,7 @@ def create_subnet(request):
network_id and the desired cidr are mandatory, everything else is optional network_id and the desired cidr are mandatory, everything else is optional
""" """
dictionary = utils.get_request_dict(request) dictionary = utils.get_json_body(request)
user_id = request.user_uniq user_id = request.user_uniq
log.info('create subnet user: %s request: %s', user_id, dictionary) log.info('create subnet user: %s request: %s', user_id, dictionary)
...@@ -188,7 +188,7 @@ def update_subnet(request, sub_id): ...@@ -188,7 +188,7 @@ def update_subnet(request, sub_id):
""" """
dictionary = utils.get_request_dict(request) dictionary = utils.get_json_body(request)
user_id = request.user_uniq user_id = request.user_uniq
try: try:
......
...@@ -415,13 +415,15 @@ def update_image_members(request, image_id): ...@@ -415,13 +415,15 @@ def update_image_members(request, image_id):
""" """
log.debug('update_image_members %s', image_id) log.debug('update_image_members %s', image_id)
data = api.utils.get_json_body(request)
members = [] members = []
try:
data = json.loads(request.body) memberships = api.utils.get_attribute(data, "memberships", attr_type=list)
for member in data['memberships']: for member in memberships:
members.append(member['member_id']) if not isinstance(member, dict):
except (ValueError, KeyError, TypeError): raise faults.BadRequest("Invalid 'memberships' field")
return HttpResponse(status=400) member = api.utils.get_attribute(member, "member_id")
members.append(member)
with PlanktonBackend(request.user_uniq) as backend: with PlanktonBackend(request.user_uniq) as backend:
backend.replace_users(image_id, members) backend.replace_users(image_id, members)
......
...@@ -100,7 +100,7 @@ def get_volume_attachments(volume): ...@@ -100,7 +100,7 @@ def get_volume_attachments(volume):
def create_volume(request): def create_volume(request):
"""Create a new Volume.""" """Create a new Volume."""
req = utils.get_request_dict(request) req = utils.get_json_body(request)
log.debug("create_volume %s", req) log.debug("create_volume %s", req)
user_id = request.user_uniq user_id = request.user_uniq
...@@ -196,7 +196,7 @@ def get_volume(request, volume_id): ...@@ -196,7 +196,7 @@ def get_volume(request, volume_id):
@api.api_method(http_method="PUT", user_required=True, logger=log) @api.api_method(http_method="PUT", user_required=True, logger=log)
def update_volume(request, volume_id): def update_volume(request, volume_id):
req = utils.get_request_dict(request) req = utils.get_json_body(request)
log.debug('update_volume volume_id: %s, request: %s', volume_id, req) log.debug('update_volume volume_id: %s, request: %s', volume_id, req)
volume = util.get.volume(request.user_uniq, volume_id, for_update=True) volume = util.get.volume(request.user_uniq, volume_id, for_update=True)
...@@ -243,7 +243,7 @@ def snapshot_to_dict(snapshot, detail=True): ...@@ -243,7 +243,7 @@ def snapshot_to_dict(snapshot, detail=True):
def create_snapshot(request): def create_snapshot(request):
"""Create a new Snapshot.""" """Create a new Snapshot."""
req = utils.get_request_dict(request) req = utils.get_json_body(request)
log.debug("create_snapshot %s", req) log.debug("create_snapshot %s", req)
user_id = request.user_uniq user_id = request.user_uniq
...@@ -327,7 +327,7 @@ def get_snapshot(request, snapshot_id): ...@@ -327,7 +327,7 @@ def get_snapshot(request, snapshot_id):
@api.api_method(http_method="PUT", user_required=True, logger=log) @api.api_method(http_method="PUT", user_required=True, logger=log)
def update_snapshot(request, snapshot_id): def update_snapshot(request, snapshot_id):
req = utils.get_request_dict(request) req = utils.get_json_body(request)
log.debug('update_snapshot snapshot_id: %s, request: %s', snapshot_id, req) log.debug('update_snapshot snapshot_id: %s, request: %s', snapshot_id, req)
snapshot = util.get_snapshot(request.user_uniq, snapshot_id) snapshot = util.get_snapshot(request.user_uniq, snapshot_id)
......
...@@ -88,10 +88,11 @@ def isoparse(s): ...@@ -88,10 +88,11 @@ def isoparse(s):
return utc_since return utc_since
def get_request_dict(request): def get_json_body(request):
"""Return data sent by the client as python dictionary. """Get the JSON request body as a Python object.
Only JSON format is supported Check that the content type is json and deserialize the body of the
request that contains a JSON document to a Python object.
""" """
data = request.body data = request.body
...@@ -101,8 +102,10 @@ def get_request_dict(request): ...@@ -101,8 +102,10 @@ def get_request_dict(request):
if content_type.startswith("application/json"): if content_type.startswith("application/json"):
try: try:
return json.loads(data) return json.loads(data)
except UnicodeDecodeError:
raise faults.BadRequest("Could not decode request as UTF-8 string")
except ValueError: except ValueError:
raise faults.BadRequest("Invalid JSON data") raise faults.BadRequest("Could not decode request body as JSON")
else: else:
raise faults.BadRequest("Unsupported Content-type: '%s'" % raise faults.BadRequest("Unsupported Content-type: '%s'" %
content_type) content_type)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment