Commit f10c0b1b authored by Giorgos Korfiatis's avatar Giorgos Korfiatis
Browse files

restructure transactions, support multiple databases

parent b796612e
......@@ -22,7 +22,7 @@ from agkyra.syncer import localfs_client
from agkyra.syncer.pithos_client import PithosFileClient
from agkyra.syncer.syncer import FileSyncer
import agkyra.syncer.syncer
from agkyra.syncer import messaging, utils, common
from agkyra.syncer import messaging, utils, common, database
import random
import os
import time
......@@ -105,7 +105,7 @@ class AgkyraTest(unittest.TestCase):
cls.s = FileSyncer(cls.settings, cls.master, cls.slave)
cls.pithos = cls.master.endpoint
cls.pithos.create_container(cls.ID)
cls.db = cls.s.get_db()
cls.db = database.get_db(cls.s.syncer_dbtuple)
m = cls.s.get_next_message(block=True)
assert isinstance(m, messaging.PithosSyncEnabled)
m = cls.s.get_next_message(block=True)
......@@ -483,11 +483,9 @@ class AgkyraTest(unittest.TestCase):
self.assert_message(messaging.UpdateMessage)
with mock.patch(
"agkyra.syncer.database.SqliteFileStateDB.commit") as dbmock:
dbmock.side_effect = [sqlite3.OperationalError("locked"),
common.DatabaseError()]
"agkyra.syncer.database.DB.begin") as dbmock:
dbmock.side_effect = sqlite3.OperationalError("locked")
self.s.decide_file_sync(fil)
self.assert_message(messaging.HeartbeatReplayDecideMessage)
def test_007_multiprobe(self):
fil = "φ007"
......@@ -642,22 +640,24 @@ class AgkyraTest(unittest.TestCase):
f.write("content")
state = self.db.get_state(self.s.SLAVE, fil)
handle = localfs_client.LocalfsTargetHandle(self.s.settings, state)
handle = self.slave.prepare_target(state)
hidden_filename = utils.join_path(
handle.cache_hide_name, utils.hash_string(handle.objname))
hidden_path = handle.get_path_in_cache(hidden_filename)
self.assertFalse(os.path.isfile(hidden_path))
self.assertIsNone(self.db.get_cachename(hidden_filename))
client_db = database.get_db(self.slave.client_dbtuple)
self.assertIsNone(client_db.get_cachename(hidden_filename))
handle.move_file()
self.assertTrue(os.path.isfile(hidden_path))
self.assertIsNotNone(self.db.get_cachename(hidden_filename))
self.assertIsNotNone(client_db.get_cachename(hidden_filename))
handle.move_file()
self.assertTrue(os.path.isfile(hidden_path))
shutil.move(hidden_path, f_path)
self.assertIsNotNone(self.db.get_cachename(hidden_filename))
self.assertIsNotNone(client_db.get_cachename(hidden_filename))
handle.move_file()
self.assertTrue(os.path.isfile(hidden_path))
......@@ -726,7 +726,7 @@ class AgkyraTest(unittest.TestCase):
self.s.probe_file(self.s.SLAVE, fil)
self.assert_message(messaging.UpdateMessage)
state = self.db.get_state(self.s.SLAVE, fil)
handle = localfs_client.LocalfsSourceHandle(self.s.settings, state)
handle = self.slave.stage_file(state)
staged_path = handle.staged_path
self.assertTrue(localfs_client.files_equal(f_path, staged_path))
handle.unstage_file()
......@@ -734,14 +734,14 @@ class AgkyraTest(unittest.TestCase):
with open(f_path, "w") as f:
f.write("content new")
handle = localfs_client.LocalfsSourceHandle(self.s.settings, state)
handle = self.slave.stage_file(state)
self.assert_message(messaging.LiveInfoUpdateMessage)
self.assertTrue(localfs_client.files_equal(f_path, staged_path))
handle.unstage_file()
f = open(f_path, "r")
with self.assertRaises(common.OpenBusyError):
handle = localfs_client.LocalfsSourceHandle(self.s.settings, state)
handle = self.slave.stage_file(state)
ftmp_path = self.get_path("φ014tmp")
with open(ftmp_path, "w") as f:
......@@ -749,21 +749,21 @@ class AgkyraTest(unittest.TestCase):
os.unlink(f_path)
os.symlink(ftmp_path, f_path)
state = self.db.get_state(self.s.SLAVE, fil)
handle = localfs_client.LocalfsSourceHandle(self.s.settings, state)
handle = self.slave.stage_file(state)
self.assert_message(messaging.LiveInfoUpdateMessage)
self.assertIsNone(handle.staged_path)
self.s.probe_file(self.s.SLAVE, fln)
self.assert_message(messaging.UpdateMessage)
state = self.db.get_state(self.s.SLAVE, fln)
handle = localfs_client.LocalfsSourceHandle(self.s.settings, state)
handle = self.slave.stage_file(state)
self.assertIsNone(handle.staged_path)
os.unlink(fln_path)
with open(fln_path, "w") as f:
f.write("reg file")
handle = localfs_client.LocalfsSourceHandle(self.s.settings, state)
handle = self.slave.stage_file(state)
self.assertIsNone(handle.staged_path)
# try to stage now
......@@ -777,7 +777,7 @@ class AgkyraTest(unittest.TestCase):
fmissing_path = self.get_path(fmissing)
self.s.probe_file(self.s.SLAVE, fmissing)
state = self.db.get_state(self.s.SLAVE, fmissing)
handle = localfs_client.LocalfsSourceHandle(self.s.settings, state)
handle = self.slave.stage_file(state)
self.assertIsNone(handle.staged_path)
with open(fmissing_path, "w") as f:
......
......@@ -18,6 +18,9 @@ from collections import namedtuple
OBJECT_DIRSEP = '/'
DBTuple = namedtuple('DBTuple',
['dbtype', 'dbname'])
FileStateTuple = namedtuple('FileStateTuple',
['archive', 'objname', 'serial', 'info'])
......
......@@ -13,74 +13,52 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from functools import wraps
import time
import sqlite3
import json
import logging
import random
import threading
import datetime
import inspect
from agkyra.syncer import common, utils
logger = logging.getLogger(__name__)
class FileStateDB(object):
thread_local_data = threading.local()
def new_serial(self, objname):
raise NotImplementedError
def list_files(self, archive):
raise NotImplementedError
def put_state(self, state):
raise NotImplementedError
def get_state(self, archive, objname):
raise NotImplementedError
class SqliteFileStateDB(FileStateDB):
class DB(object):
def __init__(self, dbname, initialize=False):
self.dbname = dbname
self.db = sqlite3.connect(dbname)
if initialize:
self.init()
def begin(self):
self.db.execute("begin immediate")
def commit(self):
self.db.commit()
def rollback(self):
self.db.rollback()
class ClientDB(DB):
def init(self):
logger.info("Initializing DB '%s'" % self.dbname)
db = self.db
Q = ("create table if not exists "
"archives(archive text, objname text, serial integer, "
"info blob, primary key (archive, objname))")
db.execute(Q)
Q = ("create table if not exists "
"serials(objname text, nextserial bigint, primary key (objname))")
db.execute(Q)
Q = ("create table if not exists "
"cachenames(cachename text, client text, objname text, "
"primary key (cachename))")
db.execute(Q)
Q = ("create table if not exists "
"config(key text, value text, primary key (key))")
db.execute(Q)
self.commit()
def begin(self):
self.db.execute("begin")
def commit(self):
self.db.commit()
def rollback(self):
self.db.rollback()
def get_cachename(self, cachename):
db = self.db
Q = "select * from cachenames where cachename = ?"
......@@ -102,6 +80,27 @@ class SqliteFileStateDB(FileStateDB):
Q = "delete from cachenames where cachename = ?"
db.execute(Q, (cachename,))
class SyncerDB(DB):
def init(self):
logger.info("Initializing DB '%s'" % self.dbname)
db = self.db
Q = ("create table if not exists "
"archives(archive text, objname text, serial integer, "
"info blob, primary key (archive, objname))")
db.execute(Q)
Q = ("create table if not exists "
"serials(objname text, nextserial bigint, primary key (objname))")
db.execute(Q)
Q = ("create table if not exists "
"config(key text, value text, primary key (key))")
db.execute(Q)
self.commit()
def new_serial(self, objname):
db = self.db
Q = ("select nextserial from serials where objname = ?")
......@@ -230,41 +229,91 @@ def rand(lim):
return random.random() * lim
def transaction(max_wait=60, init_wait=0.4, exp_backoff=1.1):
def wrap(func):
@wraps(func)
def inner(*args, **kwargs):
obj = args[0]
db = obj.get_db()
attempt = 0
current_max_wait = init_wait
while True:
try:
db.begin()
r = func(*args, **kwargs)
db.commit()
return r
except Exception as e:
db.rollback()
# TODO check conflict
if isinstance(e, sqlite3.OperationalError) and \
"locked" in e.message:
if current_max_wait <= max_wait:
attempt += 1
logger.warning(
"Got DB error '%s' while running '%s' "
"with args '%s' and kwargs '%s'. "
"Retrying transaction (%s times)." %
(e, func.__name__, args, kwargs, attempt))
time.sleep(rand(current_max_wait))
current_max_wait *= exp_backoff
else:
logger.error(
"Got DB error '%s' while running '%s' "
"with args '%s' and kwargs '%s'. Aborting." %
(e, func.__name__, args, kwargs))
raise common.DatabaseError(e)
def get_db(dbtuple, initialize=False):
dbname = dbtuple.dbname
dbtype = dbtuple.dbtype
dbs = getattr(thread_local_data, "dbs", None)
if dbs is not None:
db = dbs.get(dbname)
else:
db = None
if db is None:
logger.debug("Connecting db: '%s', thread: %s" %
(dbname, threading.current_thread().ident))
db = dbtype(dbname, initialize=initialize)
if dbs is None:
thread_local_data.dbs = {}
thread_local_data.dbs[dbname] = db
return db
def initialize(dbtuple):
return get_db(dbtuple, initialize=True)
class TransactedConnection(object):
def __init__(self, dbtuple, max_wait=60, init_wait=0.4, exp_backoff=1.1):
self.db = get_db(dbtuple)
self.max_wait = max_wait
self.init_wait = init_wait
self.exp_backoff = exp_backoff
def __enter__(self):
attempt = 0
current_max_wait = self.init_wait
total_wait = 0
while True:
try:
curframe = inspect.currentframe()
calframe = inspect.getouterframes(curframe, 2)
caller_name = calframe[1][3]
tbefore = datetime.datetime.now()
self.db.begin()
tafter = datetime.datetime.now()
logger.debug("BEGIN %s %s" % (tafter-tbefore, caller_name))
return self.db
except sqlite3.Error as e:
tfail = datetime.datetime.now()
self.db.rollback()
if isinstance(e, sqlite3.OperationalError) and \
"locked" in e.message:
if total_wait <= self.max_wait:
attempt += 1
logger.warning(
"Got DB error '%s' while beginning transaction %s "
"after %s sec. Retrying (%s times)." %
(e, caller_name, tfail-tbefore, attempt))
sleeptime = rand(current_max_wait)
total_wait += sleeptime
time.sleep(sleeptime)
current_max_wait *= self.exp_backoff
else:
raise e
return inner
return wrap
logger.error(
"Got DB error '%s' while beginning transaction %s. "
"Aborting." %
(e, caller_name))
raise common.DatabaseError(e)
else:
logger.error(
"Got sqlite3 error '%s while beginning transaction %s. "
"Aborting." %
(e, caller_name))
raise common.DatabaseError(e)
def __exit__(self, exctype, value, traceback):
if value is not None:
try:
self.db.rollback()
finally:
if issubclass(exctype, sqlite3.Error):
raise common.DatabaseError(value)
return False # re-raise
else:
try:
self.db.commit()
except sqlite3.Error as e:
try:
self.db.rollback()
finally:
raise common.DatabaseError(e)
......@@ -28,8 +28,8 @@ from watchdog.events import FileSystemEventHandler
import logging
from agkyra.syncer.file_client import FileClient
from agkyra.syncer import utils, common, messaging
from agkyra.syncer.database import transaction
from agkyra.syncer import utils, common, messaging, database
from agkyra.syncer.database import TransactedConnection
logger = logging.getLogger(__name__)
......@@ -258,14 +258,17 @@ def is_info_eq(info1, info2, unhandled_equal=True):
class LocalfsTargetHandle(object):
def __init__(self, settings, target_state):
def __init__(self, client, target_state):
self.client = client
settings = client.settings
self.settings = settings
self.SIGNATURE = "LocalfsTargetHandle"
self.rootpath = settings.local_root_path
self.cache_hide_name = settings.cache_hide_name
self.cache_hide_path = settings.cache_hide_path
self.cache_path = settings.cache_path
self.get_db = settings.get_db
self.syncer_dbtuple = settings.syncer_dbtuple
self.client_dbtuple = client.client_dbtuple
self.mtime_lag = settings.mtime_lag
self.target_state = target_state
self.objname = target_state.objname
......@@ -276,9 +279,11 @@ class LocalfsTargetHandle(object):
def get_path_in_cache(self, name):
return utils.join_path(self.cache_path, name)
@transaction()
def register_hidden_name(self, filename):
db = self.get_db()
with TransactedConnection(self.client_dbtuple) as db:
return self._register_hidden_name(db, filename)
def _register_hidden_name(self, db, filename):
f = utils.hash_string(filename)
hide_filename = utils.join_path(self.cache_hide_name, f)
self.hidden_filename = hide_filename
......@@ -288,10 +293,9 @@ class LocalfsTargetHandle(object):
db.insert_cachename(hide_filename, self.SIGNATURE, filename)
return True
@transaction()
def unregister_hidden_name(self, hidden_filename):
db = self.get_db()
db.delete_cachename(hidden_filename)
with TransactedConnection(self.client_dbtuple) as db:
db.delete_cachename(hidden_filename)
self.hidden_filename = None
self.hidden_path = None
......@@ -412,9 +416,11 @@ class LocalfsTargetHandle(object):
class LocalfsSourceHandle(object):
@transaction()
def register_stage_name(self, filename):
db = self.get_db()
with TransactedConnection(self.client_dbtuple) as db:
return self._register_stage_name(db, filename)
def _register_stage_name(self, db, filename):
f = utils.hash_string(filename)
stage_filename = utils.join_path(self.cache_stage_name, f)
self.stage_filename = stage_filename
......@@ -425,10 +431,9 @@ class LocalfsSourceHandle(object):
db.insert_cachename(stage_filename, self.SIGNATURE, filename)
return True
@transaction()
def unregister_stage_name(self, stage_filename):
db = self.get_db()
db.delete_cachename(stage_filename)
with TransactedConnection(self.client_dbtuple) as db:
db.delete_cachename(stage_filename)
self.stage_filename = None
self.staged_path = None
......@@ -500,14 +505,17 @@ class LocalfsSourceHandle(object):
logger.warning(m)
raise common.ChangedBusyError(m)
def __init__(self, settings, source_state):
def __init__(self, client, source_state):
self.client = client
settings = client.settings
self.settings = settings
self.SIGNATURE = "LocalfsSourceHandle"
self.rootpath = settings.local_root_path
self.cache_stage_name = settings.cache_stage_name
self.cache_stage_path = settings.cache_stage_path
self.cache_path = settings.cache_path
self.get_db = settings.get_db
self.syncer_dbtuple = settings.syncer_dbtuple
self.client_dbtuple = client.client_dbtuple
self.source_state = source_state
self.objname = source_state.objname
self.fspath = utils.join_path(self.rootpath, self.objname)
......@@ -517,10 +525,9 @@ class LocalfsSourceHandle(object):
if info_of_regular_file(self.source_state.info):
self.stage_file()
@transaction()
def update_state(self, state):
db = self.get_db()
db.put_state(state)
with TransactedConnection(self.syncer_dbtuple) as db:
db.put_state(state)
def check_update_source_state(self, live_info):
if not is_info_eq(live_info, self.source_state.info):
......@@ -568,7 +575,12 @@ class LocalfsFileClient(FileClient):
self.SIGNATURE = "LocalfsFileClient"
self.ROOTPATH = settings.local_root_path
self.CACHEPATH = settings.cache_path
self.get_db = settings.get_db
self.syncer_dbtuple = settings.syncer_dbtuple
client_dbname = self.SIGNATURE+'.db'
self.client_dbtuple = common.DBTuple(
dbtype=database.ClientDB,
dbname=utils.join_path(settings.instance_path, client_dbname))
database.initialize(self.client_dbtuple)
self.probe_candidates = utils.ThreadSafeDict()
self.check_enabled()
......@@ -638,10 +650,9 @@ class LocalfsFileClient(FileClient):
logger.debug("Candidates: %s" % candidates)
return candidates
@transaction()
def list_files(self):
db = self.get_db()
return db.list_files(self.SIGNATURE)
with TransactedConnection(self.syncer_dbtuple) as db:
return db.list_files(self.SIGNATURE)
def _local_path_changes(self, name, state):
local_path = utils.join_path(self.ROOTPATH, name)
......@@ -678,15 +689,14 @@ class LocalfsFileClient(FileClient):
return live_state
def stage_file(self, source_state):
return LocalfsSourceHandle(self.settings, source_state)
return LocalfsSourceHandle(self, source_state)
def prepare_target(self, target_state):
return LocalfsTargetHandle(self.settings, target_state)
return LocalfsTargetHandle(self, target_state)
@transaction()
def get_dir_contents(self, objname):
db = self.get_db()
return db.get_dir_contents(self.SIGNATURE, objname)
with TransactedConnection(self.syncer_dbtuple) as db:
return db.get_dir_contents(self.SIGNATURE, objname)
def notifier(self):
def handle_path(path, rec=False):
......
......@@ -19,10 +19,10 @@ import os
import logging
import re
from agkyra.syncer import utils, common, messaging
from agkyra.syncer import utils, common, messaging, database
from agkyra.syncer.file_client import FileClient
from agkyra.syncer.setup import ClientError
from agkyra.syncer.database import transaction
from agkyra.syncer.database import TransactedConnection
logger = logging.getLogger(__name__)
......@@ -41,21 +41,26 @@ def handle_client_errors(f):
class PithosSourceHandle(object):
def __init__(self, settings, source_state):
def __init__(self, client, source_state):
self.SIGNATURE = "PithosSourceHandle"
self.client = client
settings = client.settings
self.settings = settings
self.endpoint = settings.endpoint
self.cache_fetch_name = settings.cache_fetch_name
self.cache_fetch_path = settings.cache_fetch_path
self.cache_path = settings.cache_path
self.get_db = settings.get_db
self.syncer_dbtuple = settings.syncer_dbtuple
self.client_dbtuple = client.client_dbtuple
self.source_state = source_state
self.objname = source_state.objname
self.heartbeat = settings.heartbeat
@transaction()
def register_fetch_name(self, filename):
db = self.get_db()
with TransactedConnection(self.client_dbtuple) as db:
return self._register_fetch_name(db, filename)
def _register_fetch_name(self, db, filename):
f = utils.hash_string(filename) + "_" + \
utils.str_time_stamp()
fetch_name = utils.join_path(self.cache_fetch_name, f)
......@@ -98,10 +103,9 @@ class PithosSourceHandle(object):
os.mkdir(fetched_fspath)
return fetched_fspath
@transaction()
def update_state(self, state):
db = self.get_db()
db.put_state(state)
with TransactedConnection(self.syncer_dbtuple) as db:
db.put_state(state)
def check_update_source_state(self, actual_info):
if actual_info != self.source_state.info:
......@@ -125,7 +129,9 @@ exclude_pattern = re.compile(exclude_staged_regex)
class PithosTargetHandle(object):
def __init__(self, settings, target_state):
def __init__(self, client, target_state):
self.client = client
settings = client.settings
self.settings = settings
self.endpoint = settings.endpoint
self.target_state = target_state
......@@ -236,7 +242,12 @@ class PithosFileClient(FileClient):
self.auth_url = settings.auth_url
self.auth_token = settings.auth_token
self.container