client.py 20.5 KB
Newer Older
Guido Trotter's avatar
Guido Trotter committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
#
#

# Copyright (C) 2009 Google Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
# 02110-1301, USA.


"""Ganeti confd client

24 25 26 27 28 29 30 31
Clients can use the confd client library to send requests to a group of master
candidates running confd. The expected usage is through the asyncore framework,
by sending queries, and asynchronously receiving replies through a callback.

This way the client library doesn't ever need to "wait" on a particular answer,
and can proceed even if some udp packets are lost. It's up to the user to
reschedule queries if they haven't received responses and they need them.

32 33
Example usage::

34 35 36 37 38 39 40 41
  client = ConfdClient(...) # includes callback specification
  req = confd_client.ConfdClientRequest(type=constants.CONFD_REQ_PING)
  client.SendRequest(req)
  # then make sure your client calls asyncore.loop() or daemon.Mainloop.Run()
  # ... wait ...
  # And your callback will be called by asyncore, when your query gets a
  # response, or when it expires.

Guido Trotter's avatar
Guido Trotter committed
42 43 44 45
You can use the provided ConfdFilterCallback to act as a filter, only passing
"newer" answer to your callback, and filtering out outdated ones, or ones
confirming what you already got.

Guido Trotter's avatar
Guido Trotter committed
46
"""
47

Iustin Pop's avatar
Iustin Pop committed
48 49 50 51 52
# pylint: disable-msg=E0203

# E0203: Access to member %r before its definition, since we use
# objects.py which doesn't explicitely initialise its members

Guido Trotter's avatar
Guido Trotter committed
53 54 55 56 57 58 59 60 61 62
import time
import random

from ganeti import utils
from ganeti import constants
from ganeti import objects
from ganeti import serializer
from ganeti import daemon # contains AsyncUDPSocket
from ganeti import errors
from ganeti import confd
63
from ganeti import ssconf
Guido Trotter's avatar
Guido Trotter committed
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87


class ConfdAsyncUDPClient(daemon.AsyncUDPSocket):
  """Confd udp asyncore client

  This is kept separate from the main ConfdClient to make sure it's easy to
  implement a non-asyncore based client library.

  """
  def __init__(self, client):
    """Constructor for ConfdAsyncUDPClient

    @type client: L{ConfdClient}
    @param client: client library, to pass the datagrams to

    """
    daemon.AsyncUDPSocket.__init__(self)
    self.client = client

  # this method is overriding a daemon.AsyncUDPSocket method
  def handle_datagram(self, payload, ip, port):
    self.client.HandleResponse(payload, ip, port)


88 89 90 91 92 93
class _Request(object):
  """Request status structure.

  @ivar request: the request data
  @ivar args: any extra arguments for the callback
  @ivar expiry: the expiry timestamp of the request
94 95
  @ivar sent: the set of contacted peers
  @ivar rcvd: the set of peers who replied
96 97

  """
98
  def __init__(self, request, args, expiry, sent):
99 100 101
    self.request = request
    self.args = args
    self.expiry = expiry
102 103
    self.sent = frozenset(sent)
    self.rcvd = set()
104 105


Guido Trotter's avatar
Guido Trotter committed
106 107 108 109 110 111 112
class ConfdClient:
  """Send queries to confd, and get back answers.

  Since the confd model works by querying multiple master candidates, and
  getting back answers, this is an asynchronous library. It can either work
  through asyncore or with your own handling.

113 114 115 116 117
  @type _requests: dict
  @ivar _requests: dictionary indexes by salt, which contains data
      about the outstanding requests; the values are objects of type
      L{_Request}

Guido Trotter's avatar
Guido Trotter committed
118
  """
119
  def __init__(self, hmac_key, peers, callback, port=None, logger=None):
Guido Trotter's avatar
Guido Trotter committed
120 121 122 123 124 125
    """Constructor for ConfdClient

    @type hmac_key: string
    @param hmac_key: hmac key to talk to confd
    @type peers: list
    @param peers: list of peer nodes
126 127
    @type callback: f(L{ConfdUpcallPayload})
    @param callback: function to call when getting answers
128
    @type port: integer
129
    @param port: confd port (default: use GetDaemonPort)
130
    @type logger: logging.Logger
131
    @param logger: optional logger for internal conditions
Guido Trotter's avatar
Guido Trotter committed
132 133

    """
134 135
    if not callable(callback):
      raise errors.ProgrammerError("callback must be callable")
Guido Trotter's avatar
Guido Trotter committed
136

137
    self.UpdatePeerList(peers)
