Commit 8b312c1d authored by Manuel Franceschini's avatar Manuel Franceschini

Introduce new IPAddress classes

This patch unifies the netutils functions dealing with IP addresses to
three classes:
- IPAddress: Common IP address functionality
- IPv4Address: IPv4 specific functionality
- IPv6address: IPv6-specific functionality

Furthermore it adds methods to check whether an address is a loopback
address, replacing the .startswith("127") for IPv4 and adding IPv6
support.

It also provides the basis for future IPv6 address handling. Methods to
convert IP strings to their corresponding interger values will allow to
canonicalize IPv6 addresses.
Signed-off-by: default avatarManuel Franceschini <livewire@google.com>
Reviewed-by: default avatarIustin Pop <iustin@google.com>
parent 1eb85930
#!/usr/bin/python
#
# Copyright (C) 2009, Google Inc.
# Copyright (C) 2009, 2010 Google Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
......@@ -66,8 +66,8 @@ class ConfdAsyncUDPServer(daemon.AsyncUDPSocket):
@param processor: ConfdProcessor to use to handle queries
"""
daemon.AsyncUDPSocket.__init__(self,
netutils.GetAddressFamily(bind_address))
family = netutils.IPAddress.GetAddressFamily(bind_address)
daemon.AsyncUDPSocket.__init__(self, family)
self.bind_address = bind_address
self.port = port
self.processor = processor
......
#!/usr/bin/python
#
# Copyright (C) 2006, 2007 Google Inc.
# Copyright (C) 2006, 2007, 2010 Google Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
......@@ -598,7 +598,7 @@ class NodeHttpServer(http.server.HttpServer):
"""Checks if a node has the given ip address.
"""
return netutils.OwnIpAddress(params[0])
return netutils.IPAddress.Own(params[0])
@staticmethod
def perspective_node_info(params):
......
#
#
# Copyright (C) 2006, 2007 Google Inc.
# Copyright (C) 2006, 2007, 2010 Google Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
......@@ -279,7 +279,7 @@ def StartMaster(start_daemons, no_voting):
# or activate the IP
else:
if netutils.TcpPing(master_ip, constants.DEFAULT_NODED_PORT):
if netutils.OwnIpAddress(master_ip):
if netutils.IPAddress.Own(master_ip):
# we already have the ip:
logging.debug("Master IP already configured, doing nothing")
else:
......
......@@ -1301,13 +1301,13 @@ class DRBD8(BaseDRBD):
# about its peer.
cls._SetMinorSyncSpeed(minor, constants.SYNC_SPEED)
if netutils.IsValidIP6(lhost):
if not netutils.IsValidIP6(rhost):
if netutils.IP6Address.IsValid(lhost):
if not netutils.IP6Address.IsValid(rhost):
_ThrowError("drbd%d: can't connect ip %s to ip %s" %
(minor, lhost, rhost))
family = "ipv6"
elif netutils.IsValidIP4(lhost):
if not netutils.IsValidIP4(rhost):
elif netutils.IP4Address.IsValid(lhost):
if not netutils.IP4Address.IsValid(rhost):
_ThrowError("drbd%d: can't connect ip %s to ip %s" %
(minor, lhost, rhost))
family = "ipv4"
......
......@@ -245,13 +245,13 @@ def InitCluster(cluster_name, mac_prefix,
hostname = netutils.GetHostInfo()
if hostname.ip.startswith("127."):
raise errors.OpPrereqError("This host's IP resolves to the private"
" range (%s). Please fix DNS or %s." %
if netutils.IP4Address.IsLoopback(hostname.ip):
raise errors.OpPrereqError("This host's IP (%s) resolves to a loopback"
" address. Please fix DNS or %s." %
(hostname.ip, constants.ETC_HOSTS),
errors.ECODE_ENVIRON)
if not netutils.OwnIpAddress(hostname.ip):
if not netutils.IPAddress.Own(hostname.ip):
raise errors.OpPrereqError("Inconsistency: this host's name resolves"
" to %s,\nbut this ip address does not"
" belong to this host. Aborting." %
......@@ -266,11 +266,11 @@ def InitCluster(cluster_name, mac_prefix,
errors.ECODE_NOTUNIQUE)
if secondary_ip:
if not netutils.IsValidIP4(secondary_ip):
if not netutils.IP4Address.IsValid(secondary_ip):
raise errors.OpPrereqError("Invalid secondary ip given",
errors.ECODE_INVAL)
if (secondary_ip != hostname.ip and
not netutils.OwnIpAddress(secondary_ip)):
not netutils.IPAddress.Own(secondary_ip)):
raise errors.OpPrereqError("You gave %s as secondary IP,"
" but it does not belong to this host." %
secondary_ip, errors.ECODE_ENVIRON)
......
......@@ -3712,7 +3712,7 @@ class LUAddNode(LogicalUnit):
primary_ip = self.op.primary_ip = dns_data.ip
if self.op.secondary_ip is None:
self.op.secondary_ip = primary_ip
if not netutils.IsValidIP4(self.op.secondary_ip):
if not netutils.IP4Address.IsValid(self.op.secondary_ip):
raise errors.OpPrereqError("Invalid secondary IP given",
errors.ECODE_INVAL)
secondary_ip = self.op.secondary_ip
......@@ -7015,7 +7015,7 @@ class LUCreateInstance(LogicalUnit):
errors.ECODE_INVAL)
nic_ip = self.hostname1.ip
else:
if not netutils.IsValidIP4(ip):
if not netutils.IP4Address.IsValid(ip):
raise errors.OpPrereqError("Given IP address '%s' doesn't look"
" like a valid IP" % ip,
errors.ECODE_INVAL)
......@@ -8690,7 +8690,7 @@ class LUSetInstanceParams(LogicalUnit):
if nic_ip.lower() == constants.VALUE_NONE:
nic_dict['ip'] = None
else:
if not netutils.IsValidIP4(nic_ip):
if not netutils.IP4Address.IsValid(nic_ip):
raise errors.OpPrereqError("Invalid IP address '%s'" % nic_ip,
errors.ECODE_INVAL)
......
#
#
# Copyright (C) 2009 Google Inc.
# Copyright (C) 2009, 2010 Google Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
......@@ -391,11 +391,11 @@ class ConfdClient:
raise errors.ConfdClientError("Peer list empty")
try:
peer = self._peers[0]
self._family = netutils.GetAddressFamily(peer)
self._family = netutils.IPAddress.GetAddressFamily(peer)
for peer in self._peers[1:]:
if netutils.GetAddressFamily(peer) != self._family:
if netutils.IPAddress.GetAddressFamily(peer) != self._family:
raise errors.ConfdClientError("Peers must be of same address family")
except errors.GenericError:
except errors.IPAddressError:
raise errors.ConfdClientError("Peer address %s invalid" % peer)
......
#
#
# Copyright (C) 2006, 2007 Google Inc.
# Copyright (C) 2006, 2007, 2010 Google Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
......@@ -356,6 +356,12 @@ class NoCtypesError(GenericError):
"""
class IPAddressError(GenericError):
"""Generic IP address error.
"""
# errors should be added above
......
......@@ -175,7 +175,8 @@ class KVMHypervisor(hv_base.BaseHypervisor):
constants.HV_ACPI: hv_base.NO_CHECK,
constants.HV_SERIAL_CONSOLE: hv_base.NO_CHECK,
constants.HV_VNC_BIND_ADDRESS:
(False, lambda x: (netutils.IsValidIP4(x) or utils.IsNormAbsPath(x)),
(False, lambda x: (netutils.IP4Address.IsValid(x) or
utils.IsNormAbsPath(x)),
"the VNC bind address must be either a valid IP address or an absolute"
" pathname", None, None),
constants.HV_VNC_TLS: hv_base.NO_CHECK,
......@@ -572,7 +573,7 @@ class KVMHypervisor(hv_base.BaseHypervisor):
vnc_bind_address = hvp[constants.HV_VNC_BIND_ADDRESS]
if vnc_bind_address:
if netutils.IsValidIP4(vnc_bind_address):
if netutils.IP4Address.IsValid(vnc_bind_address):
if instance.network_port > constants.VNC_BASE_PORT:
display = instance.network_port - constants.VNC_BASE_PORT
if vnc_bind_address == constants.IP4_ADDRESS_ANY:
......
......@@ -551,7 +551,7 @@ class XenHvmHypervisor(XenHypervisor):
hv_base.ParamInSet(True, constants.HT_HVM_VALID_NIC_TYPES),
constants.HV_PAE: hv_base.NO_CHECK,
constants.HV_VNC_BIND_ADDRESS:
(False, netutils.IsValidIP4,
(False, netutils.IP4Address.IsValid,
"VNC bind address is not a valid IP address", None, None),
constants.HV_KERNEL_PATH: hv_base.REQ_FILE_CHECK,
constants.HV_DEVICE_MODEL: hv_base.REQ_FILE_CHECK,
......
......@@ -151,84 +151,6 @@ class HostInfo:
return hostname
def _GenericIsValidIP(family, ip):
"""Generic internal version of ip validation.
@type family: int
@param family: socket.AF_INET | socket.AF_INET6
@type ip: str
@param ip: the address to be checked
@rtype: boolean
@return: True if ip is valid, False otherwise
"""
try:
socket.inet_pton(family, ip)
return True
except socket.error:
return False
def IsValidIP4(ip):
"""Verifies an IPv4 address.
This function checks if the given address is a valid IPv4 address.
@type ip: str
@param ip: the address to be checked
@rtype: boolean
@return: True if ip is valid, False otherwise
"""
return _GenericIsValidIP(socket.AF_INET, ip)
def IsValidIP6(ip):
"""Verifies an IPv6 address.
This function checks if the given address is a valid IPv6 address.
@type ip: str
@param ip: the address to be checked
@rtype: boolean
@return: True if ip is valid, False otherwise
"""
return _GenericIsValidIP(socket.AF_INET6, ip)
def IsValidIP(ip):
"""Verifies an IP address.
This function checks if the given IP address (both IPv4 and IPv6) is valid.
@type ip: str
@param ip: the address to be checked
@rtype: boolean
@return: True if ip is valid, False otherwise
"""
return IsValidIP4(ip) or IsValidIP6(ip)
def GetAddressFamily(ip):
"""Get the address family of the given address.
@type ip: str
@param ip: ip address whose family will be returned
@rtype: int
@return: socket.AF_INET or socket.AF_INET6
@raise errors.GenericError: for invalid addresses
"""
if IsValidIP6(ip):
return socket.AF_INET6
elif IsValidIP4(ip):
return socket.AF_INET
else:
raise errors.GenericError("Address %s not valid" % ip)
def TcpPing(target, port, timeout=10, live_port_needed=False, source=None):
"""Simple ping implementation using TCP connect(2).
......@@ -251,7 +173,7 @@ def TcpPing(target, port, timeout=10, live_port_needed=False, source=None):
"""
try:
family = GetAddressFamily(target)
family = IPAddress.GetAddressFamily(target)
except errors.GenericError:
return False
......@@ -279,32 +201,6 @@ def TcpPing(target, port, timeout=10, live_port_needed=False, source=None):
return success
def OwnIpAddress(address):
"""Check if the current host has the the given IP address.
This is done by trying to bind the given address. We return True if we
succeed or false if a socket.error is raised.
@type address: string
@param address: the address to check
@rtype: bool
@return: True if we own the address
"""
family = GetAddressFamily(address)
s = socket.socket(family, socket.SOCK_DGRAM)
success = False
try:
try:
s.bind((address, 0))
success = True
except socket.error:
success = False
finally:
s.close()
return success
def GetDaemonPort(daemon_name):
"""Get the daemon port for this cluster.
......@@ -327,3 +223,239 @@ def GetDaemonPort(daemon_name):
port = default_port
return port
class IPAddress(object):
"""Class that represents an IP address.
"""
iplen = 0
family = None
loopback_cidr = None
@staticmethod
def _GetIPIntFromString(address):
"""Abstract method to please pylint.
"""
raise NotImplementedError
@classmethod
def IsValid(cls, address):
"""Validate a IP address.
@type address: str
@param address: IP address to be checked
@rtype: bool
@return: True if valid, False otherwise
"""
if cls.family is None:
try:
family = cls.GetAddressFamily(address)
except errors.IPAddressError:
return False
else:
family = cls.family
try:
socket.inet_pton(family, address)
return True
except socket.error:
return False
@classmethod
def Own(cls, address):
"""Check if the current host has the the given IP address.
This is done by trying to bind the given address. We return True if we
succeed or false if a socket.error is raised.
@type address: str
@param address: IP address to be checked
@rtype: bool
@return: True if we own the address, False otherwise
"""
if cls.family is None:
try:
family = cls.GetAddressFamily(address)
except errors.IPAddressError:
return False
else:
family = cls.family
s = socket.socket(family, socket.SOCK_DGRAM)
success = False
try:
try:
s.bind((address, 0))
success = True
except socket.error:
success = False
finally:
s.close()
return success
@classmethod
def InNetwork(cls, cidr, address):
"""Determine whether an address is within a network.
@type cidr: string
@param cidr: Network in CIDR notation, e.g. '192.0.2.0/24', '2001:db8::/64'
@type address: str
@param address: IP address
@rtype: bool
@return: True if address is in cidr, False otherwise
"""
address_int = cls._GetIPIntFromString(address)
subnet = cidr.split("/")
assert len(subnet) == 2
try:
prefix = int(subnet[1])
except ValueError:
return False
assert 0 <= prefix <= cls.iplen
target_int = cls._GetIPIntFromString(subnet[0])
# Convert prefix netmask to integer value of netmask
netmask_int = (2**cls.iplen)-1 ^ ((2**cls.iplen)-1 >> prefix)
# Calculate hostmask
hostmask_int = netmask_int ^ (2**cls.iplen)-1
# Calculate network address by and'ing netmask
network_int = target_int & netmask_int
# Calculate broadcast address by or'ing hostmask
broadcast_int = target_int | hostmask_int
return network_int <= address_int <= broadcast_int
@staticmethod
def GetAddressFamily(address):
"""Get the address family of the given address.
@type address: str
@param address: ip address whose family will be returned
@rtype: int
@return: socket.AF_INET or socket.AF_INET6
@raise errors.GenericError: for invalid addresses
"""
try:
return IP4Address(address).family
except errors.IPAddressError:
pass
try:
return IP6Address(address).family
except errors.IPAddressError:
pass
raise errors.IPAddressError("Invalid address '%s'" % address)
@classmethod
def IsLoopback(cls, address):
"""Determine whether it is a loopback address.
@type address: str
@param address: IP address to be checked
@rtype: bool
@return: True if loopback, False otherwise
"""
try:
return cls.InNetwork(cls.loopback_cidr, address)
except errors.IPAddressError:
return False
class IP4Address(IPAddress):
"""IPv4 address class.
"""
iplen = 32
family = socket.AF_INET
loopback_cidr = "127.0.0.0/8"
def __init__(self, address):
"""Constructor for IPv4 address.
@type address: str
@param address: IP address
@raises errors.IPAddressError: if address invalid
"""
IPAddress.__init__(self)
if not self.IsValid(address):
raise errors.IPAddressError("IPv4 Address %s invalid" % address)
self.address = address
@staticmethod
def _GetIPIntFromString(address):
"""Get integer value of IPv4 address.
@type address: str
@param: IPv6 address
@rtype: int
@return: integer value of given IP address
"""
address_int = 0
parts = address.split(".")
assert len(parts) == 4
for part in parts:
address_int = (address_int << 8) | int(part)
return address_int
class IP6Address(IPAddress):
"""IPv6 address class.
"""
iplen = 128
family = socket.AF_INET6
loopback_cidr = "::1/128"
def __init__(self, address):
"""Constructor for IPv6 address.
@type address: str
@param address: IP address
@raises errors.IPAddressError: if address invalid
"""
IPAddress.__init__(self)
if not self.IsValid(address):
raise errors.IPAddressError("IPv6 Address [%s] invalid" % address)
self.address = address
@staticmethod
def _GetIPIntFromString(address):
"""Get integer value of IPv6 address.
@type address: str
@param: IPv6 address
@rtype: int
@return: integer value of given IP address
"""
doublecolons = address.count("::")
assert not doublecolons > 1
if doublecolons == 1:
# We have a shorthand address, expand it
parts = []
twoparts = address.split("::")
sep = len(twoparts[0].split(':')) + len(twoparts[1].split(':'))
parts = twoparts[0].split(':')
[parts.append("0") for _ in range(8 - sep)]
parts += twoparts[1].split(':')
else:
parts = address.split(":")
address_int = 0
for part in parts:
address_int = (address_int << 16) + int(part or '0', 16)
return address_int
......@@ -1218,7 +1218,7 @@ def ShowInstanceConfig(opts, args):
vnc_console_port = "%s:%s (display %s)" % (instance["pnode"],
port,
display)
elif display > 0 and netutils.IsValidIP4(vnc_bind_address):
elif display > 0 and netutils.IP4Address.IsValid(vnc_bind_address):
vnc_console_port = ("%s:%s (node %s) (display %s)" %
(vnc_bind_address, port,
instance["pnode"], display))
......
......@@ -140,52 +140,114 @@ class TestHostInfo(unittest.TestCase):
netutils.HostInfo.NormalizeName(value)
class TestIsValidIP4(unittest.TestCase):
def test(self):
self.assert_(netutils.IsValidIP4("127.0.0.1"))
self.assert_(netutils.IsValidIP4("0.0.0.0"))
self.assert_(netutils.IsValidIP4("255.255.255.255"))
self.assertFalse(netutils.IsValidIP4("0"))
self.assertFalse(netutils.IsValidIP4("1"))
self.assertFalse(netutils.IsValidIP4("1.1.1"))
self.assertFalse(netutils.IsValidIP4("255.255.255.256"))
self.assertFalse(netutils.IsValidIP4("::1"))
class TestIsValidIP6(unittest.TestCase):
def test(self):
self.assert_(netutils.IsValidIP6("::"))
self.assert_(netutils.IsValidIP6("::1"))
self.assert_(netutils.IsValidIP6("1" + (":1" * 7)))
self.assert_(netutils.IsValidIP6("ffff" + (":ffff" * 7)))
self.assertFalse(netutils.IsValidIP6("0"))
self.assertFalse(netutils.IsValidIP6(":1"))
self.assertFalse(netutils.IsValidIP6("f" + (":f" * 6)))
self.assertFalse(netutils.IsValidIP6("fffg" + (":ffff" * 7)))
self.assertFalse(netutils.IsValidIP6("fffff" + (":ffff" * 7)))
self.assertFalse(netutils.IsValidIP6("1" + (":1" * 8)))
self.assertFalse(netutils.IsValidIP6("127.0.0.1"))
class TestIsValidIP(unittest.TestCase):
def test(self):
self.assert_(netutils.IsValidIP("0.0.0.0"))
self.assert_(netutils.IsValidIP("127.0.0.1"))
self.assert_(netutils.IsValidIP("::"))
self.assert_(netutils.IsValidIP("::1"))
self.assertFalse(netutils.IsValidIP("0"))
self.asse