protocol.py 21.2 KB
Newer Older
Giorgos Korfiatis's avatar
Giorgos Korfiatis committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (C) 2015 GRNET S.A.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

16
from wsgiref.simple_server import make_server
17
from ws4py.websocket import WebSocket
18 19 20 21
from ws4py.server.wsgiutils import WebSocketWSGIApplication
from ws4py.server.wsgirefserver import WSGIServer, WebSocketWSGIRequestHandler
from hashlib import sha1
from threading import Thread
22 23
import sqlite3
import time
24
import os
25 26
import json
import logging
27
from agkyra.syncer import (
28
    syncer, setup, pithos_client, localfs_client, messaging, utils)
29
from agkyra.config import AgkyraConfig, AGKYRA_DIR
30 31


32 33
CURPATH = os.path.dirname(os.path.abspath(__file__))

34
LOG = logging.getLogger(__name__)
35
SYNCERS = utils.ThreadSafeDict()
36 37


38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
def retry_on_locked_db(method, *args, **kwargs):
    """If DB is locked, wait and try again"""
    wait = kwargs.get('wait', 0.2)
    retries = kwargs.get('retries', 2)
    while retries:
        try:
            return method(*args, **kwargs)
        except sqlite3.OperationalError as oe:
            if 'locked' not in '%s' % oe:
                raise
            LOG.debug('%s, retry' % oe)
        time.sleep(wait)
        retries -= 1


53
class SessionHelper(object):
54 55 56
    """Agkyra Helper Server sets a WebSocket server with the Helper protocol
    It also provided methods for running and killing the Helper server
    """
57
    session_timeout = 20
58

59
    def __init__(self, **kwargs):
60
        """Setup the helper server"""
61 62 63 64 65 66
        self.session_db = kwargs.get(
            'session_db', os.path.join(AGKYRA_DIR, 'session.db'))
        self.session_relation = kwargs.get('session_relation', 'heart')

        LOG.debug('Connect to db')
        self.db = sqlite3.connect(self.session_db)
67

68
        self._init_db_relation()
69
        # self.session = self._load_active_session() or self._create_session()
70

71
        # self.db.close()
72

73
    def _init_db_relation(self):
74
        """Create the session relation"""
75 76 77 78 79 80 81
        self.db.execute('BEGIN')
        self.db.execute(
            'CREATE TABLE IF NOT EXISTS %s ('
            'ui_id VARCHAR(256), address text, beat VARCHAR(32)'
            ')' % self.session_relation)
        self.db.commit()

82
    def load_active_session(self):
83 84 85 86
        """Load a session from db"""
        r = self.db.execute('SELECT * FROM %s' % self.session_relation)
        sessions = r.fetchall()
        if sessions:
87 88 89 90
            last, expected_id = sessions[-1], getattr(self, 'ui_id', None)
            if expected_id and last[0] != '%s' % expected_id:
                LOG.debug('Session ID is old')
                return None
91 92 93 94
            now, last_beat = time.time(), float(last[2])
            if now - last_beat < self.session_timeout:
                # Found an active session
                return dict(ui_id=last[0], address=last[1])
95
        LOG.debug('No active sessions found')
96 97
        return None

98
    def create_session(self):
99 100 101 102 103 104 105 106 107 108 109
        """Return the active session or create a new one"""

        def get_session():
                self.db.execute('BEGIN')
                return self.load_active_session()

        session = retry_on_locked_db(get_session)
        if session:
            self.db.rollback()
            return session

110 111
        ui_id = sha1(os.urandom(128)).hexdigest()

Giorgos Korfiatis's avatar
Giorgos Korfiatis committed
112
        LOCAL_ADDR = '127.0.0.1'
113
        WebSocketProtocol.ui_id = ui_id
114 115
        WebSocketProtocol.session_db = self.session_db
        WebSocketProtocol.session_relation = self.session_relation
116
        server = make_server(
Giorgos Korfiatis's avatar
Giorgos Korfiatis committed
117
            LOCAL_ADDR, 0,
118 119 120 121
            server_class=WSGIServer,
            handler_class=WebSocketWSGIRequestHandler,
            app=WebSocketWSGIApplication(handler_cls=WebSocketProtocol))
        server.initialize_websockets_manager()
Giorgos Korfiatis's avatar
Giorgos Korfiatis committed
122
        address = 'ws://%s:%s' % (LOCAL_ADDR, server.server_port)
