utils.py 26.8 KB
Newer Older
Iustin Pop's avatar
Iustin Pop committed
1
#
Iustin Pop's avatar
Iustin Pop committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
#

# Copyright (C) 2006, 2007 Google Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
# 02110-1301, USA.


"""Ganeti small utilities
23

Iustin Pop's avatar
Iustin Pop committed
24
25
26
27
28
29
30
"""


import sys
import os
import sha
import time
31
import subprocess
Iustin Pop's avatar
Iustin Pop committed
32
33
34
35
import re
import socket
import tempfile
import shutil
36
import errno
37
import pwd
Guido Trotter's avatar
Guido Trotter committed
38
import itertools
39
40
41
42
import select
import fcntl

from cStringIO import StringIO
Iustin Pop's avatar
Iustin Pop committed
43
44
45

from ganeti import logger
from ganeti import errors
Iustin Pop's avatar
Iustin Pop committed
46
from ganeti import constants
Iustin Pop's avatar
Iustin Pop committed
47

48

Iustin Pop's avatar
Iustin Pop committed
49
50
51
_locksheld = []
_re_shell_unquoted = re.compile('^[-.,=:/_+@A-Za-z0-9]+$')

52
53
debug = False

54

Iustin Pop's avatar
Iustin Pop committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
class RunResult(object):
  """Simple class for holding the result of running external programs.

  Instance variables:
    exit_code: the exit code of the program, or None (if the program
               didn't exit())
    signal: numeric signal that caused the program to finish, or None
            (if the program wasn't terminated by a signal)
    stdout: the standard output of the program
    stderr: the standard error of the program
    failed: a Boolean value which is True in case the program was
            terminated by a signal or exited with a non-zero exit code
    fail_reason: a string detailing the termination reason

  """
  __slots__ = ["exit_code", "signal", "stdout", "stderr",
               "failed", "fail_reason", "cmd"]


  def __init__(self, exit_code, signal, stdout, stderr, cmd):
    self.cmd = cmd
    self.exit_code = exit_code
    self.signal = signal
    self.stdout = stdout
    self.stderr = stderr
    self.failed = (signal is not None or exit_code != 0)

    if self.signal is not None:
      self.fail_reason = "terminated by signal %s" % self.signal
    elif self.exit_code is not None:
      self.fail_reason = "exited with exit code %s" % self.exit_code
    else:
      self.fail_reason = "unable to determine termination reason"

89
90
91
92
    if debug and self.failed:
      logger.Debug("Command '%s' failed (%s); output: %s" %
                   (self.cmd, self.fail_reason, self.output))

Iustin Pop's avatar
Iustin Pop committed
93
94
95
96
97
98
99
100
101
102
103
  def _GetOutput(self):
    """Returns the combined stdout and stderr for easier usage.

    """
    return self.stdout + self.stderr

  output = property(_GetOutput, None, None, "Return full output")


def _GetLockFile(subsystem):
  """Compute the file name for a given lock name."""
Iustin Pop's avatar
Iustin Pop committed
104
  return "%s/ganeti_lock_%s" % (constants.LOCK_DIR, subsystem)
Iustin Pop's avatar
Iustin Pop committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133


def Lock(name, max_retries=None, debug=False):
  """Lock a given subsystem.

  In case the lock is already held by an alive process, the function
  will sleep indefintely and poll with a one second interval.

  When the optional integer argument 'max_retries' is passed with a
  non-zero value, the function will sleep only for this number of
  times, and then it will will raise a LockError if the lock can't be
  acquired. Passing in a negative number will cause only one try to
  get the lock. Passing a positive number will make the function retry
  for approximately that number of seconds.

  """
  lockfile = _GetLockFile(name)

  if name in _locksheld:
    raise errors.LockError('Lock "%s" already held!' % (name,))

  errcount = 0

  retries = 0
  while True:
    try:
      fd = os.open(lockfile, os.O_CREAT | os.O_EXCL | os.O_RDWR | os.O_SYNC)
      break
    except OSError, creat_err:
134
      if creat_err.errno != errno.EEXIST:
135
136
        raise errors.LockError("Can't create the lock file. Error '%s'." %
                               str(creat_err))
Iustin Pop's avatar
Iustin Pop committed
137
138
139
140
141
142

      try:
        pf = open(lockfile, 'r')
      except IOError, open_err:
        errcount += 1
        if errcount >= 5:
