protocol.py 17.4 KB
Newer Older
Giorgos Korfiatis's avatar
Giorgos Korfiatis committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright (C) 2015 GRNET S.A.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

16
from wsgiref.simple_server import make_server
17
from ws4py.websocket import WebSocket
18
19
20
21
from ws4py.server.wsgiutils import WebSocketWSGIApplication
from ws4py.server.wsgirefserver import WSGIServer, WebSocketWSGIRequestHandler
from hashlib import sha1
from threading import Thread
22
23
import sqlite3
import time
24
import os
25
26
import json
import logging
27
from agkyra.syncer import (
28
    syncer, setup, pithos_client, localfs_client, messaging, utils)
29
from agkyra.config import AgkyraConfig, AGKYRA_DIR
30
31
32
33
34


LOG = logging.getLogger(__name__)


35
class SessionHelper(object):
36
37
38
    """Agkyra Helper Server sets a WebSocket server with the Helper protocol
    It also provided methods for running and killing the Helper server
    """
39
    session_timeout = 20
40

41
    def __init__(self, **kwargs):
42
        """Setup the helper server"""
43
44
45
46
47
48
49
50
51
        self.session_db = kwargs.get(
            'session_db', os.path.join(AGKYRA_DIR, 'session.db'))
        self.session_relation = kwargs.get('session_relation', 'heart')

        LOG.debug('Connect to db')
        self.db = sqlite3.connect(self.session_db)
        self._init_db_relation()
        self.session = self._load_active_session() or self._create_session()

52
53
        self.db.close()

54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    def _init_db_relation(self):
        self.db.execute('BEGIN')
        self.db.execute(
            'CREATE TABLE IF NOT EXISTS %s ('
            'ui_id VARCHAR(256), address text, beat VARCHAR(32)'
            ')' % self.session_relation)
        self.db.commit()

    def _load_active_session(self):
        """Load a session from db"""
        r = self.db.execute('SELECT * FROM %s' % self.session_relation)
        sessions = r.fetchall()
        if sessions:
            last = sessions[-1]
            now, last_beat = time.time(), float(last[2])
            if now - last_beat < self.session_timeout:
                # Found an active session
                return dict(ui_id=last[0], address=last[1])
        return None

    def _create_session(self):
        """Create session credentials"""
        ui_id = sha1(os.urandom(128)).hexdigest()

Giorgos Korfiatis's avatar
Giorgos Korfiatis committed
78
        LOCAL_ADDR = '127.0.0.1'
79
        WebSocketProtocol.ui_id = ui_id
80
81
        WebSocketProtocol.session_db = self.session_db
        WebSocketProtocol.session_relation = self.session_relation
82
        server = make_server(
Giorgos Korfiatis's avatar
Giorgos Korfiatis committed
83
            LOCAL_ADDR, 0,
84
85
86
87
            server_class=WSGIServer,
            handler_class=WebSocketWSGIRequestHandler,
            app=WebSocketWSGIApplication(handler_cls=WebSocketProtocol))
        server.initialize_websockets_manager()
Giorgos Korfiatis's avatar
Giorgos Korfiatis committed
88
        address = 'ws://%s:%s' % (LOCAL_ADDR, server.server_port)
89
90
        self.server = server

91
92
93
94
95
96
97
98
        self.db.execute('BEGIN')
        self.db.execute('DELETE FROM %s' % self.session_relation)
        self.db.execute('INSERT INTO %s VALUES ("%s", "%s", "%s")' % (
            self.session_relation, ui_id, address, time.time()))
        self.db.commit()

        return dict(ui_id=ui_id, address=address)

99
100
    def start(self):
        """Start the helper server in a thread"""
101
102
        if getattr(self, 'server', None):
            Thread(target=self.server.serve_forever).start()
103
104
105

    def shutdown(self):
        """Shutdown the server (needs another thread) and join threads"""
106
107
108
109
        if getattr(self, 'server', None):
            t = Thread(target=self.server.shutdown)
            t.start()
            t.join()
110
111