123

124 125 126 127
        self.db.execute('INSERT INTO %s VALUES ("%s", "%s", "%s")' % (
            self.session_relation, ui_id, address, time.time()))
        self.db.commit()

128
        self.server = server
129
        self.ui_id = ui_id
130 131
        return dict(ui_id=ui_id, address=address)

132 133 134 135 136 137 138 139 140 141 142 143 144
    def wait_session_to_load(self, timeout=20, step=2):
        """Wait while the session is loading e.g. in another process
            :returns: the session or None if timeout
        """
        time_passed = 0
        while time_passed < timeout:
            self.session = self.load_active_session()
            if self.session:
                return self.session
            time_passed += step
            time.sleep(step)
        return None

145 146 147 148 149 150 151 152 153 154
    def wait_session_to_stop(self, timeout=20, step=2):
        """Wait while the session is shutting down
            :returns: True if stopped, False if timed out and still running
        """
        time_passed = 0
        while time_passed < timeout and self.load_active_session():
            time.sleep(step)
            time_passed += step
        return not bool(self.load_active_session())

155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
    def heartbeat(self):
        """General session heartbeat - when heart stops, WSGI server dies"""
        db, alive = sqlite3.connect(self.session_db), True
        while alive:
            time.sleep(2)
            try:
                db.execute('BEGIN')
                r = db.execute('SELECT ui_id FROM %s WHERE ui_id="%s"' % (
                    self.session_relation, self.ui_id))
                if r.fetchall():
                    db.execute('UPDATE %s SET beat="%s" WHERE ui_id="%s"' % (
                        self.session_relation, time.time(), self.ui_id))
                else:
                    alive = False
                db.commit()
            except sqlite3.OperationalError as oe:
                if 'locked' not in '%s' % oe:
                    raise
        db.close()

175 176
    def start(self):
        """Start the helper server in a thread"""
177
        if getattr(self, 'server', None):
178 179 180
            t = Thread(target=self._shutdown_daemon)
            t.start()
            Thread(target=self.heartbeat).start()
181
            self.server.serve_forever()
182 183
            t.join()
            LOG.debug('WSGI server is down')
184

185
    def _shutdown_daemon(self):
186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
        """Shutdown WSGI server when the heart stops"""
        db = sqlite3.connect(self.session_db)
        while True:
            time.sleep(4)
            try:
                r = db.execute('SELECT ui_id FROM %s WHERE ui_id="%s"' % (
                    self.session_relation, self.ui_id))
                if not r.fetchall():
                    db.close()
                    time.sleep(5)
                    t = Thread(target=self.server.shutdown)
                    t.start()
                    t.join()
                    break
            except sqlite3.OperationalError:
                pass
202 203


204 205 206
class WebSocketProtocol(WebSocket):
    """Helper-side WebSocket protocol for communication with GUI:

207
    -- INTERNAL HANDSAKE --
208
    GUI: {"method": "post", "ui_id": <GUI ID>}
209
    HELPER: {"ACCEPTED": 202, "action": "post ui_id"}" or
210
        "{"REJECTED": 401, "action": "post ui_id"}
211

212 213 214 215
    -- ERRORS WITH SIGNIFICANCE --
    If the token doesn't work:
    HELPER: {"action": <action that caused the error>, "UNAUTHORIZED": 401}

216 217 218 219 220 221 222
    -- SHUT DOWN --
    GUI: {"method": "post", "path": "shutdown"}

    -- PAUSE --
    GUI: {"method": "post", "path": "pause"}
    HELPER: {"OK": 200, "action": "post pause"} or error

223
    -- START --
224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
    GUI: {"method": "post", "path": "start"}
    HELPER: {"OK": 200, "action": "post start"} or error

    -- GET SETTINGS --
    GUI: {"method": "get", "path": "settings"}
    HELPER:
        {
            "action": "get settings",
            "token": <user token>,
            "url": <auth url>,
            "container": <container>,
            "directory": <local directory>,
            "exclude": <file path>
        } or {<ERROR>: <ERROR CODE>}

    -- PUT SETTINGS --
    GUI: {
            "method": "put", "path": "settings",
            "token": <user token>,
            "url": <auth url>,
            "container": <container>,
            "directory": <local directory>,
            "exclude": <file path>
        }
    HELPER: {"CREATED": 201, "action": "put settings",} or
        {<ERROR>: <ERROR CODE>, "action": "get settings",}

    -- GET STATUS --
    GUI: {"method": "get", "path": "status"}
253 254 255 256 257
    HELPER: {
        "can_sync": <boolean>,
        "progress": <int>,
        "paused": <boolean>,
        "action": "get status"} or
258 259 260
        {<ERROR>: <ERROR CODE>, "action": "get status"}
    """