143
144
          raise errors.LockError("Lock file exists but cannot be opened."
                                 " Error: '%s'." % str(open_err))
Iustin Pop's avatar
Iustin Pop committed
145
146
147
148
149
150
        time.sleep(1)
        continue

      try:
        pid = int(pf.read())
      except ValueError:
151
        raise errors.LockError("Invalid pid string in %s" %
Iustin Pop's avatar
Iustin Pop committed
152
153
154
                               (lockfile,))

      if not IsProcessAlive(pid):
155
156
        raise errors.LockError("Stale lockfile %s for pid %d?" %
                               (lockfile, pid))
Iustin Pop's avatar
Iustin Pop committed
157
158

      if max_retries and max_retries <= retries:
159
160
        raise errors.LockError("Can't acquire lock during the specified"
                               " time, aborting.")
Iustin Pop's avatar
Iustin Pop committed
161
162
163
164
165
166
167
168
169
170
171
172
173
174
      if retries == 5 and (debug or sys.stdin.isatty()):
        logger.ToStderr("Waiting for '%s' lock from pid %d..." % (name, pid))

      time.sleep(1)
      retries += 1
      continue

  os.write(fd, '%d\n' % (os.getpid(),))
  os.close(fd)

  _locksheld.append(name)


def Unlock(name):
175
  """Unlock a given subsystem.
Iustin Pop's avatar
Iustin Pop committed
176

177
  """
Iustin Pop's avatar
Iustin Pop committed
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
  lockfile = _GetLockFile(name)

  try:
    fd = os.open(lockfile, os.O_RDONLY)
  except OSError:
    raise errors.LockError('Lock "%s" not held.' % (name,))

  f = os.fdopen(fd, 'r')
  pid_str = f.read()

  try:
    pid = int(pid_str)
  except ValueError:
    raise errors.LockError('Unable to determine PID of locking process.')

  if pid != os.getpid():
    raise errors.LockError('Lock not held by me (%d != %d)' %
                           (os.getpid(), pid,))

  os.unlink(lockfile)
  _locksheld.remove(name)


def LockCleanup():
202
  """Remove all locks.
Iustin Pop's avatar
Iustin Pop committed
203

204
  """
Iustin Pop's avatar
Iustin Pop committed
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
  for lock in _locksheld:
    Unlock(lock)


def RunCmd(cmd):
  """Execute a (shell) command.

  The command should not read from its standard input, as it will be
  closed.

  Args:
    cmd: command to run. (str)

  Returns: `RunResult` instance

  """
  if isinstance(cmd, list):
    cmd = [str(val) for val in cmd]
223
224
225
226
227
    strcmd = " ".join(cmd)
    shell = False
  else:
    strcmd = cmd
    shell = True
228
229
  env = os.environ.copy()
  env["LC_ALL"] = "C"
230
  poller = select.poll()
231
232
233
234
  child = subprocess.Popen(cmd, shell=shell,
                           stderr=subprocess.PIPE,
                           stdout=subprocess.PIPE,
                           stdin=subprocess.PIPE,
235
                           close_fds=True, env=env)
236
237

  child.stdin.close()
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
  poller.register(child.stdout, select.POLLIN)
  poller.register(child.stderr, select.POLLIN)
  out = StringIO()
  err = StringIO()
  fdmap = {
    child.stdout.fileno(): (out, child.stdout),
    child.stderr.fileno(): (err, child.stderr),
    }
  for fd in fdmap:
    status = fcntl.fcntl(fd, fcntl.F_GETFL)
    fcntl.fcntl(fd, fcntl.F_SETFL, status | os.O_NONBLOCK)

  while fdmap:
    for fd, event in poller.poll():
      if event & select.POLLIN or event & select.POLLPRI:
        data = fdmap[fd][1].read()
        # no data from read signifies EOF (the same as POLLHUP)
        if not data:
          poller.unregister(fd)
          del fdmap[fd]
          continue
        fdmap[fd][0].write(data)
      if (event & select.POLLNVAL or event & select.POLLHUP or
          event & select.POLLERR):
        poller.unregister(fd)
        del fdmap[fd]

  out = out.getvalue()
  err = err.getvalue()
Iustin Pop's avatar
Iustin Pop committed
267
268

  status = child.wait()
269
270
  if status >= 0:
    exitcode = status
Iustin Pop's avatar
Iustin Pop committed
271
272
273
    signal = None
  else:
    exitcode = None
274
    signal = -status
