From 2fd70aac67e13e293b2e30df1f7a2694c8aeda02 Mon Sep 17 00:00:00 2001
From: Giorgos Korfiatis <gkorf@grnet.gr>
Date: Sat, 10 Oct 2015 17:54:22 +0300
Subject: [PATCH] use transaction mechanism in protocol

---
 agkyra/cli.py      |   3 +-
 agkyra/protocol.py | 131 ++++++++++++++++++++-------------------------
 2 files changed, 58 insertions(+), 76 deletions(-)

diff --git a/agkyra/cli.py b/agkyra/cli.py
index a938331..3604d00 100644
--- a/agkyra/cli.py
+++ b/agkyra/cli.py
@@ -174,8 +174,7 @@ class AgkyraCLI(cmd.Cmd):
         """Return the helper client instace or None"""
         self._client = getattr(self, '_client', None)
         if not self._client:
-            session = protocol.retry_on_locked_db(
-                self.helper.load_active_session)
+            session = self.helper.load_active_session()
             if session:
                 self._client = protocol_client.UIClient(session)
                 self._client.connect()
diff --git a/agkyra/protocol.py b/agkyra/protocol.py
index 4282c7f..ba754d9 100644
--- a/agkyra/protocol.py
+++ b/agkyra/protocol.py
@@ -27,7 +27,8 @@ import json
 import logging
 import subprocess
 from agkyra.syncer import (
-    syncer, setup, pithos_client, localfs_client, messaging, utils)
+    syncer, setup, pithos_client, localfs_client, messaging, utils, database,
+    common)
 from agkyra.config import AgkyraConfig, AGKYRA_DIR
 
 if getattr(sys, 'frozen', False):
@@ -49,19 +50,36 @@ with open(os.path.join(RESOURCES, 'ui_data/common_en.json')) as f:
 STATUS = COMMON['STATUS']
 
 
-def retry_on_locked_db(method, *args, **kwargs):
-    """If DB is locked, wait and try again"""
-    wait = kwargs.get('wait', 0.2)
-    retries = kwargs.get('retries', 2)
-    while retries:
-        try:
-            return method(*args, **kwargs)
-        except sqlite3.OperationalError as oe:
-            if 'locked' not in '%s' % oe:
-                raise
-            LOG.debug('%s, retry' % oe)
-        time.sleep(wait)
-        retries -= 1
+class SessionDB(database.DB):
+    def init(self):
+        db = self.db
+        db.execute(
+            'CREATE TABLE IF NOT EXISTS heart ('
+            'ui_id VARCHAR(256), address text, beat VARCHAR(32)'
+            ')')
+
+    def get_all_heartbeats(self):
+        db = self.db
+        r = db.execute('SELECT * FROM heart')
+        return r.fetchall()
+
+    def register_heartbeat(self, ui_id, address):
+        db = self.db
+        db.execute('INSERT INTO heart VALUES (?, ?, ?)',
+                   (ui_id, address, time.time()))
+
+    def update_heartbeat(self, ui_id):
+        db = self.db
+        r = db.execute('SELECT ui_id FROM heart WHERE ui_id=?', (ui_id,))
+        if r.fetchall():
+            db.execute('UPDATE heart SET beat=? WHERE ui_id=?',
+                       (time.time(), ui_id))
+            return True
+        return False
+
+    def unregister_heartbeat(self, ui_id):
+        db = self.db
+        db.execute('DELETE FROM heart WHERE ui_id=?', (ui_id,))
 
 
 class SessionHelper(object):
@@ -72,27 +90,18 @@ class SessionHelper(object):
 
     def __init__(self, **kwargs):
         """Setup the helper server"""
