From 4138d39f9964173249db1a4e9afe1d06846b8fc7 Mon Sep 17 00:00:00 2001 From: Iustin Pop <iustin@google.com> Date: Fri, 22 Oct 2010 14:29:47 +0200 Subject: [PATCH] Add a "safe" file wrapper over WriteFile Signed-off-by: Iustin Pop <iustin@google.com> Reviewed-by: Michael Hanselmann <hansmi@google.com> --- lib/utils.py | 25 +++++++++++++++++++++++++ test/ganeti.utils_unittest.py | 15 +++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/lib/utils.py b/lib/utils.py index aa84cb1f8..bdd86105c 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -1912,6 +1912,31 @@ def VerifyFileID(fi_disk, fi_ours): return (d1, i1) == (d2, i2) and m1 <= m2 +def SafeWriteFile(file_name, file_id, **kwargs): + """Wraper over L{WriteFile} that locks the target file. + + By keeping the target file locked during WriteFile, we ensure that + cooperating writers will safely serialise access to the file. + + @type file_name: str + @param file_name: the target filename + @type file_id: tuple + @param file_id: a result from L{GetFileID} + + """ + fd = os.open(file_name, os.O_RDONLY | os.O_CREAT) + try: + LockFile(fd) + if file_id is not None: + disk_id = GetFileID(fd=fd) + if not VerifyFileID(disk_id, file_id): + raise errors.LockError("Cannot overwrite file %s, it has been modified" + " since last written" % file_name) + return WriteFile(file_name, **kwargs) + finally: + os.close(fd) + + def ReadOneLineFile(file_name, strict=False): """Return the first non-empty line from a file. diff --git a/test/ganeti.utils_unittest.py b/test/ganeti.utils_unittest.py index 7fc93db1f..2c46afcb1 100755 --- a/test/ganeti.utils_unittest.py +++ b/test/ganeti.utils_unittest.py @@ -2374,6 +2374,21 @@ class TestFileID(testutils.GanetiTestCase): finally: os.close(fd) + def testWriteFile(self): + name = self._CreateTempFile() + oldi = utils.GetFileID(path=name) + mtime = oldi[2] + os.utime(name, (mtime + 10, mtime + 10)) + self.assertRaises(errors.LockError, utils.SafeWriteFile, name, + oldi, data="") + os.utime(name, (mtime - 10, mtime - 10)) + utils.SafeWriteFile(name, oldi, data="") + oldi = utils.GetFileID(path=name) + mtime = oldi[2] + os.utime(name, (mtime + 10, mtime + 10)) + # this doesn't raise, since we passed None + utils.SafeWriteFile(name, None, data="") + if __name__ == '__main__': testutils.GanetiTestProgram() -- GitLab