Guido Trotter's avatar
Guido Trotter committed
138 139
    self._hmac_key = hmac_key
    self._socket = ConfdAsyncUDPClient(self)
140
    self._callback = callback
141
    self._confd_port = port
142
    self._logger = logger
143
    self._requests = {}
144 145 146

    if self._confd_port is None:
      self._confd_port = utils.GetDaemonPort(constants.CONFD)
Guido Trotter's avatar
Guido Trotter committed
147

148 149 150 151 152 153 154
  def UpdatePeerList(self, peers):
    """Update the list of peers

    @type peers: list
    @param peers: list of peer nodes

    """
155 156
    # we are actually called from init, so:
    # pylint: disable-msg=W0201
157 158
    if not isinstance(peers, list):
      raise errors.ProgrammerError("peers must be a list")
159 160
    # make a copy of peers, since we're going to shuffle the list, later
    self._peers = list(peers)
161

Guido Trotter's avatar
Guido Trotter committed
162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
  def _PackRequest(self, request, now=None):
    """Prepare a request to be sent on the wire.

    This function puts a proper salt in a confd request, puts the proper salt,
    and adds the correct magic number.

    """
    if now is None:
      now = time.time()
    tstamp = '%d' % now
    req = serializer.DumpSignedJson(request.ToDict(), self._hmac_key, tstamp)
    return confd.PackMagic(req)

  def _UnpackReply(self, payload):
    in_payload = confd.UnpackMagic(payload)
177 178
    (dict_answer, salt) = serializer.LoadSignedJson(in_payload, self._hmac_key)
    answer = objects.ConfdReply.FromDict(dict_answer)
Guido Trotter's avatar
Guido Trotter committed
179 180
    return answer, salt

181 182
  def ExpireRequests(self):
    """Delete all the expired requests.
Guido Trotter's avatar
Guido Trotter committed
183 184 185

    """
    now = time.time()
186 187
    for rsalt, rq in self._requests.items():
      if now >= rq.expiry:
188 189 190
        del self._requests[rsalt]
        client_reply = ConfdUpcallPayload(salt=rsalt,
                                          type=UPCALL_EXPIRE,
191 192
                                          orig_request=rq.request,
                                          extra_args=rq.args,
193 194
                                          client=self,
                                          )
195
        self._callback(client_reply)
Guido Trotter's avatar
Guido Trotter committed
196

197
  def SendRequest(self, request, args=None, coverage=None, async=True):
Guido Trotter's avatar
Guido Trotter committed
198 199 200 201 202
    """Send a confd request to some MCs

    @type request: L{objects.ConfdRequest}
    @param request: the request to send
    @type args: tuple
203
    @param args: additional callback arguments
Guido Trotter's avatar
Guido Trotter committed
204
    @type coverage: integer
205
    @param coverage: number of remote nodes to contact
206 207
    @type async: boolean
    @param async: handle the write asynchronously
Guido Trotter's avatar
Guido Trotter committed
208 209 210 211 212 213 214 215 216 217 218 219

    """
    if coverage is None:
      coverage = min(len(self._peers), constants.CONFD_DEFAULT_REQ_COVERAGE)

    if coverage > len(self._peers):
      raise errors.ConfdClientError("Not enough MCs known to provide the"
                                    " desired coverage")

    if not request.rsalt:
      raise errors.ConfdClientError("Missing request rsalt")

220 221
    self.ExpireRequests()
    if request.rsalt in self._requests:
Guido Trotter's avatar
Guido Trotter committed
222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
      raise errors.ConfdClientError("Duplicate request rsalt")

    if request.type not in constants.CONFD_REQS:
      raise errors.ConfdClientError("Invalid request type")

    random.shuffle(self._peers)
    targets = self._peers[:coverage]

    now = time.time()
    payload = self._PackRequest(request, now=now)

    for target in targets:
      try:
        self._socket.enqueue_send(target, self._confd_port, payload)
      except errors.UdpDataSizeError:
        raise errors.ConfdClientError("Request too big")

    expire_time = now + constants.CONFD_CLIENT_EXPIRE_TIMEOUT
240 241
    self._requests[request.rsalt] = _Request(request, args, expire_time,
                                             targets)
Guido Trotter's avatar
Guido Trotter committed
242

243 244 245
    if not async:
      self.FlushSendQueue()

Guido Trotter's avatar
Guido Trotter committed
246 247 248 249 250 251 252 253 254
  def HandleResponse(self, payload, ip, port):
    """Asynchronous handler for a confd reply

    Call the relevant callback associated to the current request.

    """
    try:
      try:
        answer, salt = self._UnpackReply(payload)
