blob: 81050715ea1613c06f0ec3dbbd73c32e55c7837a [file] [log] [blame]
Brad Bishopa34c0302019-09-23 22:34:48 -04001# Copyright (C) 2019 Garmin Ltd.
2#
3# SPDX-License-Identifier: GPL-2.0-only
4#
5
6from contextlib import closing
7from datetime import datetime
8import asyncio
9import json
10import logging
11import math
12import os
13import signal
14import socket
15import time
Andrew Geissler475cb722020-07-10 16:00:51 -050016from . import chunkify, DEFAULT_MAX_CHUNK
Brad Bishopa34c0302019-09-23 22:34:48 -040017
18logger = logging.getLogger('hashserv.server')
19
20
21class Measurement(object):
22 def __init__(self, sample):
23 self.sample = sample
24
25 def start(self):
26 self.start_time = time.perf_counter()
27
28 def end(self):
29 self.sample.add(time.perf_counter() - self.start_time)
30
31 def __enter__(self):
32 self.start()
33 return self
34
35 def __exit__(self, *args, **kwargs):
36 self.end()
37
38
39class Sample(object):
40 def __init__(self, stats):
41 self.stats = stats
42 self.num_samples = 0
43 self.elapsed = 0
44
45 def measure(self):
46 return Measurement(self)
47
48 def __enter__(self):
49 return self
50
51 def __exit__(self, *args, **kwargs):
52 self.end()
53
54 def add(self, elapsed):
55 self.num_samples += 1
56 self.elapsed += elapsed
57
58 def end(self):
59 if self.num_samples:
60 self.stats.add(self.elapsed)
61 self.num_samples = 0
62 self.elapsed = 0
63
64
65class Stats(object):
66 def __init__(self):
67 self.reset()
68
69 def reset(self):
70 self.num = 0
71 self.total_time = 0
72 self.max_time = 0
73 self.m = 0
74 self.s = 0
75 self.current_elapsed = None
76
77 def add(self, elapsed):
78 self.num += 1
79 if self.num == 1:
80 self.m = elapsed
81 self.s = 0
82 else:
83 last_m = self.m
84 self.m = last_m + (elapsed - last_m) / self.num
85 self.s = self.s + (elapsed - last_m) * (elapsed - self.m)
86
87 self.total_time += elapsed
88
89 if self.max_time < elapsed:
90 self.max_time = elapsed
91
92 def start_sample(self):
93 return Sample(self)
94
95 @property
96 def average(self):
97 if self.num == 0:
98 return 0
99 return self.total_time / self.num
100
101 @property
102 def stdev(self):
103 if self.num <= 1:
104 return 0
105 return math.sqrt(self.s / (self.num - 1))
106
107 def todict(self):
108 return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')}
109
110
Andrew Geissler475cb722020-07-10 16:00:51 -0500111class ClientError(Exception):
112 pass
113
Brad Bishopa34c0302019-09-23 22:34:48 -0400114class ServerClient(object):
Andrew Geissler475cb722020-07-10 16:00:51 -0500115 FAST_QUERY = 'SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1'
116 ALL_QUERY = 'SELECT * FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1'
117
Brad Bishopa34c0302019-09-23 22:34:48 -0400118 def __init__(self, reader, writer, db, request_stats):
119 self.reader = reader
120 self.writer = writer
121 self.db = db
122 self.request_stats = request_stats
Andrew Geissler475cb722020-07-10 16:00:51 -0500123 self.max_chunk = DEFAULT_MAX_CHUNK
124
125 self.handlers = {
126 'get': self.handle_get,
127 'report': self.handle_report,
128 'report-equiv': self.handle_equivreport,
129 'get-stream': self.handle_get_stream,
130 'get-stats': self.handle_get_stats,
131 'reset-stats': self.handle_reset_stats,
132 'chunk-stream': self.handle_chunk,
133 }
Brad Bishopa34c0302019-09-23 22:34:48 -0400134
135 async def process_requests(self):
136 try:
137 self.addr = self.writer.get_extra_info('peername')
138 logger.debug('Client %r connected' % (self.addr,))
139
140 # Read protocol and version
141 protocol = await self.reader.readline()
142 if protocol is None:
143 return
144
145 (proto_name, proto_version) = protocol.decode('utf-8').rstrip().split()
Andrew Geissler475cb722020-07-10 16:00:51 -0500146 if proto_name != 'OEHASHEQUIV':
147 return
148
149 proto_version = tuple(int(v) for v in proto_version.split('.'))
150 if proto_version < (1, 0) or proto_version > (1, 1):
Brad Bishopa34c0302019-09-23 22:34:48 -0400151 return
152
153 # Read headers. Currently, no headers are implemented, so look for
154 # an empty line to signal the end of the headers
155 while True:
156 line = await self.reader.readline()
157 if line is None:
158 return
159
160 line = line.decode('utf-8').rstrip()
161 if not line:
162 break
163
164 # Handle messages
Brad Bishopa34c0302019-09-23 22:34:48 -0400165 while True:
166 d = await self.read_message()
167 if d is None:
168 break
Andrew Geissler475cb722020-07-10 16:00:51 -0500169 await self.dispatch_message(d)
Brad Bishopa34c0302019-09-23 22:34:48 -0400170 await self.writer.drain()
Andrew Geissler475cb722020-07-10 16:00:51 -0500171 except ClientError as e:
172 logger.error(str(e))
Brad Bishopa34c0302019-09-23 22:34:48 -0400173 finally:
174 self.writer.close()
175
Andrew Geissler475cb722020-07-10 16:00:51 -0500176 async def dispatch_message(self, msg):
177 for k in self.handlers.keys():
178 if k in msg:
179 logger.debug('Handling %s' % k)
180 if 'stream' in k:
181 await self.handlers[k](msg[k])
182 else:
183 with self.request_stats.start_sample() as self.request_sample, \
184 self.request_sample.measure():
185 await self.handlers[k](msg[k])
186 return
187
188 raise ClientError("Unrecognized command %r" % msg)
189
Brad Bishopa34c0302019-09-23 22:34:48 -0400190 def write_message(self, msg):
Andrew Geissler475cb722020-07-10 16:00:51 -0500191 for c in chunkify(json.dumps(msg), self.max_chunk):
192 self.writer.write(c.encode('utf-8'))
Brad Bishopa34c0302019-09-23 22:34:48 -0400193
194 async def read_message(self):
195 l = await self.reader.readline()
196 if not l:
197 return None
198
199 try:
200 message = l.decode('utf-8')
201
202 if not message.endswith('\n'):
203 return None
204
205 return json.loads(message)
206 except (json.JSONDecodeError, UnicodeDecodeError) as e:
207 logger.error('Bad message from client: %r' % message)
208 raise e
209
Andrew Geissler475cb722020-07-10 16:00:51 -0500210 async def handle_chunk(self, request):
211 lines = []
212 try:
213 while True:
214 l = await self.reader.readline()
215 l = l.rstrip(b"\n").decode("utf-8")
216 if not l:
217 break
218 lines.append(l)
219
220 msg = json.loads(''.join(lines))
221 except (json.JSONDecodeError, UnicodeDecodeError) as e:
222 logger.error('Bad message from client: %r' % message)
223 raise e
224
225 if 'chunk-stream' in msg:
226 raise ClientError("Nested chunks are not allowed")
227
228 await self.dispatch_message(msg)
229
Brad Bishopa34c0302019-09-23 22:34:48 -0400230 async def handle_get(self, request):
231 method = request['method']
232 taskhash = request['taskhash']
233
Andrew Geissler475cb722020-07-10 16:00:51 -0500234 if request.get('all', False):
235 row = self.query_equivalent(method, taskhash, self.ALL_QUERY)
236 else:
237 row = self.query_equivalent(method, taskhash, self.FAST_QUERY)
238
Brad Bishopa34c0302019-09-23 22:34:48 -0400239 if row is not None:
240 logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
Andrew Geissler475cb722020-07-10 16:00:51 -0500241 d = {k: row[k] for k in row.keys()}
Brad Bishopa34c0302019-09-23 22:34:48 -0400242
243 self.write_message(d)
244 else:
245 self.write_message(None)
246
247 async def handle_get_stream(self, request):
248 self.write_message('ok')
249
250 while True:
251 l = await self.reader.readline()
252 if not l:
253 return
254
255 try:
256 # This inner loop is very sensitive and must be as fast as
257 # possible (which is why the request sample is handled manually
258 # instead of using 'with', and also why logging statements are
259 # commented out.
260 self.request_sample = self.request_stats.start_sample()
261 request_measure = self.request_sample.measure()
262 request_measure.start()
263
264 l = l.decode('utf-8').rstrip()
265 if l == 'END':
266 self.writer.write('ok\n'.encode('utf-8'))
267 return
268
269 (method, taskhash) = l.split()
270 #logger.debug('Looking up %s %s' % (method, taskhash))
Andrew Geissler475cb722020-07-10 16:00:51 -0500271 row = self.query_equivalent(method, taskhash, self.FAST_QUERY)
Brad Bishopa34c0302019-09-23 22:34:48 -0400272 if row is not None:
273 msg = ('%s\n' % row['unihash']).encode('utf-8')
274 #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
275 else:
276 msg = '\n'.encode('utf-8')
277
278 self.writer.write(msg)
279 finally:
280 request_measure.end()
281 self.request_sample.end()
282
283 await self.writer.drain()
284
285 async def handle_report(self, data):
286 with closing(self.db.cursor()) as cursor:
287 cursor.execute('''
288 -- Find tasks with a matching outhash (that is, tasks that
289 -- are equivalent)
290 SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND outhash=:outhash
291
292 -- If there is an exact match on the taskhash, return it.
293 -- Otherwise return the oldest matching outhash of any
294 -- taskhash
295 ORDER BY CASE WHEN taskhash=:taskhash THEN 1 ELSE 2 END,
296 created ASC
297
298 -- Only return one row
299 LIMIT 1
300 ''', {k: data[k] for k in ('method', 'outhash', 'taskhash')})
301
302 row = cursor.fetchone()
303
304 # If no matching outhash was found, or one *was* found but it
305 # wasn't an exact match on the taskhash, a new entry for this
306 # taskhash should be added
307 if row is None or row['taskhash'] != data['taskhash']:
308 # If a row matching the outhash was found, the unihash for
309 # the new taskhash should be the same as that one.
310 # Otherwise the caller provided unihash is used.
311 unihash = data['unihash']
312 if row is not None:
313 unihash = row['unihash']
314
315 insert_data = {
316 'method': data['method'],
317 'outhash': data['outhash'],
318 'taskhash': data['taskhash'],
319 'unihash': unihash,
320 'created': datetime.now()
321 }
322
323 for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'):
324 if k in data:
325 insert_data[k] = data[k]
326
327 cursor.execute('''INSERT INTO tasks_v2 (%s) VALUES (%s)''' % (
328 ', '.join(sorted(insert_data.keys())),
329 ', '.join(':' + k for k in sorted(insert_data.keys()))),
330 insert_data)
331
332 self.db.commit()
333
334 logger.info('Adding taskhash %s with unihash %s',
335 data['taskhash'], unihash)
336
337 d = {
338 'taskhash': data['taskhash'],
339 'method': data['method'],
340 'unihash': unihash
341 }
342 else:
343 d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
344
345 self.write_message(d)
346
Andrew Geissler82c905d2020-04-13 13:39:40 -0500347 async def handle_equivreport(self, data):
348 with closing(self.db.cursor()) as cursor:
349 insert_data = {
350 'method': data['method'],
351 'outhash': "",
352 'taskhash': data['taskhash'],
353 'unihash': data['unihash'],
354 'created': datetime.now()
355 }
356
357 for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'):
358 if k in data:
359 insert_data[k] = data[k]
360
361 cursor.execute('''INSERT OR IGNORE INTO tasks_v2 (%s) VALUES (%s)''' % (
362 ', '.join(sorted(insert_data.keys())),
363 ', '.join(':' + k for k in sorted(insert_data.keys()))),
364 insert_data)
365
366 self.db.commit()
367
368 # Fetch the unihash that will be reported for the taskhash. If the
369 # unihash matches, it means this row was inserted (or the mapping
370 # was already valid)
Andrew Geissler475cb722020-07-10 16:00:51 -0500371 row = self.query_equivalent(data['method'], data['taskhash'], self.FAST_QUERY)
Andrew Geissler82c905d2020-04-13 13:39:40 -0500372
373 if row['unihash'] == data['unihash']:
374 logger.info('Adding taskhash equivalence for %s with unihash %s',
375 data['taskhash'], row['unihash'])
376
377 d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
378
379 self.write_message(d)
380
381
Brad Bishopa34c0302019-09-23 22:34:48 -0400382 async def handle_get_stats(self, request):
383 d = {
384 'requests': self.request_stats.todict(),
385 }
386
387 self.write_message(d)
388
389 async def handle_reset_stats(self, request):
390 d = {
391 'requests': self.request_stats.todict(),
392 }
393
394 self.request_stats.reset()
395 self.write_message(d)
396
Andrew Geissler475cb722020-07-10 16:00:51 -0500397 def query_equivalent(self, method, taskhash, query):
Brad Bishopa34c0302019-09-23 22:34:48 -0400398 # This is part of the inner loop and must be as fast as possible
399 try:
400 cursor = self.db.cursor()
Andrew Geissler475cb722020-07-10 16:00:51 -0500401 cursor.execute(query, {'method': method, 'taskhash': taskhash})
Brad Bishopa34c0302019-09-23 22:34:48 -0400402 return cursor.fetchone()
403 except:
404 cursor.close()
405
406
407class Server(object):
408 def __init__(self, db, loop=None):
409 self.request_stats = Stats()
410 self.db = db
411
412 if loop is None:
413 self.loop = asyncio.new_event_loop()
414 self.close_loop = True
415 else:
416 self.loop = loop
417 self.close_loop = False
418
419 self._cleanup_socket = None
420
421 def start_tcp_server(self, host, port):
422 self.server = self.loop.run_until_complete(
423 asyncio.start_server(self.handle_client, host, port, loop=self.loop)
424 )
425
426 for s in self.server.sockets:
427 logger.info('Listening on %r' % (s.getsockname(),))
428 # Newer python does this automatically. Do it manually here for
429 # maximum compatibility
430 s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
431 s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
432
433 name = self.server.sockets[0].getsockname()
434 if self.server.sockets[0].family == socket.AF_INET6:
435 self.address = "[%s]:%d" % (name[0], name[1])
436 else:
437 self.address = "%s:%d" % (name[0], name[1])
438
439 def start_unix_server(self, path):
440 def cleanup():
441 os.unlink(path)
442
443 cwd = os.getcwd()
444 try:
445 # Work around path length limits in AF_UNIX
446 os.chdir(os.path.dirname(path))
447 self.server = self.loop.run_until_complete(
448 asyncio.start_unix_server(self.handle_client, os.path.basename(path), loop=self.loop)
449 )
450 finally:
451 os.chdir(cwd)
452
453 logger.info('Listening on %r' % path)
454
455 self._cleanup_socket = cleanup
456 self.address = "unix://%s" % os.path.abspath(path)
457
458 async def handle_client(self, reader, writer):
459 # writer.transport.set_write_buffer_limits(0)
460 try:
461 client = ServerClient(reader, writer, self.db, self.request_stats)
462 await client.process_requests()
463 except Exception as e:
464 import traceback
465 logger.error('Error from client: %s' % str(e), exc_info=True)
466 traceback.print_exc()
467 writer.close()
468 logger.info('Client disconnected')
469
470 def serve_forever(self):
471 def signal_handler():
472 self.loop.stop()
473
474 self.loop.add_signal_handler(signal.SIGTERM, signal_handler)
475
476 try:
477 self.loop.run_forever()
478 except KeyboardInterrupt:
479 pass
480
481 self.server.close()
482 self.loop.run_until_complete(self.server.wait_closed())
483 logger.info('Server shutting down')
484
485 if self.close_loop:
486 self.loop.close()
487
488 if self._cleanup_socket is not None:
489 self._cleanup_socket()