-        self.session_db = kwargs.get(
+        db_name = kwargs.get(
             'session_db', os.path.join(AGKYRA_DIR, 'session.db'))
-        self.session_relation = kwargs.get('session_relation', 'heart')
-
-        LOG.debug('Connect to db')
-        self.db = sqlite3.connect(self.session_db)
-        retry_on_locked_db(self._init_db_relation)
-
-    def _init_db_relation(self):
-        """Create the session relation"""
-        self.db.execute('BEGIN')
-        self.db.execute(
-            'CREATE TABLE IF NOT EXISTS %s ('
-            'ui_id VARCHAR(256), address text, beat VARCHAR(32)'
-            ')' % self.session_relation)
-        self.db.commit()
+        self.session_db = common.DBTuple(dbtype=SessionDB, dbname=db_name)
+        database.initialize(self.session_db)
 
     def load_active_session(self):
+        with database.TransactedConnection(self.session_db) as db:
+            return self._load_active_session(db)
+
+    def _load_active_session(self, db):
         """Load a session from db"""
-        r = self.db.execute('SELECT * FROM %s' % self.session_relation)
-        sessions = r.fetchall()
+        sessions = db.get_all_heartbeats()
         if sessions:
             last, expected_id = sessions[-1], getattr(self, 'ui_id', None)
             if expected_id and last[0] != '%s' % expected_id:
@@ -107,22 +116,13 @@ class SessionHelper(object):
 
     def create_session_daemon(self):
         """Create and return a new daemon, or None if one exists"""
-
-        def get_session():
-                self.db.execute('BEGIN')
-                return self.load_active_session()
-
-        session = retry_on_locked_db(get_session)
-        if session:
-            self.db.rollback()
-            return None
-
-        session_daemon = SessionDaemon(self.session_db, self.session_relation)
-        self.db.execute('INSERT INTO %s VALUES ("%s", "%s", "%s")' % (
-            self.session_relation, session_daemon.ui_id,
-            session_daemon.address, time.time()))
-        self.db.commit()
-        return session_daemon
+        with database.TransactedConnection(self.session_db) as db:
+            session = self._load_active_session(db)
+            if session:
+                return None
+            session_daemon = SessionDaemon(self.session_db)
+            db.register_heartbeat(session_daemon.ui_id, session_daemon.address)
+            return session_daemon
 
     def wait_session_to_load(self, timeout=20, step=0.2):
         """Wait while the session is loading e.g. in another process
@@ -152,15 +152,13 @@ class SessionDaemon(object):
     """A WebSocket server which inspects a heartbeat and decides whether to
     shut down
     """
-    def __init__(self, session_db, session_relation, *args, **kwargs):
+    def __init__(self, session_db, *args, **kwargs):
         self.session_db = session_db
-        self.session_relation = session_relation
         ui_id = sha1(os.urandom(128)).hexdigest()
 
         LOCAL_ADDR = '127.0.0.1'
         WebSocketProtocol.ui_id = ui_id
         WebSocketProtocol.session_db = session_db
-        WebSocketProtocol.session_relation = session_relation
         server = make_server(
             LOCAL_ADDR, 0,
             server_class=WSGIServer,
@@ -174,23 +172,12 @@ class SessionDaemon(object):
 
     def heartbeat(self):
         """Periodically update the session database timestamp"""
-        db, alive = sqlite3.connect(self.session_db), True
-        while alive:
+        while True:
             time.sleep(2)
-            try:
-                db.execute('BEGIN')
-                r = db.execute('SELECT ui_id FROM %s WHERE ui_id="%s"' % (
-                    self.session_relation, self.ui_id))
-                if r.fetchall():
-                    db.execute('UPDATE %s SET beat="%s" WHERE ui_id="%s"' % (
-                        self.session_relation, time.time(), self.ui_id))
-                else:
-                    alive = False
-                db.commit()
-            except sqlite3.OperationalError as oe:
-                if 'locked' not in '%s' % oe:
-                    raise
-        db.close()
+            with database.TransactedConnection(self.session_db) as db:
+                found = db.update_heartbeat(self.ui_id)
+                if not found:
+                    break
         self.close_manager()
         self.server.shutdown()
 
@@ -276,7 +263,7 @@ class WebSocketProtocol(WebSocket):
         d.update(code=STATUS['UNINITIALIZED'], synced=0, unsynced=0, failed=0)
 
     ui_id = None
-    session_db, session_relation = None, None
+    session_db = None
     accepted = False
     settings = dict(
         token=None, url=None,
@@ -318,12 +305,8 @@ class WebSocketProtocol(WebSocket):
     def clean_db(self):
         """Clean DB from current session trace"""
         LOG.debug('Remove current session trace')
-        db = sqlite3.connect(self.session_db)
-        db.execute('BEGIN')
-        db.execute('DELETE FROM %s WHERE ui_id="%s"' % (
-            self.session_relation, self.ui_id))
-        db.commit()
-        db.close()
+        with database.TransactedConnection(self.session_db) as db:
+            db.unregister_heartbeat(self.ui_id)
 
     def shutdown_syncer(self, syncer_key=0):
         """Shutdown the syncer backend object"""
@@ -612,7 +595,7 @@ class WebSocketProtocol(WebSocket):
                 # Clean db to cause syncer backend to shut down
                 self.set_status(code=STATUS['SHUTTING DOWN'])
                 self.shutdown_syncer()
-                retry_on_locked_db(self.clean_db)
+                self.clean_db()
                 # self._shutdown()
                 # self.terminate()
                 return
-- 
GitLab