diff --git a/lib/utils.py b/lib/utils.py index aa84cb1f806d03363e651df9fd476769c9253597..bdd86105c88d8177471d610709e4a9b80480338e 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -1912,6 +1912,31 @@ def VerifyFileID(fi_disk, fi_ours): return (d1, i1) == (d2, i2) and m1 <= m2 +def SafeWriteFile(file_name, file_id, **kwargs): + """Wraper over L{WriteFile} that locks the target file. + + By keeping the target file locked during WriteFile, we ensure that + cooperating writers will safely serialise access to the file. + + @type file_name: str + @param file_name: the target filename + @type file_id: tuple + @param file_id: a result from L{GetFileID} + + """ + fd = os.open(file_name, os.O_RDONLY | os.O_CREAT) + try: + LockFile(fd) + if file_id is not None: + disk_id = GetFileID(fd=fd) + if not VerifyFileID(disk_id, file_id): + raise errors.LockError("Cannot overwrite file %s, it has been modified" + " since last written" % file_name) + return WriteFile(file_name, **kwargs) + finally: + os.close(fd) + + def ReadOneLineFile(file_name, strict=False): """Return the first non-empty line from a file. diff --git a/test/ganeti.utils_unittest.py b/test/ganeti.utils_unittest.py index 7fc93db1f63ab98187ea536e3d397ccce9a3df2c..2c46afcb1ae2ea2eaaa44fc403e7e82221c85a2b 100755 --- a/test/ganeti.utils_unittest.py +++ b/test/ganeti.utils_unittest.py @@ -2374,6 +2374,21 @@ class TestFileID(testutils.GanetiTestCase): finally: os.close(fd) + def testWriteFile(self): + name = self._CreateTempFile() + oldi = utils.GetFileID(path=name) + mtime = oldi[2] + os.utime(name, (mtime + 10, mtime + 10)) + self.assertRaises(errors.LockError, utils.SafeWriteFile, name, + oldi, data="") + os.utime(name, (mtime - 10, mtime - 10)) + utils.SafeWriteFile(name, oldi, data="") + oldi = utils.GetFileID(path=name) + mtime = oldi[2] + os.utime(name, (mtime + 10, mtime + 10)) + # this doesn't raise, since we passed None + utils.SafeWriteFile(name, None, data="") + if __name__ == '__main__': testutils.GanetiTestProgram()