diff --git a/lib/utils.py b/lib/utils.py index 1f33670684ca07a6a79ab11a2780793757a7df1e..7d13617afe0eafe1930ada820835b017d008e5fa 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -81,6 +81,8 @@ X509_SIGNATURE = re.compile(r"^%s:\s*(?P<salt>%s+)/(?P<sign>%s+)$" % HEX_CHAR_RE, HEX_CHAR_RE), re.S | re.I) +_VALID_SERVICE_NAME_RE = re.compile("^[-_.a-zA-Z0-9]{1,128}$") + # Structure definition for getsockopt(SOL_SOCKET, SO_PEERCRED, ...): # struct ucred { pid_t pid; uid_t uid; gid_t gid; }; # @@ -1155,6 +1157,30 @@ class HostInfo: return hostname +def ValidateServiceName(name): + """Validate the given service name. + + @type name: number or string + @param name: Service name or port specification + + """ + try: + numport = int(name) + except (ValueError, TypeError): + # Non-numeric service name + valid = _VALID_SERVICE_NAME_RE.match(name) + else: + # Numeric port (protocols other than TCP or UDP might need adjustments + # here) + valid = (numport >= 0 and numport < (1 << 16)) + + if not valid: + raise errors.OpPrereqError("Invalid service name '%s'" % name, + errors.ECODE_INVAL) + + return name + + def GetHostInfo(name=None): """Lookup host name and raise an OpPrereqError for failures""" diff --git a/test/ganeti.utils_unittest.py b/test/ganeti.utils_unittest.py index 4b8d4d13fc549d7d16df04292ce011ecb60b6624..650fd619d4272f9090c5c03e2ba6d6dc23e8c9d5 100755 --- a/test/ganeti.utils_unittest.py +++ b/test/ganeti.utils_unittest.py @@ -1910,6 +1910,32 @@ class TestHostInfo(unittest.TestCase): HostInfo.NormalizeName(value) +class TestValidateServiceName(unittest.TestCase): + def testValid(self): + testnames = [ + 0, 1, 2, 3, 1024, 65000, 65534, 65535, + "ganeti", + "gnt-masterd", + "HELLO_WORLD_SVC", + "hello.world.1", + "0", "80", "1111", "65535", + ] + + for name in testnames: + self.assertEqual(utils.ValidateServiceName(name), name) + + def testInvalid(self): + testnames = [ + -15756, -1, 65536, 133428083, + "", "Hello World!", "!", "'", "\"", "\t", "\n", "`", + "-8546", "-1", "65536", + (129 * "A"), + ] + + for name in testnames: + self.assertRaises(OpPrereqError, utils.ValidateServiceName, name) + + class TestParseAsn1Generalizedtime(unittest.TestCase): def test(self): # UTC