Commit 136532c7 authored by Christos Stavrakakis's avatar Christos Stavrakakis
Browse files

cyclades: Fix permissions when looking up ports

Look up of ports should check the port userid and not the network
userid.
parent 069f5e3e
......@@ -35,8 +35,7 @@ from optparse import make_option
from django.core.management.base import BaseCommand, CommandError
from synnefo.management.common import convert_api_faults
from synnefo.api.util import get_port
from synnefo.management import pprint
from synnefo.management import pprint, common
class Command(BaseCommand):
......@@ -63,7 +62,7 @@ class Command(BaseCommand):
if len(args) != 1:
raise CommandError("Please provide a port ID")
port = get_port(args[0], None)
port = common.get_port(args[0])
pprint.pprint_port(port, stdout=self.stdout)
self.stdout.write('\n\n')
......
......@@ -31,7 +31,6 @@
from optparse import make_option
from django.core.management.base import BaseCommand, CommandError
from synnefo.logic import servers
from synnefo.api.util import get_port
from synnefo.management import common
from snf_django.management.utils import parse_bool
......@@ -54,7 +53,7 @@ class Command(BaseCommand):
if len(args) < 1:
raise CommandError("Please provide a port ID")
port = get_port(args[0], None, for_update=True)
port = common.get_port(args[0], for_update=True)
servers.delete_port(port)
......
......@@ -61,7 +61,7 @@ class PortTest(BaseAPITest):
def test_get_port(self):
nic = dbmf.NetworkInterfaceFactory()
url = join_urls(PORTS_URL, str(nic.id))
response = self.get(url, user=nic.network.userid)
response = self.get(url, user=nic.userid)
self.assertEqual(response.status_code, 200)
@patch("synnefo.db.models.get_rapi_client")
......@@ -70,10 +70,10 @@ class PortTest(BaseAPITest):
url = join_urls(PORTS_URL, str(nic.id))
mrapi().ModifyInstance.return_value = 42
with override_settings(settings, GANETI_USE_HOTPLUG=True):
response = self.delete(url, user=nic.network.userid)
response = self.delete(url, user=nic.userid)
self.assertEqual(response.status_code, 204)
with override_settings(settings, GANETI_USE_HOTPLUG=False):
response = self.delete(url, user=nic.network.userid)
response = self.delete(url, user=nic.userid)
self.assertEqual(response.status_code, 400)
def test_remove_nic_malformed(self):
......@@ -86,7 +86,7 @@ class PortTest(BaseAPITest):
url = join_urls(PORTS_URL, str(nic.id))
request = {'port': {"name": "test-name"}}
response = self.put(url, params=json.dumps(request),
user=nic.network.userid)
user=nic.userid)
self.assertEqual(response.status_code, 200)
res = json.loads(response.content)
self.assertEqual(res['port']['name'], "test-name")
......@@ -99,7 +99,7 @@ class PortTest(BaseAPITest):
url = join_urls(PORTS_URL, str(nic.id))
request = {'port': {"security_groups": ["123"]}}
response = self.put(url, params=json.dumps(request),
user=nic.network.userid)
user=nic.userid)
self.assertEqual(response.status_code, 404)
def test_update_port_sg(self):
......@@ -112,7 +112,7 @@ class PortTest(BaseAPITest):
url = join_urls(PORTS_URL, str(nic.id))
request = {'port': {"security_groups": [str(sg2.id), str(sg3.id)]}}
response = self.put(url, params=json.dumps(request),
user=nic.network.userid)
user=nic.userid)
res = json.loads(response.content)
self.assertEqual(res['port']['security_groups'],
[str(sg2.id), str(sg3.id)])
......
......@@ -408,7 +408,8 @@ class ServerCreateAPITest(ComputeAPITest):
def test_create_server_with_port(self, mrapi):
mrapi().CreateInstance.return_value = 42
port1 = mfactory.IPv4AddressFactory(nic=None)
ip = mfactory.IPv4AddressFactory(nic__machine=None)
port1 = ip.nic
request = deepcopy(self.request)
request["server"]["networks"] = [{"port": port1.id}]
with mocked_quotaholder():
......@@ -424,7 +425,8 @@ class ServerCreateAPITest(ComputeAPITest):
json.dumps(request), 'json')
self.assertConflict(response)
# Test permissions
port2 = mfactory.IPv4AddressFactory(userid="user1")
ip = mfactory.IPv4AddressFactory(userid="user1", nic__userid="user1")
port2 = ip.nic
request["server"]["networks"] = [{"port": port2.id}]
with mocked_quotaholder():
response = self.mypost("servers", "user2",
......
......@@ -231,21 +231,14 @@ def get_port(port_id, user_id, for_update=False):
Return a NetworkInteface instance or raise ItemNotFound.
"""
try:
objects = NetworkInterface.objects
objects = NetworkInterface.objects.filter(userid=user_id)
if for_update:
objects = objects.select_for_update()
if not user_id:
port = objects.get(id=port_id)
else:
port = objects.get(network__userid=user_id, id=port_id)
# if (port.device_owner != "vm") and for_update:
# raise faults.BadRequest('Can not update non vm port')
return port
return objects.get(id=port_id)
except (ValueError, NetworkInterface.DoesNotExist):
raise faults.ItemNotFound('Port not found.')
raise faults.ItemNotFound("Port '%s' not found." % port_id)
def get_security_group(sg_id):
......
......@@ -176,6 +176,7 @@ class BackendNetworkFactory(factory.DjangoModelFactory):
class NetworkInterfaceFactory(factory.DjangoModelFactory):
FACTORY_FOR = models.NetworkInterface
userid = factory.Sequence(user_seq())
name = factory.LazyAttribute(lambda self: random_string(30))
machine = factory.SubFactory(VirtualMachineFactory, operstate="STARTED")
network = factory.SubFactory(NetworkFactory, state="ACTIVE")
......@@ -231,6 +232,7 @@ class IPv4AddressFactory(factory.DjangoModelFactory):
factory.LazyAttributeSequence(lambda self, n: self.subnet.cidr[:-4] +
'{0}'.format(int(n) + 2))
nic = factory.SubFactory(NetworkInterfaceFactory,
userid=factory.SelfAttribute('..userid'),
network=factory.SelfAttribute('..network'))
......
......@@ -261,7 +261,7 @@ def create_instance_nics(vm, userid, networks=[], floating_ips=[]):
ports.append(port)
for floating_ip_id in floating_ips:
floating_ip = util.get_floating_ip_by_id(vm.userid, floating_ip_id,
floating_ip = util.get_floating_ip_by_id(userid, floating_ip_id,
for_update=True)
port = _create_port(userid, network=floating_ip.network,
use_ipaddress=floating_ip)
......
......@@ -34,7 +34,8 @@
from django.core.management import CommandError
from synnefo.db.models import (Backend, VirtualMachine, Network,
Flavor, IPAddress, Subnet,
BridgePoolTable, MacPrefixPoolTable)
BridgePoolTable, MacPrefixPoolTable,
NetworkInterface)
from functools import wraps
from snf_django.lib.api import faults
......@@ -132,13 +133,29 @@ def get_network(network_id, for_update=True):
def get_subnet(subnet_id, for_update=True):
"""Get a Subnet object by its ID."""
try:
return Subnet.objects.get(id=subnet_id)
subnets = Subnet.objects
if for_update:
subnets.select_for_update()
return subnets.get(id=subnet_id)
except Subnet.DoesNotExist:
raise CommandError("Subnet with ID %s not found in DB."
" Use snf-manage subnet-list to find out"
" available subnet IDs" % subnet_id)
def get_port(port_id, for_update=True):
"""Get a port object by its ID."""
try:
ports = NetworkInterface.objects
if for_update:
ports.select_for_update()
return ports.get(id=port_id)
except NetworkInterface.DoesNotExist:
raise CommandError("Port with ID %s not found in DB."
" Use snf-manage port-list to find out"
" available port IDs" % port_id)
def get_flavor(flavor_id):
try:
flavor_id = int(flavor_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment