client.py 6.92 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
#
#

# Copyright (C) 2013 Google Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
# 02110-1301, USA.


"""Module for generic RPC clients.

"""

import logging

import ganeti.rpc.transport as t

from ganeti import constants
from ganeti import errors
from ganeti.rpc.errors import (ProtocolError, RequestError, LuxiError)
from ganeti import serializer

KEY_METHOD = constants.LUXI_KEY_METHOD
KEY_ARGS = constants.LUXI_KEY_ARGS
KEY_SUCCESS = constants.LUXI_KEY_SUCCESS
KEY_RESULT = constants.LUXI_KEY_RESULT
KEY_VERSION = constants.LUXI_KEY_VERSION


def ParseRequest(msg):
  """Parses a request message.

  """
  try:
    request = serializer.LoadJson(msg)
  except ValueError, err:
49
    raise ProtocolError("Invalid RPC request (parsing error): %s" % err)
50

51
  logging.debug("RPC request: %s", request)
52 53

  if not isinstance(request, dict):
54 55
    logging.error("RPC request not a dict: %r", msg)
    raise ProtocolError("Invalid RPC request (not a dict)")
56 57 58 59 60 61

  method = request.get(KEY_METHOD, None) # pylint: disable=E1103
  args = request.get(KEY_ARGS, None) # pylint: disable=E1103
  version = request.get(KEY_VERSION, None) # pylint: disable=E1103

  if method is None or args is None:
62 63
    logging.error("RPC request missing method or arguments: %r", msg)
    raise ProtocolError(("Invalid RPC request (no method or arguments"
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
                         " in request): %r") % msg)

  return (method, args, version)


def ParseResponse(msg):
  """Parses a response message.

  """
  # Parse the result
  try:
    data = serializer.LoadJson(msg)
  except KeyboardInterrupt:
    raise
  except Exception, err:
    raise ProtocolError("Error while deserializing response: %s" % str(err))

  # Validate response
  if not (isinstance(data, dict) and
          KEY_SUCCESS in data and
          KEY_RESULT in data):
    raise ProtocolError("Invalid response from server: %r" % data)

  return (data[KEY_SUCCESS], data[KEY_RESULT],
          data.get(KEY_VERSION, None)) # pylint: disable=E1103


def FormatResponse(success, result, version=None):
  """Formats a response message.

  """
  response = {
    KEY_SUCCESS: success,
    KEY_RESULT: result,
    }

  if version is not None:
    response[KEY_VERSION] = version

103
  logging.debug("RPC response: %s", response)
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121

  return serializer.DumpJson(response)


def FormatRequest(method, args, version=None):
  """Formats a request message.

  """
  # Build request
  request = {
    KEY_METHOD: method,
    KEY_ARGS: args,
    }

  if version is not None:
    request[KEY_VERSION] = version

  # Serialize the request
122 123
  return serializer.DumpJson(request,
                             private_encoder=serializer.EncodeWithPrivateFields)
124 125


126 127
def CallRPCMethod(transport_cb, method, args, version=None):
  """Send a RPC request via a transport and return the response.
128 129 130 131 132 133 134 135 136 137 138 139 140

  """
  assert callable(transport_cb)

  request_msg = FormatRequest(method, args, version=version)

  # Send request and wait for response
  response_msg = transport_cb(request_msg)

  (success, result, resp_version) = ParseResponse(response_msg)

  # Verify version if there was one in the response
  if resp_version is not None and resp_version != version:
141
    raise LuxiError("RPC version mismatch, client %s, response %s" %
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
                    (version, resp_version))

  if success:
    return result

  errors.MaybeRaise(result)
  raise RequestError(result)


class AbstractClient(object):
  """High-level client abstraction.

  This uses a backing Transport-like class on top of which it
  implements data serialization/deserialization.

  """

159
  def __init__(self, timeouts=None, transport=t.Transport):
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
    """Constructor for the Client class.

    Arguments:
      - address: a valid address the the used transport class
      - timeout: a list of timeouts, to be used on connect and read/write
      - transport: a Transport-like class


    If timeout is not passed, the default timeouts of the transport
    class are used.

    """
    self.timeouts = timeouts
    self.transport_class = transport
    self.transport = None
175 176
    # The version used in RPC communication, by default unused:
    self.version = None
177

178 179 180 181 182 183
  def _GetAddress(self):
    """Returns the socket address

    """
    raise NotImplementedError

184 185 186 187 188
  def _InitTransport(self):
    """(Re)initialize the transport if needed.

    """
    if self.transport is None:
189
      self.transport = self.transport_class(self._GetAddress(),
190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
                                            timeouts=self.timeouts)

  def _CloseTransport(self):
    """Close the transport, ignoring errors.

    """
    if self.transport is None:
      return
    try:
      old_transp = self.transport
      self.transport = None
      old_transp.Close()
    except Exception: # pylint: disable=W0703
      pass

  def _SendMethodCall(self, data):
    # Send request and wait for response
207 208 209
    def send(try_no):
      if try_no:
        logging.debug("RPC peer disconnected, retrying")
210 211
      self._InitTransport()
      return self.transport.Call(data)
212
    return t.Transport.RetryOnBrokenPipe(send, lambda _: self._CloseTransport())
213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232

  def Close(self):
    """Close the underlying connection.

    """
    self._CloseTransport()

  def close(self):
    """Same as L{Close}, to be used with contextlib.closing(...).

    """
    self.Close()

  def CallMethod(self, method, args):
    """Send a generic request and return the response.

    """
    if not isinstance(args, (list, tuple)):
      raise errors.ProgrammerError("Invalid parameter passed to CallMethod:"
                                   " expected list, got %s" % type(args))
233
    return CallRPCMethod(self._SendMethodCall, method, args,
234
                         version=self.version)
235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250


class AbstractStubClient(AbstractClient):
  """An abstract Client that connects a generated stub client to a L{Transport}.

  Subclasses should inherit from this class (first) as well and a designated
  stub (second).
  """

  def __init__(self, timeouts=None, transport=t.Transport):
    """Constructor for the class.

    Arguments are the same as for L{AbstractClient}. Checks that SOCKET_PATH
    attribute is defined (in the stub class).
    """

251 252
    super(AbstractStubClient, self).__init__(timeouts=timeouts,
                                             transport=transport)
253 254 255 256 257 258

  def _GenericInvoke(self, method, *args):
    return self.CallMethod(method, args)

  def _GetAddress(self):
    return self._GetSocketPath() # pylint: disable=E1101