261
    ui_id = None
262
    session_db, session_relation = None, None
263 264 265 266
    accepted = False
    settings = dict(
        token=None, url=None,
        container=None, directory=None,
267
        exclude=None)
268 269
    status = dict(
        progress=0, synced=0, unsynced=0, paused=True, can_sync=False)
270
    cnf = AgkyraConfig()
271
    essentials = ('url', 'token', 'container', 'directory')
272

273 274 275 276 277 278 279
    @property
    def syncer(self):
        with SYNCERS.lock() as d:
            for sync_key, sync_obj in d.items():
                return sync_obj
        return None

280 281 282 283 284 285 286 287 288 289 290 291
    def _shutdown(self):
        """Shutdown the service"""
        LOG.debug('Shutdown syncer')
        self.close()
        if self.can_sync():
            self.syncer.stop_all_daemons()
            LOG.debug('Wait open syncs to complete')
            self.syncer.wait_sync_threads()

    def clean_db(self):
        """Clean DB from session traces"""
        LOG.debug('Remove session traces')
292
        db = sqlite3.connect(self.session_db)
293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315
        db.execute('BEGIN')
        db.execute('DELETE FROM %s' % self.session_relation)
        db.commit()
        db.close()

    def heartbeat(self):
        """Check if socket should be alive"""
        db, alive = sqlite3.connect(self.session_db), True
        while alive:
            time.sleep(1)
            try:
                db.execute('BEGIN')
                r = db.execute('SELECT ui_id FROM %s WHERE ui_id="%s"' % (
                    self.session_relation, self.ui_id))
                if r.fetchall():
                    db.execute('UPDATE %s SET beat="%s" WHERE ui_id="%s"' % (
                        self.session_relation, time.time(), self.ui_id))
                else:
                    alive = False
                db.commit()
            except sqlite3.OperationalError:
                alive = True
        db.close()
316
        self._shutdown()
317

318 319 320 321
    def _get_default_sync(self):
        """Get global.default_sync or pick the first sync as default
        If there are no syncs, create a 'default' sync.
        """
322
        sync = self.cnf.get('global', 'default_sync')
323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343
        if not sync:
            for sync in self.cnf.keys('sync'):
                break
            self.cnf.set('global', 'default_sync', sync or 'default')
        return sync or 'default'

    def _get_sync_cloud(self, sync):
        """Get the <sync>.cloud or pick the first cloud and use it
        In case of cloud picking, set the cloud as the <sync>.cloud for future
        sessions.
        If no clouds are found, create a 'default' cloud, with an empty url.
        """
        try:
            cloud = self.cnf.get_sync(sync, 'cloud')
        except KeyError:
            cloud = None
        if not cloud:
            for cloud in self.cnf.keys('cloud'):
                break
            self.cnf.set_sync(sync, 'cloud', cloud or 'default')
        return cloud or 'default'
344

345 346 347 348
    def _load_settings(self):
        LOG.debug('Start loading settings')
        sync = self._get_default_sync()
        cloud = self._get_sync_cloud(sync)
349 350

        try:
351 352 353 354 355 356 357
            self.settings['url'] = self.cnf.get_cloud(cloud, 'url')
        except Exception:
            self.settings['url'] = None
        try:
            self.settings['token'] = self.cnf.get_cloud(cloud, 'token')
        except Exception:
            self.settings['url'] = None
358

359 360
        # for option in ('container', 'directory', 'exclude'):
        for option in ('container', 'directory'):
361 362 363 364
            try:
                self.settings[option] = self.cnf.get_sync(sync, option)
            except KeyError:
                LOG.debug('No %s is set' % option)
365

366 367
        LOG.debug('Finished loading settings')

368
    def _dump_settings(self):
369
        LOG.debug('Saving settings')
370
        sync = self._get_default_sync()
371
        changes = False
372

373 374 375 376 377
        if not self.settings.get('url', None):
            LOG.debug('No cloud settings to save')
        else:
            LOG.debug('Save cloud settings')
            cloud = self._get_sync_cloud(sync)
378 379

            try:
