From 28f3404879c74722674bcc1e10398f7ad6388c9a Mon Sep 17 00:00:00 2001 From: Michael Hanselmann <hansmi@google.com> Date: Fri, 11 Jun 2010 13:52:19 +0200 Subject: [PATCH] utils: Add function to validate service name Signed-off-by: Michael Hanselmann <hansmi@google.com> Reviewed-by: Guido Trotter <ultrotter@google.com> --- lib/utils.py | 26 ++++++++++++++++++++++++++ test/ganeti.utils_unittest.py | 26 ++++++++++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/lib/utils.py b/lib/utils.py index 1f3367068..7d13617af 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -81,6 +81,8 @@ X509_SIGNATURE = re.compile(r"^%s:\s*(?P<salt>%s+)/(?P<sign>%s+)$" % HEX_CHAR_RE, HEX_CHAR_RE), re.S | re.I) +_VALID_SERVICE_NAME_RE = re.compile("^[-_.a-zA-Z0-9]{1,128}$") + # Structure definition for getsockopt(SOL_SOCKET, SO_PEERCRED, ...): # struct ucred { pid_t pid; uid_t uid; gid_t gid; }; # @@ -1155,6 +1157,30 @@ class HostInfo: return hostname +def ValidateServiceName(name): + """Validate the given service name. + + @type name: number or string + @param name: Service name or port specification + + """ + try: + numport = int(name) + except (ValueError, TypeError): + # Non-numeric service name + valid = _VALID_SERVICE_NAME_RE.match(name) + else: + # Numeric port (protocols other than TCP or UDP might need adjustments + # here) + valid = (numport >= 0 and numport < (1 << 16)) + + if not valid: + raise errors.OpPrereqError("Invalid service name '%s'" % name, + errors.ECODE_INVAL) + + return name + + def GetHostInfo(name=None): """Lookup host name and raise an OpPrereqError for failures""" diff --git a/test/ganeti.utils_unittest.py b/test/ganeti.utils_unittest.py index 4b8d4d13f..650fd619d 100755 --- a/test/ganeti.utils_unittest.py +++ b/test/ganeti.utils_unittest.py @@ -1910,6 +1910,32 @@ class TestHostInfo(unittest.TestCase): HostInfo.NormalizeName(value) +class TestValidateServiceName(unittest.TestCase): + def testValid(self): + testnames = [ + 0, 1, 2, 3, 1024, 65000, 65534, 65535, + "ganeti", + "gnt-masterd", + "HELLO_WORLD_SVC", + "hello.world.1", + "0", "80", "1111", "65535", + ] + + for name in testnames: + self.assertEqual(utils.ValidateServiceName(name), name) + + def testInvalid(self): + testnames = [ + -15756, -1, 65536, 133428083, + "", "Hello World!", "!", "'", "\"", "\t", "\n", "`", + "-8546", "-1", "65536", + (129 * "A"), + ] + + for name in testnames: + self.assertRaises(OpPrereqError, utils.ValidateServiceName, name) + + class TestParseAsn1Generalizedtime(unittest.TestCase): def test(self): # UTC -- GitLab