112
113
114
115
class WebSocketProtocol(WebSocket):
    """Helper-side WebSocket protocol for communication with GUI:

    -- INTERRNAL HANDSAKE --
116
    GUI: {"method": "post", "ui_id": <GUI ID>}
117
    HELPER: {"ACCEPTED": 202, "action": "post ui_id"}" or
118
        "{"REJECTED": 401, "action": "post ui_id"}
119

120
121
122
123
    -- ERRORS WITH SIGNIFICANCE --
    If the token doesn't work:
    HELPER: {"action": <action that caused the error>, "UNAUTHORIZED": 401}

124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
    -- SHUT DOWN --
    GUI: {"method": "post", "path": "shutdown"}

    -- PAUSE --
    GUI: {"method": "post", "path": "pause"}
    HELPER: {"OK": 200, "action": "post pause"} or error

    -- start --
    GUI: {"method": "post", "path": "start"}
    HELPER: {"OK": 200, "action": "post start"} or error

    -- GET SETTINGS --
    GUI: {"method": "get", "path": "settings"}
    HELPER:
        {
            "action": "get settings",
            "token": <user token>,
            "url": <auth url>,
            "container": <container>,
            "directory": <local directory>,
            "exclude": <file path>
        } or {<ERROR>: <ERROR CODE>}

    -- PUT SETTINGS --
    GUI: {
            "method": "put", "path": "settings",
            "token": <user token>,
            "url": <auth url>,
            "container": <container>,
            "directory": <local directory>,
            "exclude": <file path>
        }
    HELPER: {"CREATED": 201, "action": "put settings",} or
        {<ERROR>: <ERROR CODE>, "action": "get settings",}

    -- GET STATUS --
    GUI: {"method": "get", "path": "status"}
161
162
163
164
165
    HELPER: {
        "can_sync": <boolean>,
        "progress": <int>,
        "paused": <boolean>,
        "action": "get status"} or
166
167
168
        {<ERROR>: <ERROR CODE>, "action": "get status"}
    """

169
    ui_id = None
170
    db, session_db, session_relation = None, None, None
171
172
173
174
    accepted = False
    settings = dict(
        token=None, url=None,
        container=None, directory=None,
175
        exclude=None)
176
177
    status = dict(
        progress=0, synced=0, unsynced=0, paused=True, can_sync=False)
178
179
    file_syncer = None
    cnf = AgkyraConfig()
180
    essentials = ('url', 'token', 'container', 'directory')
181

182
183
184
185
186
187
188
189
190
    def heartbeat(self):
        if not self.db:
            self.db = sqlite3.connect(self.session_db)
        self.db.execute('BEGIN')
        self.db.execute('UPDATE %s SET beat="%s" WHERE ui_id="%s"' % (
            self.session_relation, time.time(), self.ui_id))
        self.db.commit()
        time.sleep(2)

191
192
193
194
    def _get_default_sync(self):
        """Get global.default_sync or pick the first sync as default
        If there are no syncs, create a 'default' sync.
        """
195
        sync = self.cnf.get('global', 'default_sync')
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
        if not sync:
            for sync in self.cnf.keys('sync'):
                break
            self.cnf.set('global', 'default_sync', sync or 'default')
        return sync or 'default'

    def _get_sync_cloud(self, sync):
        """Get the <sync>.cloud or pick the first cloud and use it
        In case of cloud picking, set the cloud as the <sync>.cloud for future
        sessions.
        If no clouds are found, create a 'default' cloud, with an empty url.
        """
        try:
            cloud = self.cnf.get_sync(sync, 'cloud')
        except KeyError:
            cloud = None
        if not cloud:
            for cloud in self.cnf.keys('cloud'):
                break
            self.cnf.set_sync(sync, 'cloud', cloud or 'default')
        return cloud or 'default'
217

218
219
220
221
    def _load_settings(self):
        LOG.debug('Start loading settings')
        sync = self._get_default_sync()
        cloud = self._get_sync_cloud(sync)
222
223

        try:
224
225
226
227
228
229
230
            self.settings['url'] = self.cnf.get_cloud(cloud, 'url')
        except Exception:
            self.settings['url'] = None
        try:
            self.settings['token'] = self.cnf.get_cloud(cloud, 'token')
        except Exception:
            self.settings['url'] = None
231

232
233
        # for option in ('container', 'directory', 'exclude'):
        for option in ('container', 'directory'):
234
235
236
237
            try:
                self.settings[option] = self.cnf.get_sync(sync, option)
            except KeyError:
                LOG.debug('No %s is set' % option)
238

239
240
        LOG.debug('Finished loading settings')

241
    def _dump_settings(self):
242
        LOG.debug('Saving settings')
243
        sync = self._get_default_sync()
244
        changes = False
245

246
247
248
249
250
        if not self.settings.get('url', None):
            LOG.debug('No cloud settings to save')
        else:
            LOG.debug('Save cloud settings')
            cloud = self._get_sync_cloud(sync)
251
252

            try:
253
                old_url = self.cnf.get_cloud(cloud, 'url') or ''
