blob: 46bca7cab3283371027bb6c9b5ffe49398a70ce0 [file] [log] [blame]
Brad Bishop19323692019-04-05 15:28:33 -04001# Copyright (C) 2018 Garmin Ltd.
2#
3# This program is free software; you can redistribute it and/or modify
4# it under the terms of the GNU General Public License version 2 as
5# published by the Free Software Foundation.
6#
7# This program is distributed in the hope that it will be useful,
8# but WITHOUT ANY WARRANTY; without even the implied warranty of
9# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10# GNU General Public License for more details.
11#
12# You should have received a copy of the GNU General Public License along
13# with this program; if not, write to the Free Software Foundation, Inc.,
14# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
15
16from http.server import BaseHTTPRequestHandler, HTTPServer
17import contextlib
18import urllib.parse
19import sqlite3
20import json
21import traceback
22import logging
23from datetime import datetime
24
25logger = logging.getLogger('hashserv')
26
27class HashEquivalenceServer(BaseHTTPRequestHandler):
28 def log_message(self, f, *args):
29 logger.debug(f, *args)
30
31 def do_GET(self):
32 try:
33 p = urllib.parse.urlparse(self.path)
34
35 if p.path != self.prefix + '/v1/equivalent':
36 self.send_error(404)
37 return
38
39 query = urllib.parse.parse_qs(p.query, strict_parsing=True)
40 method = query['method'][0]
41 taskhash = query['taskhash'][0]
42
43 d = None
44 with contextlib.closing(self.db.cursor()) as cursor:
45 cursor.execute('SELECT taskhash, method, unihash FROM tasks_v1 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1',
46 {'method': method, 'taskhash': taskhash})
47
48 row = cursor.fetchone()
49
50 if row is not None:
51 logger.debug('Found equivalent task %s', row['taskhash'])
52 d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
53
54 self.send_response(200)
55 self.send_header('Content-Type', 'application/json; charset=utf-8')
56 self.end_headers()
57 self.wfile.write(json.dumps(d).encode('utf-8'))
58 except:
59 logger.exception('Error in GET')
60 self.send_error(400, explain=traceback.format_exc())
61 return
62
63 def do_POST(self):
64 try:
65 p = urllib.parse.urlparse(self.path)
66
67 if p.path != self.prefix + '/v1/equivalent':
68 self.send_error(404)
69 return
70
71 length = int(self.headers['content-length'])
72 data = json.loads(self.rfile.read(length).decode('utf-8'))
73
74 with contextlib.closing(self.db.cursor()) as cursor:
75 cursor.execute('''
76 SELECT taskhash, method, unihash FROM tasks_v1 WHERE method=:method AND outhash=:outhash
77 ORDER BY CASE WHEN taskhash=:taskhash THEN 1 ELSE 2 END,
78 created ASC
79 LIMIT 1
80 ''', {k: data[k] for k in ('method', 'outhash', 'taskhash')})
81
82 row = cursor.fetchone()
83
84 if row is None or row['taskhash'] != data['taskhash']:
85 unihash = data['unihash']
86 if row is not None:
87 unihash = row['unihash']
88
89 insert_data = {
90 'method': data['method'],
91 'outhash': data['outhash'],
92 'taskhash': data['taskhash'],
93 'unihash': unihash,
94 'created': datetime.now()
95 }
96
97 for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'):
98 if k in data:
99 insert_data[k] = data[k]
100
101 cursor.execute('''INSERT INTO tasks_v1 (%s) VALUES (%s)''' % (
102 ', '.join(sorted(insert_data.keys())),
103 ', '.join(':' + k for k in sorted(insert_data.keys()))),
104 insert_data)
105
106 logger.info('Adding taskhash %s with unihash %s', data['taskhash'], unihash)
107 cursor.execute('SELECT taskhash, method, unihash FROM tasks_v1 WHERE id=:id', {'id': cursor.lastrowid})
108 row = cursor.fetchone()
109
110 self.db.commit()
111
112 d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
113
114 self.send_response(200)
115 self.send_header('Content-Type', 'application/json; charset=utf-8')
116 self.end_headers()
117 self.wfile.write(json.dumps(d).encode('utf-8'))
118 except:
119 logger.exception('Error in POST')
120 self.send_error(400, explain=traceback.format_exc())
121 return
122
123def create_server(addr, db, prefix=''):
124 class Handler(HashEquivalenceServer):
125 pass
126
127 Handler.prefix = prefix
128 Handler.db = db
129 db.row_factory = sqlite3.Row
130
131 with contextlib.closing(db.cursor()) as cursor:
132 cursor.execute('''
133 CREATE TABLE IF NOT EXISTS tasks_v1 (
134 id INTEGER PRIMARY KEY AUTOINCREMENT,
135 method TEXT NOT NULL,
136 outhash TEXT NOT NULL,
137 taskhash TEXT NOT NULL,
138 unihash TEXT NOT NULL,
139 created DATETIME,
140
141 -- Optional fields
142 owner TEXT,
143 PN TEXT,
144 PV TEXT,
145 PR TEXT,
146 task TEXT,
147 outhash_siginfo TEXT
148 )
149 ''')
150
151 logger.info('Starting server on %s', addr)
152 return HTTPServer(addr, Handler)