255 256 257
      except (errors.SignatureError, errors.ConfdMagicError), err:
        if self._logger:
          self._logger.debug("Discarding broken package: %s" % err)
Guido Trotter's avatar
Guido Trotter committed
258 259 260
        return

      try:
261
        rq = self._requests[salt]
Guido Trotter's avatar
Guido Trotter committed
262
      except KeyError:
263 264
        if self._logger:
          self._logger.debug("Discarding unknown (expired?) reply: %s" % err)
265 266
        return

267 268
      rq.rcvd.add(ip)

269 270 271
      client_reply = ConfdUpcallPayload(salt=salt,
                                        type=UPCALL_REPLY,
                                        server_reply=answer,
272
                                        orig_request=rq.request,
273 274
                                        server_ip=ip,
                                        server_port=port,
275
                                        extra_args=rq.args,
276 277
                                        client=self,
                                       )
278
      self._callback(client_reply)
Guido Trotter's avatar
Guido Trotter committed
279 280

    finally:
281 282
      self.ExpireRequests()

283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302
  def FlushSendQueue(self):
    """Send out all pending requests.

    Can be used for synchronous client use.

    """
    while self._socket.writable():
      self._socket.handle_write()

  def ReceiveReply(self, timeout=1):
    """Receive one reply.

    @type timeout: float
    @param timeout: how long to wait for the reply
    @rtype: boolean
    @return: True if some data has been handled, False otherwise

    """
    return self._socket.process_next_packet(timeout=timeout)

303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379
  @staticmethod
  def _NeededReplies(peer_cnt):
    """Compute the minimum safe number of replies for a query.

    The algorithm is designed to work well for both small and big
    number of peers:
        - for less than three, we require all responses
        - for less than five, we allow one miss
        - otherwise, half the number plus one

    This guarantees that we progress monotonically: 1->1, 2->2, 3->2,
    4->2, 5->3, 6->3, 7->4, etc.

    @type peer_cnt: int
    @param peer_cnt: the number of peers contacted
    @rtype: int
    @return: the number of replies which should give a safe coverage

    """
    if peer_cnt < 3:
      return peer_cnt
    elif peer_cnt < 5:
      return peer_cnt - 1
    else:
      return int(peer_cnt/2) + 1

  def WaitForReply(self, salt, timeout=constants.CONFD_CLIENT_EXPIRE_TIMEOUT):
    """Wait for replies to a given request.

    This method will wait until either the timeout expires or a
    minimum number (computed using L{_NeededReplies}) of replies are
    received for the given salt. It is useful when doing synchronous
    calls to this library.

    @param salt: the salt of the request we want responses for
    @param timeout: the maximum timeout (should be less or equal to
        L{ganeti.constants.CONFD_CLIENT_EXPIRE_TIMEOUT}
    @rtype: tuple
    @return: a tuple of (timed_out, sent_cnt, recv_cnt); if the
        request is unknown, timed_out will be true and the counters
        will be zero

    """
    def _CheckResponse():
      if salt not in self._requests:
        # expired?
        if self._logger:
          self._logger.debug("Discarding unknown/expired request: %s" % salt)
        return MISSING
      rq = self._requests[salt]
      if len(rq.rcvd) >= expected:
        # already got all replies
        return (False, len(rq.sent), len(rq.rcvd))
      # else wait, using default timeout
      self.ReceiveReply()
      raise utils.RetryAgain()

    MISSING = (True, 0, 0)

    if salt not in self._requests:
      return MISSING
    # extend the expire time with the current timeout, so that we
    # don't get the request expired from under us
    rq = self._requests[salt]
    rq.expiry += timeout
    sent = len(rq.sent)
    expected = self._NeededReplies(sent)

    try:
      return utils.Retry(_CheckResponse, 0, timeout)
    except utils.RetryTimeout:
      if salt in self._requests:
        rq = self._requests[salt]
        return (True, len(rq.sent), len(rq.rcvd))
      else:
        return MISSING

380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409

# UPCALL_REPLY: server reply upcall
# has all ConfdUpcallPayload fields populated
UPCALL_REPLY = 1
# UPCALL_EXPIRE: internal library request expire
# has only salt, type, orig_request and extra_args
UPCALL_EXPIRE = 2
CONFD_UPCALL_TYPES = frozenset([
  UPCALL_REPLY,
  UPCALL_EXPIRE,
  ])