254
            except KeyError:
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
                old_url = self.settings['url']

            while old_url and old_url != self.settings['url']:
                cloud = '%s_%s' % (cloud, sync)
                try:
                    self.cnf.get_cloud(cloud, 'url')
                except KeyError:
                    break

            LOG.debug('Cloud name is %s' % cloud)
            self.cnf.set_cloud(cloud, 'url', self.settings['url'])
            self.cnf.set_cloud(cloud, 'token', self.settings['token'] or '')
            self.cnf.set_sync(sync, 'cloud', cloud)
            changes = True

        LOG.debug('Save sync settings, name is %s' % sync)
271
272
        # for option in ('directory', 'container', 'exclude'):
        for option in ('directory', 'container'):
273
            self.cnf.set_sync(sync, option, self.settings[option] or '')
274
            changes = True
275

276
277
278
279
280
        if changes:
            self.cnf.write()
            LOG.debug('Settings saved')
        else:
            LOG.debug('No setting changes spotted')
281

282
283
284
285
286
    def _essentials_changed(self, new_settings):
        """Check if essential settings have changed in new_settings"""
        return all([
            self.settings[e] == self.settings[e] for e in self.essentials])

287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
    def _update_statistics(self):
        """Update statistics by consuming and understanding syncer messages"""
        if self.can_sync():
            msg = self.syncer.get_next_message()
            if not msg:
                if self.status['unsynced'] == self.status['synced']:
                    self.status['unsynced'] = 0
                    self.status['synced'] = 0
            while (msg):
                if isinstance(msg, messaging.SyncMessage):
                    LOG.info('Start syncing "%s"' % msg.objname)
                    self.status['unsynced'] += 1
                elif isinstance(msg, messaging.AckSyncMessage):
                    LOG.info('Finished syncing "%s"' % msg.objname)
                    self.status['synced'] += 1
                elif isinstance(msg, messaging.CollisionMessage):
                    LOG.info('Collision for "%s"' % msg.objname)
                elif isinstance(msg, messaging.ConflictStashMessage):
                    LOG.info('Conflict for "%s"' % msg.objname)
                else:
                    LOG.debug('Consumed msg %s' % msg)
                msg = self.syncer.get_next_message()

    def can_sync(self):
        """Check if settings are enough to setup a syncing proccess"""
        return all([self.settings[e] for e in self.essentials])

314
315
316
    def init_sync(self):
        """Initialize syncer"""
        sync = self._get_default_sync()
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331

        kwargs = dict(agkyra_path=AGKYRA_DIR)
        # Get SSL settings
        cloud = self._get_sync_cloud(sync)
        try:
            ignore_ssl = self.cnf.get_cloud(cloud, 'ignore_ssl') in ('on', )
            kwargs['ignore_ssl'] = ignore_ssl
        except KeyError:
            ignore_ssl = None
        if not ignore_ssl:
            try:
                kwargs['ca_certs'] = self.cnf.get_cloud(cloud, 'ca_certs')
            except KeyError:
                pass

332
333
334
335
336
337
338
339
340
341
342
343
344
        try:
            syncer_settings = setup.SyncerSettings(
                self.settings['url'], self.settings['token'],
                self.settings['container'], self.settings['directory'],
                **kwargs)
            master = pithos_client.PithosFileClient(syncer_settings)
            slave = localfs_client.LocalfsFileClient(syncer_settings)
            self.syncer = syncer.FileSyncer(syncer_settings, master, slave)
            self.syncer_settings = syncer_settings
            self.syncer.initiate_probe()
        except setup.ClientError:
            self.syncer = None
            raise
345

346
347
    # Syncer-related methods
    def get_status(self):
348
        if getattr(self, 'syncer', None) and self.can_sync():
349
350
351
352
353
354
            self._update_statistics()
            self.status['paused'] = self.syncer.paused
            self.status['can_sync'] = self.can_sync()
        else:
            self.status = dict(
                progress=0, synced=0, unsynced=0, paused=True, can_sync=False)
355
356
357
358
359
360
        return self.status

    def get_settings(self):
        return self.settings

    def set_settings(self, new_settings):
361
        """Set the settings and dump them to permanent storage if needed"""
362
        # Prepare setting save
363
        could_sync = getattr(self, 'syncer', None) and self.can_sync()
364
365
366
        was_active = False
        if could_sync and not self.syncer.paused:
            was_active = True
367
368
369
370
            self.pause_sync()
        must_reset_syncing = self._essentials_changed(new_settings)

        # save settings
371
        self.settings.update(new_settings)
