From 4138d39f9964173249db1a4e9afe1d06846b8fc7 Mon Sep 17 00:00:00 2001
From: Iustin Pop <iustin@google.com>
Date: Fri, 22 Oct 2010 14:29:47 +0200
Subject: [PATCH] Add a "safe" file wrapper over WriteFile

Signed-off-by: Iustin Pop <iustin@google.com>
Reviewed-by: Michael Hanselmann <hansmi@google.com>
---
 lib/utils.py                  | 25 +++++++++++++++++++++++++
 test/ganeti.utils_unittest.py | 15 +++++++++++++++
 2 files changed, 40 insertions(+)

diff --git a/lib/utils.py b/lib/utils.py
index aa84cb1f8..bdd86105c 100644
--- a/lib/utils.py
+++ b/lib/utils.py
@@ -1912,6 +1912,31 @@ def VerifyFileID(fi_disk, fi_ours):
   return (d1, i1) == (d2, i2) and m1 <= m2
 
 
+def SafeWriteFile(file_name, file_id, **kwargs):
+  """Wraper over L{WriteFile} that locks the target file.
+
+  By keeping the target file locked during WriteFile, we ensure that
+  cooperating writers will safely serialise access to the file.
+
+  @type file_name: str
+  @param file_name: the target filename
+  @type file_id: tuple
+  @param file_id: a result from L{GetFileID}
+
+  """
+  fd = os.open(file_name, os.O_RDONLY | os.O_CREAT)
+  try:
+    LockFile(fd)
+    if file_id is not None:
+      disk_id = GetFileID(fd=fd)
+      if not VerifyFileID(disk_id, file_id):
+        raise errors.LockError("Cannot overwrite file %s, it has been modified"
+                               " since last written" % file_name)
+    return WriteFile(file_name, **kwargs)
+  finally:
+    os.close(fd)
+
+
 def ReadOneLineFile(file_name, strict=False):
   """Return the first non-empty line from a file.
 
diff --git a/test/ganeti.utils_unittest.py b/test/ganeti.utils_unittest.py
index 7fc93db1f..2c46afcb1 100755
--- a/test/ganeti.utils_unittest.py
+++ b/test/ganeti.utils_unittest.py
@@ -2374,6 +2374,21 @@ class TestFileID(testutils.GanetiTestCase):
     finally:
       os.close(fd)
 
+  def testWriteFile(self):
+    name = self._CreateTempFile()
+    oldi = utils.GetFileID(path=name)
+    mtime = oldi[2]
+    os.utime(name, (mtime + 10, mtime + 10))
+    self.assertRaises(errors.LockError, utils.SafeWriteFile, name,
+                      oldi, data="")
+    os.utime(name, (mtime - 10, mtime - 10))
+    utils.SafeWriteFile(name, oldi, data="")
+    oldi = utils.GetFileID(path=name)
+    mtime = oldi[2]
+    os.utime(name, (mtime + 10, mtime + 10))
+    # this doesn't raise, since we passed None
+    utils.SafeWriteFile(name, None, data="")
+
 
 if __name__ == '__main__':
   testutils.GanetiTestProgram()
-- 
GitLab