class ConfdUpcallPayload(objects.ConfigObject):
  """Callback argument for confd replies

  @type salt: string
  @ivar salt: salt associated with the query
  @type type: one of confd.client.CONFD_UPCALL_TYPES
  @ivar type: upcall type (server reply, expired request, ...)
  @type orig_request: L{objects.ConfdRequest}
  @ivar orig_request: original request
  @type server_reply: L{objects.ConfdReply}
  @ivar server_reply: server reply
  @type server_ip: string
  @ivar server_ip: answering server ip address
  @type server_port: int
  @ivar server_port: answering server port
  @type extra_args: any
  @ivar extra_args: 'args' argument of the SendRequest function
410 411
  @type client: L{ConfdClient}
  @ivar client: current confd client instance
412 413 414 415 416 417 418 419 420 421

  """
  __slots__ = [
    "salt",
    "type",
    "orig_request",
    "server_reply",
    "server_ip",
    "server_port",
    "extra_args",
422
    "client",
423
    ]
Guido Trotter's avatar
Guido Trotter committed
424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441


class ConfdClientRequest(objects.ConfdRequest):
  """This is the client-side version of ConfdRequest.

  This version of the class helps creating requests, on the client side, by
  filling in some default values.

  """
  def __init__(self, **kwargs):
    objects.ConfdRequest.__init__(self, **kwargs)
    if not self.rsalt:
      self.rsalt = utils.NewUUID()
    if not self.protocol:
      self.protocol = constants.CONFD_PROTOCOL_VERSION
    if self.type not in constants.CONFD_REQS:
      raise errors.ConfdClientError("Invalid request type")

Guido Trotter's avatar
Guido Trotter committed
442 443 444 445

class ConfdFilterCallback:
  """Callback that calls another callback, but filters duplicate results.

446 447 448 449 450 451
  @ivar consistent: a dictionary indexed by salt; for each salt, if
      all responses ware identical, this will be True; this is the
      expected state on a healthy cluster; on inconsistent or
      partitioned clusters, this might be False, if we see answers
      with the same serial but different contents

Guido Trotter's avatar
Guido Trotter committed
452 453 454 455 456 457
  """
  def __init__(self, callback, logger=None):
    """Constructor for ConfdFilterCallback

    @type callback: f(L{ConfdUpcallPayload})
    @param callback: function to call when getting answers
458
    @type logger: logging.Logger
459
    @param logger: optional logger for internal conditions
Guido Trotter's avatar
Guido Trotter committed
460 461 462 463 464 465 466 467 468

    """
    if not callable(callback):
      raise errors.ProgrammerError("callback must be callable")

    self._callback = callback
    self._logger = logger
    # answers contains a dict of salt -> answer
    self._answers = {}
469
    self.consistent = {}
Guido Trotter's avatar
Guido Trotter committed
470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491

  def _LogFilter(self, salt, new_reply, old_reply):
    if not self._logger:
      return

    if new_reply.serial > old_reply.serial:
      self._logger.debug("Filtering confirming answer, with newer"
                         " serial for query %s" % salt)
    elif new_reply.serial == old_reply.serial:
      if new_reply.answer != old_reply.answer:
        self._logger.warning("Got incoherent answers for query %s"
                             " (serial: %s)" % (salt, new_reply.serial))
      else:
        self._logger.debug("Filtering confirming answer, with same"
                           " serial for query %s" % salt)
    else:
      self._logger.debug("Filtering outdated answer for query %s"
                         " serial: (%d < %d)" % (salt, old_reply.serial,
                                                 new_reply.serial))

  def _HandleExpire(self, up):
    # if we have no answer we have received none, before the expiration.
492 493
    if up.salt in self._answers:
      del self._answers[up.salt]
494 495
    if up.salt in self.consistent:
      del self.consistent[up.salt]
Guido Trotter's avatar
Guido Trotter committed
496 497 498 499 500 501 502 503 504 505 506

  def _HandleReply(self, up):
    """Handle a single confd reply, and decide whether to filter it.

    @rtype: boolean
    @return: True if the reply should be filtered, False if it should be passed
             on to the up-callback

    """
    filter_upcall = False
    salt = up.salt
507 508
    if salt not in self.consistent:
      self.consistent[salt] = True