Iustin Pop's avatar
Iustin Pop committed
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312

  return RunResult(exitcode, signal, out, err, strcmd)


def RunCmdUnlocked(cmd):
  """Execute a shell command without the 'cmd' lock.

  This variant of `RunCmd()` drops the 'cmd' lock before running the
  command and re-aquires it afterwards, thus it can be used to call
  other ganeti commands.

  The argument and return values are the same as for the `RunCmd()`
  function.

  Args:
    cmd - command to run. (str)

  Returns:
    `RunResult`

  """
  Unlock('cmd')
  ret = RunCmd(cmd)
  Lock('cmd')

  return ret


def RemoveFile(filename):
  """Remove a file ignoring some errors.

  Remove a file, ignoring non-existing ones or directories. Other
  errors are passed.

  """
  try:
    os.unlink(filename)
  except OSError, err:
313
    if err.errno not in (errno.ENOENT, errno.EISDIR):
Iustin Pop's avatar
Iustin Pop committed
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
      raise


def _FingerprintFile(filename):
  """Compute the fingerprint of a file.

  If the file does not exist, a None will be returned
  instead.

  Args:
    filename - Filename (str)

  """
  if not (os.path.exists(filename) and os.path.isfile(filename)):
    return None

  f = open(filename)

  fp = sha.sha()
  while True:
    data = f.read(4096)
    if not data:
      break

    fp.update(data)

  return fp.hexdigest()


def FingerprintFiles(files):
  """Compute fingerprints for a list of files.

  Args:
    files - array of filenames.  ( [str, ...] )

  Return value:
    dictionary of filename: fingerprint for the files that exist

  """
  ret = {}

  for filename in files:
    cksum = _FingerprintFile(filename)
    if cksum:
      ret[filename] = cksum

  return ret


def CheckDict(target, template, logname=None):
  """Ensure a dictionary has a required set of keys.

  For the given dictionaries `target` and `template`, ensure target
  has all the keys from template. Missing keys are added with values
  from template.

  Args:
    target   - the dictionary to check
    template - template dictionary
    logname  - a caller-chosen string to identify the debug log
               entry; if None, no logging will be done

  Returns value:
    None

  """
  missing = []
  for k in template:
    if k not in target:
      missing.append(k)
      target[k] = template[k]

  if missing and logname:
    logger.Debug('%s missing keys %s' %
                 (logname, ', '.join(missing)))


def IsProcessAlive(pid):
  """Check if a given pid exists on the system.

  Returns: true or false, depending on if the pid exists or not

  Remarks: zombie processes treated as not alive

  """
  try:
    f = open("/proc/%d/status" % pid)
  except IOError, err:
402
    if err.errno in (errno.ENOENT, errno.ENOTDIR):
Iustin Pop's avatar
Iustin Pop committed
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
      return False

  alive = True
  try:
    data = f.readlines()
    if len(data) > 1:
      state = data[1].split()
      if len(state) > 1 and state[1] == "Z":
        alive = False
  finally:
    f.close()

  return alive


def MatchNameComponent(key, name_list):
  """Try to match a name against a list.

  This function will try to match a name like test1 against a list
  like ['test1.example.com', 'test2.example.com', ...]. Against this
  list, 'test1' as well as 'test1.example' will match, but not
  'test1.ex'. A multiple match will be considered as no match at all
  (e.g. 'test1' against ['test1.example.com', 'test1.example.org']).

  Args:
    key: the name to be searched
    name_list: the list of strings against which to search the key

  Returns:
    None if there is no match *or* if there are multiple matches
    otherwise the element from the list which matches

  """
  mo = re.compile("^%s(\..*)?$" % re.escape(key))
  names_filtered = [name for name in name_list if mo.match(name) is not None]
  if len(names_filtered) != 1:
    return None
  return names_filtered[0]


443
class HostInfo:
444
  """Class implementing resolver and hostname functionality
445
446

  """
447
  def __init__(self, name=None):
448
449
    """Initialize the host name object.

450
451
    If the name argument is not passed, it will use this system's
    name.
452
453

    """
454
455
456
457
458
    if name is None:
      name = self.SysName()

    self.query = name
    self.name, self.aliases, self.ipaddrs = self.LookupHostname(name)
459
460
    self.ip = self.ipaddrs[0]

461
462
463
464
465
466
  def ShortName(self):
    """Returns the hostname without domain.

    """
    return self.name.split('.')[0]