372
373
        self._dump_settings()

374
375
376
377
        # Restart
        if self.can_sync():
            if must_reset_syncing or not could_sync:
                self.init_sync()
378
            if was_active:
379
380
                self.start_sync()

381
    def pause_sync(self):
Giorgos Korfiatis's avatar
Giorgos Korfiatis committed
382
        self.syncer.stop_decide()
383
384
        LOG.debug('Wait open syncs to complete')
        self.syncer.wait_sync_threads()
385
386

    def start_sync(self):
387
        self.syncer.start_decide()
388
389
390
391

    # WebSocket connection methods
    def opened(self):
        LOG.debug('Helper: connection established')
392
393
394
        self.heart = utils.StoppableThread()
        self.heart.run_body = self.heartbeat
        self.heart.start()
395
396

    def closed(self, *args):
397
398
399
        """Stop server heart, empty DB and exit"""
        LOG.debug('Stop protocol heart')
        self.heart.stop()
400
401
402
403
        self.clean_db()

    def clean_db(self):
        """Clean DB from session traces"""
404
405
406
407
408
        LOG.debug('Remove session traces')
        self.db = sqlite3.connect(self.session_db)
        self.db.execute('BEGIN')
        self.db.execute('DELETE FROM %s' % self.session_relation)
        self.db.commit()
409
410
411
412
413
414
415
416
417
418
419

    def send_json(self, msg):
        LOG.debug('send: %s' % msg)
        self.send(json.dumps(msg))

    # Protocol handling methods
    def _post(self, r):
        """Handle POST requests"""
        if self.accepted:
            action = r['path']
            if action == 'shutdown':
420
421
                if self.can_sync():
                    self.syncer.stop_all_daemons()
422
423
                    LOG.debug('Wait open syncs to complete')
                    self.syncer.wait_sync_threads()
424
425
426
427
428
429
430
                self.close()
                return
            {
                'start': self.start_sync,
                'pause': self.pause_sync
            }[action]()
            self.send_json({'OK': 200, 'action': 'post %s' % action})
431
        elif r['ui_id'] == self.ui_id:
432
            self.accepted = True
433
            self.send_json({'ACCEPTED': 202, 'action': 'post ui_id'})
434
            self._load_settings()
435
436
            if self.can_sync():
                self.init_sync()
Stavros Sachtouris's avatar
Stavros Sachtouris committed
437
                self.start_sync()
438
        else:
439
            action = r.get('path', 'ui_id')
440
441
442
443
444
            self.send_json({'REJECTED': 401, 'action': 'post %s' % action})
            self.terminate()

    def _put(self, r):
        """Handle PUT requests"""
445
        if self.accepted:
446
447
448
449
450
            LOG.debug('put %s' % r)
            action = r.pop('path')
            self.set_settings(r)
            r.update({'CREATED': 201, 'action': 'put %s' % action})
            self.send_json(r)
451
452
        else:
            action = r['path']
453
454
            self.send_json({
                'UNAUTHORIZED UI': 401, 'action': 'put %s' % action})
455
            self.terminate()
456
457
458
459
460

    def _get(self, r):
        """Handle GET requests"""
        action = r.pop('path')
        if not self.accepted:
461
462
            self.send_json({
                'UNAUTHORIZED UI': 401, 'action': 'get %s' % action})
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
            self.terminate()
        else:
            data = {
                'settings': self.get_settings,
                'status': self.get_status,
            }[action]()
            data['action'] = 'get %s' % action
            self.send_json(data)

    def received_message(self, message):
        """Route requests to corresponding handling methods"""
        LOG.debug('recv: %s' % message)
        try:
            r = json.loads('%s' % message)
        except ValueError as ve:
            self.send_json({'BAD REQUEST': 400})
            LOG.error('JSON ERROR: %s' % ve)
            return
        try:
            method = r.pop('method')
            {
                'post': self._post,
                'put': self._put,
                'get': self._get
            }[method](r)
        except KeyError as ke:
            self.send_json({'BAD REQUEST': 400})
            LOG.error('KEY ERROR: %s' % ke)
491
492
493
494
495
        except setup.ClientError as ce:
            action = '%s %s' % (
                method, r.get('path', 'ui_id' if 'ui_id' in r else ''))
            self.send_json({'%s' % ce: ce.status, 'action': action})
            return
496
        except Exception as e:
497
498
            from traceback import print_stack
            print_stack(e)
499
500
501
            self.send_json({'INTERNAL ERROR': 500})
            LOG.error('EXCEPTION: %s' % e)
            self.terminate()