Guido Trotter's avatar
Guido Trotter committed
509 510 511 512 513 514 515 516 517 518 519 520 521 522
    if salt not in self._answers:
      # first answer for a query (don't filter, and record)
      self._answers[salt] = up.server_reply
    elif up.server_reply.serial > self._answers[salt].serial:
      # newer answer (record, and compare contents)
      old_answer = self._answers[salt]
      self._answers[salt] = up.server_reply
      if up.server_reply.answer == old_answer.answer:
        # same content (filter) (version upgrade was unrelated)
        filter_upcall = True
        self._LogFilter(salt, up.server_reply, old_answer)
      # else: different content, pass up a second answer
    else:
      # older or same-version answer (duplicate or outdated, filter)
523 524
      if (up.server_reply.serial == self._answers[salt].serial and
          up.server_reply.answer != self._answers[salt].answer):
525
        self.consistent[salt] = False
Guido Trotter's avatar
Guido Trotter committed
526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545
      filter_upcall = True
      self._LogFilter(salt, up.server_reply, self._answers[salt])

    return filter_upcall

  def __call__(self, up):
    """Filtering callback

    @type up: L{ConfdUpcallPayload}
    @param up: upper callback

    """
    filter_upcall = False
    if up.type == UPCALL_REPLY:
      filter_upcall = self._HandleReply(up)
    elif up.type == UPCALL_EXPIRE:
      self._HandleExpire(up)

    if not filter_upcall:
      self._callback(up)
Guido Trotter's avatar
Guido Trotter committed
546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607


class ConfdCountingCallback:
  """Callback that calls another callback, and counts the answers

  """
  def __init__(self, callback, logger=None):
    """Constructor for ConfdCountingCallback

    @type callback: f(L{ConfdUpcallPayload})
    @param callback: function to call when getting answers
    @type logger: logging.Logger
    @param logger: optional logger for internal conditions

    """
    if not callable(callback):
      raise errors.ProgrammerError("callback must be callable")

    self._callback = callback
    self._logger = logger
    # answers contains a dict of salt -> count
    self._answers = {}

  def RegisterQuery(self, salt):
    if salt in self._answers:
      raise errors.ProgrammerError("query already registered")
    self._answers[salt] = 0

  def AllAnswered(self):
    """Have all the registered queries received at least an answer?

    """
    return utils.all(self._answers.values())

  def _HandleExpire(self, up):
    # if we have no answer we have received none, before the expiration.
    if up.salt in self._answers:
      del self._answers[up.salt]

  def _HandleReply(self, up):
    """Handle a single confd reply, and decide whether to filter it.

    @rtype: boolean
    @return: True if the reply should be filtered, False if it should be passed
             on to the up-callback

    """
    if up.salt in self._answers:
      self._answers[up.salt] += 1

  def __call__(self, up):
    """Filtering callback

    @type up: L{ConfdUpcallPayload}
    @param up: upper callback

    """
    if up.type == UPCALL_REPLY:
      self._HandleReply(up)
    elif up.type == UPCALL_EXPIRE:
      self._HandleExpire(up)
    self._callback(up)
608

609

610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656
class StoreResultCallback:
  """Callback that simply stores the most recent answer.

  @ivar _answers: dict of salt to (have_answer, reply)

  """
  _NO_KEY = (False, None)

  def __init__(self):
    """Constructor for StoreResultCallback

    """
    # answers contains a dict of salt -> best result
    self._answers = {}

  def GetResponse(self, salt):
    """Return the best match for a salt

    """
    return self._answers.get(salt, self._NO_KEY)

  def _HandleExpire(self, up):
    """Expiration handler.

    """
    if up.salt in self._answers and self._answers[up.salt] == self._NO_KEY:
      del self._answers[up.salt]

  def _HandleReply(self, up):
    """Handle a single confd reply, and decide whether to filter it.

    """
    self._answers[up.salt] = (True, up)

  def __call__(self, up):
    """Filtering callback

    @type up: L{ConfdUpcallPayload}
    @param up: upper callback

    """
    if up.type == UPCALL_REPLY:
      self._HandleReply(up)
    elif up.type == UPCALL_EXPIRE:
      self._HandleExpire(up)


657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672
def GetConfdClient(callback):
  """Return a client configured using the given callback.

  This is handy to abstract the MC list and HMAC key reading.

  @attention: This should only be called on nodes which are part of a
      cluster, since it depends on a valid (ganeti) data directory;
      for code running outside of a cluster, you need to create the
      client manually

  """
  ss = ssconf.SimpleStore()
  mc_file = ss.KeyToFilename(constants.SS_MASTER_CANDIDATES_IPS)
  mc_list = utils.ReadFile(mc_file).splitlines()
  hmac_key = utils.ReadFile(constants.CONFD_HMAC_KEY)
  return ConfdClient(hmac_key, mc_list, callback)