467
468
469
  @staticmethod
  def SysName():
    """Return the current system's name.
470

471
    This is simply a wrapper over socket.gethostname()
Iustin Pop's avatar
Iustin Pop committed
472

473
474
    """
    return socket.gethostname()
Iustin Pop's avatar
Iustin Pop committed
475

476
477
478
  @staticmethod
  def LookupHostname(hostname):
    """Look up hostname
Iustin Pop's avatar
Iustin Pop committed
479

480
481
482
483
484
485
486
487
488
489
490
491
492
    Args:
      hostname: hostname to look up

    Returns:
      a tuple (name, aliases, ipaddrs) as returned by socket.gethostbyname_ex
      in case of errors in resolving, we raise a ResolverError

    """
    try:
      result = socket.gethostbyname_ex(hostname)
    except socket.gaierror, err:
      # hostname not found in DNS
      raise errors.ResolverError(hostname, err.args[0], err.args[1])
Iustin Pop's avatar
Iustin Pop committed
493

494
    return result
Iustin Pop's avatar
Iustin Pop committed
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671


def ListVolumeGroups():
  """List volume groups and their size

  Returns:
     Dictionary with keys volume name and values the size of the volume

  """
  command = "vgs --noheadings --units m --nosuffix -o name,size"
  result = RunCmd(command)
  retval = {}
  if result.failed:
    return retval

  for line in result.stdout.splitlines():
    try:
      name, size = line.split()
      size = int(float(size))
    except (IndexError, ValueError), err:
      logger.Error("Invalid output from vgs (%s): %s" % (err, line))
      continue

    retval[name] = size

  return retval


def BridgeExists(bridge):
  """Check whether the given bridge exists in the system

  Returns:
     True if it does, false otherwise.

  """
  return os.path.isdir("/sys/class/net/%s/bridge" % bridge)


def NiceSort(name_list):
  """Sort a list of strings based on digit and non-digit groupings.

  Given a list of names ['a1', 'a10', 'a11', 'a2'] this function will
  sort the list in the logical order ['a1', 'a2', 'a10', 'a11'].

  The sort algorithm breaks each name in groups of either only-digits
  or no-digits. Only the first eight such groups are considered, and
  after that we just use what's left of the string.

  Return value
    - a copy of the list sorted according to our algorithm

  """
  _SORTER_BASE = "(\D+|\d+)"
  _SORTER_FULL = "^%s%s?%s?%s?%s?%s?%s?%s?.*$" % (_SORTER_BASE, _SORTER_BASE,
                                                  _SORTER_BASE, _SORTER_BASE,
                                                  _SORTER_BASE, _SORTER_BASE,
                                                  _SORTER_BASE, _SORTER_BASE)
  _SORTER_RE = re.compile(_SORTER_FULL)
  _SORTER_NODIGIT = re.compile("^\D*$")
  def _TryInt(val):
    """Attempts to convert a variable to integer."""
    if val is None or _SORTER_NODIGIT.match(val):
      return val
    rval = int(val)
    return rval

  to_sort = [([_TryInt(grp) for grp in _SORTER_RE.match(name).groups()], name)
             for name in name_list]
  to_sort.sort()
  return [tup[1] for tup in to_sort]


def CheckDaemonAlive(pid_file, process_string):
  """Check wether the specified daemon is alive.

  Args:
   - pid_file: file to read the daemon pid from, the file is
               expected to contain only a single line containing
               only the PID
   - process_string: a substring that we expect to find in
                     the command line of the daemon process

  Returns:
   - True if the daemon is judged to be alive (that is:
      - the PID file exists, is readable and contains a number
      - a process of the specified PID is running
      - that process contains the specified string in its
        command line
      - the process is not in state Z (zombie))
   - False otherwise

  """
  try:
    pid_file = file(pid_file, 'r')
    try:
      pid = int(pid_file.readline())
    finally:
      pid_file.close()

    cmdline_file_path = "/proc/%s/cmdline" % (pid)
    cmdline_file = open(cmdline_file_path, 'r')
    try:
      cmdline = cmdline_file.readline()
    finally:
      cmdline_file.close()

    if not process_string in cmdline:
      return False

    stat_file_path =  "/proc/%s/stat" % (pid)
    stat_file = open(stat_file_path, 'r')
    try:
      process_state = stat_file.readline().split()[2]
    finally:
      stat_file.close()

    if process_state == 'Z':
      return False

  except (IndexError, IOError, ValueError):
    return False

  return True


