diff --git a/lib/cache.py b/lib/cache.py new file mode 100644 index 0000000000000000000000000000000000000000..ae1c49779d4545faf8da1b14aa0729796d0d580b --- /dev/null +++ b/lib/cache.py @@ -0,0 +1,235 @@ +# +# + +# Copyright (C) 2011 Google Inc. +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +# 02110-1301, USA. + + +"""This module implements caching.""" + + +import time + +from ganeti import locking +from ganeti import serializer + + +TIMESTAMP = "timestamp" +TTL = "ttl" +VALUE = "value" + + +class CacheBase: + """This is the base class for all caches. + + """ + def __init__(self): + """Base init method. + + """ + + def Store(self, key, value, ttl=0): + """Stores key with value in the cache. + + @param key: The key to associate this cached value + @param value: The value to cache + @param ttl: TTL in seconds after when this entry is considered outdated + @returns: L{True} on success, L{False} on failure + + """ + raise NotImplementedError + + def GetMulti(self, keys): + """Retrieve multiple values from the cache. + + @param keys: The keys to retrieve + @returns: The list of values + + """ + raise NotImplementedError + + def Get(self, key): + """Retrieve the value from the cache. + + @param key: The key to retrieve + @returns: The value or L{None} if not found + + """ + raise NotImplementedError + + def Invalidate(self, keys): + """Invalidate given keys. + + @param keys: The list of keys to invalidate + @returns: L{True} on success, L{False} otherwise + + """ + raise NotImplementedError + + def Flush(self): + """Invalidates all of the keys and flushes the cache. + + """ + raise NotImplementedError + + def ResetState(self): + """Used to reset the state of the cache. + + This can be used to reinstantiate connection or any other state refresh + + """ + + def Cleanup(self): + """Cleanup the cache from expired entries. + + """ + + +class SimpleCache(CacheBase): + """Implements a very simple, dict base cache. + + """ + CLEANUP_ROUND = 1800 + _LOCK = "lock" + + def __init__(self, _time_fn=time.time): + """Initialize this class. + + @param _time_fn: Function used to return time (unittest only) + + """ + CacheBase.__init__(self) + + self._time_fn = _time_fn + + self.cache = {} + self.lock = locking.SharedLock("SimpleCache") + self.last_cleanup = self._time_fn() + + def _UnlockedCleanup(self): + """Does cleanup of the cache. + + """ + check_time = self._time_fn() + if (self.last_cleanup + self.CLEANUP_ROUND) <= check_time: + keys = [] + for key, value in self.cache.items(): + if not value[TTL]: + continue + + expired = value[TIMESTAMP] + value[TTL] + if expired < check_time: + keys.append(key) + self._UnlockedInvalidate(keys) + self.last_cleanup = check_time + + @locking.ssynchronized(_LOCK) + def Cleanup(self): + """Cleanup our cache. + + """ + self._UnlockedCleanup() + + @locking.ssynchronized(_LOCK) + def Store(self, key, value, ttl=0): + """Stores a value at key in the cache. + + See L{CacheBase.Store} for parameter description + + """ + assert ttl >= 0 + self._UnlockedCleanup() + val = serializer.Dump(value) + cache_val = { + TIMESTAMP: self._time_fn(), + TTL: ttl, + VALUE: val + } + self.cache[key] = cache_val + return True + + @locking.ssynchronized(_LOCK, shared=1) + def GetMulti(self, keys): + """Retrieve the values of keys from cache. + + See L{CacheBase.GetMulti} for parameter description + + """ + return [self._ExtractValue(key) for key in keys] + + @locking.ssynchronized(_LOCK, shared=1) + def Get(self, key): + """Retrieve the value of key from cache. + + See L{CacheBase.Get} for parameter description + + """ + return self._ExtractValue(key) + + @locking.ssynchronized(_LOCK) + def Invalidate(self, keys): + """Invalidates value for keys in cache. + + See L{CacheBase.Invalidate} for parameter description + + """ + return self._UnlockedInvalidate(keys) + + @locking.ssynchronized(_LOCK) + def Flush(self): + """Invalidates all keys and values in cache. + + See L{CacheBase.Flush} for parameter description + + """ + self.cache.clear() + self.last_cleanup = self._time_fn() + + def _UnlockedInvalidate(self, keys): + """Invalidate keys in cache. + + This is the unlocked version, see L{Invalidate} for parameter description + + """ + for key in keys: + self.cache.pop(key, None) + + return True + + def _ExtractValue(self, key): + """Extracts just the value for a key. + + This method is taking care if the value did not expire ans returns it + + @param key: The key to look for + @returns: The value if key is not expired, L{None} otherwise + + """ + try: + cache_val = self.cache[key] + except KeyError: + return None + else: + if cache_val[TTL] == 0: + return serializer.Load(cache_val[VALUE]) + else: + expired = cache_val[TIMESTAMP] + cache_val[TTL] + + if self._time_fn() <= expired: + return serializer.Load(cache_val[VALUE]) + else: + return None diff --git a/test/ganeti.cache_unittest.py b/test/ganeti.cache_unittest.py new file mode 100755 index 0000000000000000000000000000000000000000..d9dffb197e508a532417ad04c68bcfd3f1075678 --- /dev/null +++ b/test/ganeti.cache_unittest.py @@ -0,0 +1,104 @@ +#!/usr/bin/python +# + +# Copyright (C) 2011 Google Inc. +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +# 02110-1301, USA. + +"""Script for testing ganeti.cache""" + +import testutils +import unittest + +from ganeti import cache + + +class ReturnStub: + def __init__(self, values): + self.values = values + + def __call__(self): + assert self.values + return self.values.pop(0) + + +class SimpleCacheTest(unittest.TestCase): + def setUp(self): + self.cache = cache.SimpleCache() + + def testNoKey(self): + self.assertEqual(self.cache.GetMulti(["i-dont-exist", "neither-do-i", "no"]), + [None, None, None]) + + def testCache(self): + value = 0xc0ffee + self.assert_(self.cache.Store("i-exist", value)) + self.assertEqual(self.cache.GetMulti(["i-exist"]), [value]) + + def testMixed(self): + value = 0xb4dc0de + self.assert_(self.cache.Store("i-exist", value)) + self.assertEqual(self.cache.GetMulti(["i-exist", "i-dont"]), [value, None]) + + def testTtl(self): + my_times = ReturnStub([0, 1, 1, 2, 3, 5]) + ttl_cache = cache.SimpleCache(_time_fn=my_times) + self.assert_(ttl_cache.Store("test-expire", 0xdeadbeef, ttl=2)) + + # At this point time will return 2, 1 (start) + 2 (ttl) = 3, still valid + self.assertEqual(ttl_cache.Get("test-expire"), 0xdeadbeef) + + # At this point time will return 3, 1 (start) + 2 (ttl) = 3, still valid + self.assertEqual(ttl_cache.Get("test-expire"), 0xdeadbeef) + + # We are at 5, < 3, invalid + self.assertEqual(ttl_cache.Get("test-expire"), None) + self.assertFalse(my_times.values) + + def testCleanup(self): + my_times = ReturnStub([0, 1, 1, 2, 2, 3, 3, 5, 5, + 21 + cache.SimpleCache.CLEANUP_ROUND, + 34 + cache.SimpleCache.CLEANUP_ROUND, + 55 + cache.SimpleCache.CLEANUP_ROUND * 2, + 89 + cache.SimpleCache.CLEANUP_ROUND * 3]) + # Index 0 + ttl_cache = cache.SimpleCache(_time_fn=my_times) + # Index 1, 2 + self.assert_(ttl_cache.Store("foobar", 0x1dea, ttl=6)) + # Index 3, 4 + self.assert_(ttl_cache.Store("baz", 0xc0dea55, ttl=11)) + # Index 6, 7 + self.assert_(ttl_cache.Store("long-foobar", "pretty long", + ttl=(22 + cache.SimpleCache.CLEANUP_ROUND))) + # Index 7, 8 + self.assert_(ttl_cache.Store("foobazbar", "alive forever")) + + self.assertEqual(set(ttl_cache.cache.keys()), + set(["foobar", "baz", "long-foobar", "foobazbar"])) + ttl_cache.Cleanup() + self.assertEqual(set(ttl_cache.cache.keys()), + set(["long-foobar", "foobazbar"])) + ttl_cache.Cleanup() + self.assertEqual(set(ttl_cache.cache.keys()), + set(["long-foobar", "foobazbar"])) + ttl_cache.Cleanup() + self.assertEqual(set(ttl_cache.cache.keys()), set(["foobazbar"])) + ttl_cache.Cleanup() + self.assertEqual(set(ttl_cache.cache.keys()), set(["foobazbar"])) + + +if __name__ == "__main__": + testutils.GanetiTestProgram()