380
                old_url = self.cnf.get_cloud(cloud, 'url') or ''
381
            except KeyError:
382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397
                old_url = self.settings['url']

            while old_url and old_url != self.settings['url']:
                cloud = '%s_%s' % (cloud, sync)
                try:
                    self.cnf.get_cloud(cloud, 'url')
                except KeyError:
                    break

            LOG.debug('Cloud name is %s' % cloud)
            self.cnf.set_cloud(cloud, 'url', self.settings['url'])
            self.cnf.set_cloud(cloud, 'token', self.settings['token'] or '')
            self.cnf.set_sync(sync, 'cloud', cloud)
            changes = True

        LOG.debug('Save sync settings, name is %s' % sync)
398 399
        # for option in ('directory', 'container', 'exclude'):
        for option in ('directory', 'container'):
400
            self.cnf.set_sync(sync, option, self.settings[option] or '')
401
            changes = True
402

403 404 405 406 407
        if changes:
            self.cnf.write()
            LOG.debug('Settings saved')
        else:
            LOG.debug('No setting changes spotted')
408

409 410 411 412 413
    def _essentials_changed(self, new_settings):
        """Check if essential settings have changed in new_settings"""
        return all([
            self.settings[e] == self.settings[e] for e in self.essentials])

414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440
    def _update_statistics(self):
        """Update statistics by consuming and understanding syncer messages"""
        if self.can_sync():
            msg = self.syncer.get_next_message()
            if not msg:
                if self.status['unsynced'] == self.status['synced']:
                    self.status['unsynced'] = 0
                    self.status['synced'] = 0
            while (msg):
                if isinstance(msg, messaging.SyncMessage):
                    LOG.info('Start syncing "%s"' % msg.objname)
                    self.status['unsynced'] += 1
                elif isinstance(msg, messaging.AckSyncMessage):
                    LOG.info('Finished syncing "%s"' % msg.objname)
                    self.status['synced'] += 1
                elif isinstance(msg, messaging.CollisionMessage):
                    LOG.info('Collision for "%s"' % msg.objname)
                elif isinstance(msg, messaging.ConflictStashMessage):
                    LOG.info('Conflict for "%s"' % msg.objname)
                else:
                    LOG.debug('Consumed msg %s' % msg)
                msg = self.syncer.get_next_message()

    def can_sync(self):
        """Check if settings are enough to setup a syncing proccess"""
        return all([self.settings[e] for e in self.essentials])

441 442 443
    def init_sync(self):
        """Initialize syncer"""
        sync = self._get_default_sync()
444 445 446 447 448 449 450 451 452 453 454 455 456 457 458

        kwargs = dict(agkyra_path=AGKYRA_DIR)
        # Get SSL settings
        cloud = self._get_sync_cloud(sync)
        try:
            ignore_ssl = self.cnf.get_cloud(cloud, 'ignore_ssl') in ('on', )
            kwargs['ignore_ssl'] = ignore_ssl
        except KeyError:
            ignore_ssl = None
        if not ignore_ssl:
            try:
                kwargs['ca_certs'] = self.cnf.get_cloud(cloud, 'ca_certs')
            except KeyError:
                pass

459
        syncer_ = None
460 461 462 463 464 465 466
        try:
            syncer_settings = setup.SyncerSettings(
                self.settings['url'], self.settings['token'],
                self.settings['container'], self.settings['directory'],
                **kwargs)
            master = pithos_client.PithosFileClient(syncer_settings)
            slave = localfs_client.LocalfsFileClient(syncer_settings)
467
            syncer_ = syncer.FileSyncer(syncer_settings, master, slave)
468
            self.syncer_settings = syncer_settings
469 470 471 472
            syncer_.initiate_probe()
        finally:
            with SYNCERS.lock() as d:
                d[0] = syncer_
473

474 475
    # Syncer-related methods
    def get_status(self):
476
        if getattr(self, 'syncer', None) and self.can_sync():
477 478 479 480 481 482
            self._update_statistics()
            self.status['paused'] = self.syncer.paused
            self.status['can_sync'] = self.can_sync()
        else:
            self.status = dict(
                progress=0, synced=0, unsynced=0, paused=True, can_sync=False)
483 484 485 486 487 488
        return self.status

    def get_settings(self):
        return self.settings

    def set_settings(self, new_settings):
489
        """Set the settings and dump them to permanent storage if needed"""
490
        # Prepare setting save