def TryConvert(fn, val):
  """Try to convert a value ignoring errors.

  This function tries to apply function `fn` to `val`. If no
  ValueError or TypeError exceptions are raised, it will return the
  result, else it will return the original value. Any other exceptions
  are propagated to the caller.

  """
  try:
    nv = fn(val)
  except (ValueError, TypeError), err:
    nv = val
  return nv


def IsValidIP(ip):
  """Verifies the syntax of an IP address.

  This function checks if the ip address passes is valid or not based
  on syntax (not ip range, class calculations or anything).

  """
  unit = "(0|[1-9]\d{0,2})"
  return re.match("^%s\.%s\.%s\.%s$" % (unit, unit, unit, unit), ip)


def IsValidShellParam(word):
  """Verifies is the given word is safe from the shell's p.o.v.

  This means that we can pass this to a command via the shell and be
  sure that it doesn't alter the command line and is passed as such to
  the actual command.

  Note that we are overly restrictive here, in order to be on the safe
  side.

  """
  return bool(re.match("^[-a-zA-Z0-9._+/:%@]+$", word))


def BuildShellCmd(template, *args):
  """Build a safe shell command line from the given arguments.

  This function will check all arguments in the args list so that they
  are valid shell parameters (i.e. they don't contain shell
  metacharaters). If everything is ok, it will return the result of
  template % args.

  """
  for word in args:
    if not IsValidShellParam(word):
672
673
      raise errors.ProgrammerError("Shell argument '%s' contains"
                                   " invalid characters" % word)
Iustin Pop's avatar
Iustin Pop committed
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
  return template % args


def FormatUnit(value):
  """Formats an incoming number of MiB with the appropriate unit.

  Value needs to be passed as a numeric type. Return value is always a string.

  """
  if value < 1024:
    return "%dM" % round(value, 0)

  elif value < (1024 * 1024):
    return "%0.1fG" % round(float(value) / 1024, 1)

  else:
    return "%0.1fT" % round(float(value) / 1024 / 1024, 1)


def ParseUnit(input_string):
  """Tries to extract number and scale from the given string.

  Input must be in the format NUMBER+ [DOT NUMBER+] SPACE* [UNIT]. If no unit
  is specified, it defaults to MiB. Return value is always an int in MiB.

  """
  m = re.match('^([.\d]+)\s*([a-zA-Z]+)?$', input_string)
  if not m:
702
    raise errors.UnitParseError("Invalid format")
Iustin Pop's avatar
Iustin Pop committed
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722

  value = float(m.groups()[0])

  unit = m.groups()[1]
  if unit:
    lcunit = unit.lower()
  else:
    lcunit = 'm'

  if lcunit in ('m', 'mb', 'mib'):
    # Value already in MiB
    pass

  elif lcunit in ('g', 'gb', 'gib'):
    value *= 1024

  elif lcunit in ('t', 'tb', 'tib'):
    value *= 1024 * 1024

  else:
723
    raise errors.UnitParseError("Unknown unit: %s" % unit)
Iustin Pop's avatar
Iustin Pop committed
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774

  # Make sure we round up
  if int(value) < value:
    value += 1

  # Round up to the next multiple of 4
  value = int(value)
  if value % 4:
    value += 4 - value % 4

  return value


def AddAuthorizedKey(file_name, key):
  """Adds an SSH public key to an authorized_keys file.

  Args:
    file_name: Path to authorized_keys file
    key: String containing key
  """
  key_fields = key.split()

  f = open(file_name, 'a+')
  try:
    nl = True
    for line in f:
      # Ignore whitespace changes
      if line.split() == key_fields:
        break
      nl = line.endswith('\n')
    else:
      if not nl:
        f.write("\n")
      f.write(key.rstrip('\r\n'))
      f.write("\n")
      f.flush()
  finally:
    f.close()


def RemoveAuthorizedKey(file_name, key):
  """Removes an SSH public key from an authorized_keys file.

  Args:
    file_name: Path to authorized_keys file
    key: String containing key
  """
  key_fields = key.split()

  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
  try:
775
    out = os.fdopen(fd, 'w')
Iustin Pop's avatar
Iustin Pop committed
776
    try:
777
778
779
780
781
782
      f = open(file_name, 'r')
      try:
        for line in f:
          # Ignore whitespace changes while comparing lines
          if line.split() != key_fields:
            out.write(line)
