diff --git a/lib/rapi/baserlib.py b/lib/rapi/baserlib.py index 020d7bcba06057c73c849c46bb09b3a055f08178..fb4c48e77a731191764393cfce89bc3d7d530c7e 100644 --- a/lib/rapi/baserlib.py +++ b/lib/rapi/baserlib.py @@ -38,6 +38,10 @@ from ganeti import opcodes from ganeti import errors +# Dummy value to detect unchanged parameters +_DEFAULT = object() + + def BuildUriList(ids, uri_format, uri_fields=("name", "uri")): """Builds a URI list as used by index resources. @@ -213,6 +217,53 @@ def FeedbackFn(ts, log_type, log_msg): # pylint: disable-msg=W0613 logging.info("%s: %s", log_type, log_msg) +def CheckType(value, exptype, descr): + """Abort request if value type doesn't match expected type. + + @param value: Value + @type exptype: type + @param exptype: Expected type + @type descr: string + @param descr: Description of value + @return: Value (allows inline usage) + + """ + if not isinstance(value, exptype): + raise http.HttpBadRequest("%s: Type is '%s', but '%s' is expected" % + (descr, type(value).__name__, exptype.__name__)) + + return value + + +def CheckParameter(data, name, default=_DEFAULT, exptype=_DEFAULT): + """Check and return the value for a given parameter. + + If no default value was given and the parameter doesn't exist in the input + data, an error is raise. + + @type data: dict + @param data: Dictionary containing input data + @type name: string + @param name: Parameter name + @param default: Default value (can be None) + @param exptype: Expected type (can be None) + + """ + try: + value = data[name] + except KeyError: + if default is not _DEFAULT: + return default + + raise http.HttpBadRequest("Required parameter '%s' is missing" % + name) + + if exptype is _DEFAULT: + return value + + return CheckType(value, exptype, "'%s' parameter" % name) + + class R_Generic(object): """Generic class for resources. @@ -279,13 +330,10 @@ class R_Generic(object): @param name: the required parameter """ - if name in self.req.request_body: - return self.req.request_body[name] - elif args: - return args[0] - else: - raise http.HttpBadRequest("Required parameter '%s' is missing" % - name) + if args: + return CheckParameter(self.req.request_body, name, default=args[0]) + + return CheckParameter(self.req.request_body, name) def useLocking(self): """Check if the request specifies locking.