491
        could_sync = getattr(self, 'syncer', None) and self.can_sync()
492 493 494
        was_active = False
        if could_sync and not self.syncer.paused:
            was_active = True
495 496 497 498
            self.pause_sync()
        must_reset_syncing = self._essentials_changed(new_settings)

        # save settings
499
        self.settings.update(new_settings)
500 501
        self._dump_settings()

502 503 504 505
        # Restart
        if self.can_sync():
            if must_reset_syncing or not could_sync:
                self.init_sync()
506
            if was_active:
507 508
                self.start_sync()

509
    def pause_sync(self):
Giorgos Korfiatis's avatar
Giorgos Korfiatis committed
510
        self.syncer.stop_decide()
511 512
        LOG.debug('Wait open syncs to complete')
        self.syncer.wait_sync_threads()
513 514

    def start_sync(self):
515
        self.syncer.start_decide()
516 517 518 519 520 521 522 523 524 525 526

    def send_json(self, msg):
        LOG.debug('send: %s' % msg)
        self.send(json.dumps(msg))

    # Protocol handling methods
    def _post(self, r):
        """Handle POST requests"""
        if self.accepted:
            action = r['path']
            if action == 'shutdown':
527
                retry_on_locked_db(self.clean_db)
528 529
                # self._shutdown()
                # self.terminate()
530 531 532 533 534 535
                return
            {
                'start': self.start_sync,
                'pause': self.pause_sync
            }[action]()
            self.send_json({'OK': 200, 'action': 'post %s' % action})
536
        elif r['ui_id'] == self.ui_id:
537
            self.accepted = True
538
            Thread(target=self.heartbeat).start()
539
            self.send_json({'ACCEPTED': 202, 'action': 'post ui_id'})
540
            self._load_settings()
541
            if (not self.syncer) and self.can_sync():
542
                self.init_sync()
Stavros Sachtouris's avatar
Stavros Sachtouris committed
543
                self.start_sync()
544
        else:
545
            action = r.get('path', 'ui_id')
546 547 548 549 550
            self.send_json({'REJECTED': 401, 'action': 'post %s' % action})
            self.terminate()

    def _put(self, r):
        """Handle PUT requests"""
551
        if self.accepted:
552 553 554 555 556
            LOG.debug('put %s' % r)
            action = r.pop('path')
            self.set_settings(r)
            r.update({'CREATED': 201, 'action': 'put %s' % action})
            self.send_json(r)
557 558
        else:
            action = r['path']
559 560
            self.send_json({
                'UNAUTHORIZED UI': 401, 'action': 'put %s' % action})
561
            self.terminate()
562 563 564 565 566

    def _get(self, r):
        """Handle GET requests"""
        action = r.pop('path')
        if not self.accepted:
567 568
            self.send_json({
                'UNAUTHORIZED UI': 401, 'action': 'get %s' % action})
569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594
            self.terminate()
        else:
            data = {
                'settings': self.get_settings,
                'status': self.get_status,
            }[action]()
            data['action'] = 'get %s' % action
            self.send_json(data)

    def received_message(self, message):
        """Route requests to corresponding handling methods"""
        LOG.debug('recv: %s' % message)
        try:
            r = json.loads('%s' % message)
        except ValueError as ve:
            self.send_json({'BAD REQUEST': 400})
            LOG.error('JSON ERROR: %s' % ve)
            return
        try:
            method = r.pop('method')
            {
                'post': self._post,
                'put': self._put,
                'get': self._get
            }[method](r)
        except KeyError as ke:
595 596
            action = method + ' ' + r.get('path', '')
            self.send_json({'BAD REQUEST': 400, 'action': action})
597
            LOG.error('KEY ERROR: %s' % ke)
598 599 600 601 602
        except setup.ClientError as ce:
            action = '%s %s' % (
                method, r.get('path', 'ui_id' if 'ui_id' in r else ''))
            self.send_json({'%s' % ce: ce.status, 'action': action})
            return
603 604 605 606
        except Exception as e:
            self.send_json({'INTERNAL ERROR': 500})
            LOG.error('EXCEPTION: %s' % e)
            self.terminate()
607 608 609 610 611 612 613 614 615


def launch_server():
    """Launch the server in a separate process"""
    LOG.info('Start SessionHelper session')
    pid = os.fork()
    if not pid:
        server_path = os.path.join(CURPATH, 'scripts', 'server.py')
        os.execlp("python", "python", server_path)