783
784
785
786
787
788
789
790
791
792
793
794

        out.flush()
        os.rename(tmpname, file_name)
      finally:
        f.close()
    finally:
      out.close()
  except:
    RemoveFile(tmpname)
    raise


795
796
def SetEtcHostsEntry(file_name, ip, hostname, aliases):
  """Sets the name of an IP address and hostname in /etc/hosts.
797
798

  """
799
800
801
  # Ensure aliases are unique
  aliases = UniqueSequence([hostname] + aliases)[1:]

802
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
803
  try:
804
805
806
807
808
809
810
    out = os.fdopen(fd, 'w')
    try:
      f = open(file_name, 'r')
      try:
        written = False
        for line in f:
          fields = line.split()
811
          if fields and not fields[0].startswith('#') and ip == fields[0]:
812
813
814
            continue
          out.write(line)

815
        out.write("%s\t%s" % (ip, hostname))
816
817
818
819
820
        if aliases:
          out.write(" %s" % ' '.join(aliases))
        out.write('\n')

        out.flush()
821
        os.fsync(out)
822
823
824
825
826
827
828
829
        os.rename(tmpname, file_name)
      finally:
        f.close()
    finally:
      out.close()
  except:
    RemoveFile(tmpname)
    raise
830
831
832


def RemoveEtcHostsEntry(file_name, hostname):
833
  """Removes a hostname from /etc/hosts.
834

835
  IP addresses without names are removed from the file.
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
  """
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
  try:
    out = os.fdopen(fd, 'w')
    try:
      f = open(file_name, 'r')
      try:
        for line in f:
          fields = line.split()
          if len(fields) > 1 and not fields[0].startswith('#'):
            names = fields[1:]
            if hostname in names:
              while hostname in names:
                names.remove(hostname)
              if names:
851
                out.write("%s %s\n" % (fields[0], ' '.join(names)))
852
853
854
              continue

          out.write(line)
855
856

        out.flush()
857
        os.fsync(out)
858
859
860
        os.rename(tmpname, file_name)
      finally:
        f.close()
Iustin Pop's avatar
Iustin Pop committed
861
    finally:
862
863
864
865
      out.close()
  except:
    RemoveFile(tmpname)
    raise
Iustin Pop's avatar
Iustin Pop committed
866
867
868
869
870
871
872
873
874


def CreateBackup(file_name):
  """Creates a backup of a file.

  Returns: the path to the newly created backup file.

  """
  if not os.path.isfile(file_name):
875
876
    raise errors.ProgrammerError("Can't make a backup of a non-file '%s'" %
                                file_name)
Iustin Pop's avatar
Iustin Pop committed
877

878
  prefix = '%s.backup-%d.' % (os.path.basename(file_name), int(time.time()))
Iustin Pop's avatar
Iustin Pop committed
879
  dir_name = os.path.dirname(file_name)
880
881
882

  fsrc = open(file_name, 'rb')
  try:
Iustin Pop's avatar
Iustin Pop committed
883
    (fd, backup_name) = tempfile.mkstemp(prefix=prefix, dir=dir_name)
884
885
886
887
888
889
890
891
    fdst = os.fdopen(fd, 'wb')
    try:
      shutil.copyfileobj(fsrc, fdst)
    finally:
      fdst.close()
  finally:
    fsrc.close()

Iustin Pop's avatar
Iustin Pop committed
892
893
894
895
896
  return backup_name


def ShellQuote(value):
  """Quotes shell argument according to POSIX.
897

Iustin Pop's avatar
Iustin Pop committed
898
899
900
901
902
903
904
905
906
907
908
909
  """
  if _re_shell_unquoted.match(value):
    return value
  else:
    return "'%s'" % value.replace("'", "'\\''")


def ShellQuoteArgs(args):
  """Quotes all given shell arguments and concatenates using spaces.

  """
  return ' '.join([ShellQuote(i) for i in args])
910
911


912

913
def TcpPing(source, target, port, timeout=10, live_port_needed=False):
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
  """Simple ping implementation using TCP connect(2).

  Try to do a TCP connect(2) from the specified source IP to the specified
  target IP and the specified target port. If live_port_needed is set to true,
  requires the remote end to accept the connection. The timeout is specified
  in seconds and defaults to 10 seconds

  """
  sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

  sucess = False

  try:
    sock.bind((source, 0))
  except socket.error, (errcode, errstring):
929
    if errcode == errno.EADDRNOTAVAIL:
930
931
932
933
934
935
936
937
938
939
940
      success = False

  sock.settimeout(timeout)

  try:
    sock.connect((target, port))
    sock.close()
    success = True
  except socket.timeout:
    success = False
  except socket.error, (errcode, errstring):
941
    success = (not live_port_needed) and (errcode == errno.ECONNREFUSED)
942
943

  return success
944
945
946
947
948
949


def ListVisibleFiles(path):
  """Returns a list of all visible files in a directory.

  """
950
951
952
  files = [i for i in os.listdir(path) if not i.startswith(".")]
  files.sort()
  return files
953
954


955
956
957
958
959
960
def GetHomeDir(user, default=None):
  """Try to get the homedir of the given user.

  The user can be passed either as a string (denoting the name) or as
  an integer (denoting the user id). If the user is not found, the
  'default' argument is returned, which defaults to None.
961
962
963

  """
  try:
964
965
966
967
968
969
970
    if isinstance(user, basestring):
      result = pwd.getpwnam(user)
    elif isinstance(user, (int, long)):
      result = pwd.getpwuid(user)
    else:
      raise errors.ProgrammerError("Invalid type passed to GetHomeDir (%s)" %
                                   type(user))
971
972
973
  except KeyError:
    return default
  return result.pw_dir
974
975


976
def NewUUID():
977
978
979
980
981
982
983
984
  """Returns a random UUID.

  """
  f = open("/proc/sys/kernel/random/uuid", "r")
  try:
    return f.read(128).rstrip("\n")
  finally:
    f.close()
Iustin Pop's avatar
Iustin Pop committed
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035


def WriteFile(file_name, fn=None, data=None,
              mode=None, uid=-1, gid=-1,
              atime=None, mtime=None):
  """(Over)write a file atomically.

  The file_name and either fn (a function taking one argument, the
  file descriptor, and which should write the data to it) or data (the
  contents of the file) must be passed. The other arguments are
  optional and allow setting the file mode, owner and group, and the
  mtime/atime of the file.

  If the function doesn't raise an exception, it has succeeded and the
  target file has the new contents. If the file has raised an
  exception, an existing target file should be unmodified and the
  temporary file should be removed.

  """
  if not os.path.isabs(file_name):
    raise errors.ProgrammerError("Path passed to WriteFile is not"
                                 " absolute: '%s'" % file_name)

  if [fn, data].count(None) != 1:
    raise errors.ProgrammerError("fn or data required")

  if [atime, mtime].count(None) == 1:
    raise errors.ProgrammerError("Both atime and mtime must be either"
                                 " set or None")


  dir_name, base_name = os.path.split(file_name)
  fd, new_name = tempfile.mkstemp('.new', base_name, dir_name)
  # here we need to make sure we remove the temp file, if any error
  # leaves it in place
  try:
    if uid != -1 or gid != -1:
      os.chown(new_name, uid, gid)
    if mode:
      os.chmod(new_name, mode)
    if data is not None:
      os.write(fd, data)
    else:
      fn(fd)
    os.fsync(fd)
    if atime is not None and mtime is not None:
      os.utime(new_name, (atime, mtime))
    os.rename(new_name, file_name)
  finally:
    os.close(fd)
    RemoveFile(new_name)
Guido Trotter's avatar
Guido Trotter committed
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049


def all(seq, pred=bool):
  "Returns True if pred(x) is True for every element in the iterable"
  for elem in itertools.ifilterfalse(pred, seq):
    return False
  return True


def any(seq, pred=bool):
  "Returns True if pred(x) is True for at least one element in the iterable"
  for elem in itertools.ifilter(pred, seq):
    return True
  return False
1050
1051
1052
1053
1054
1055
1056
1057
1058


def UniqueSequence(seq):
  """Returns a list with unique elements.

  Element order is preserved.
  """
  seen = set()
  return [i for i in seq if i not in seen and not seen.add(i)]
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068


def IsValidMac(mac):
  """Predicate to check if a MAC address is valid.

  Checks wether the supplied MAC address is formally correct, only
  accepts colon separated format.
  """
  mac_check = re.compile("^([0-9a-f]{2}(:|$)){6}$")
  return mac_check.match(mac) is not None
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078


def TestDelay(duration):
  """Sleep for a fixed amount of time.

  """
  if duration < 0:
    return False
  time.sleep(duration)
  return True