blob: cc7e48233ba6c9b82a6eeb14e4f09fff866a6ceb [file] [log] [blame]
Brad Bishopa34c0302019-09-23 22:34:48 -04001# Copyright (C) 2019 Garmin Ltd.
2#
3# SPDX-License-Identifier: GPL-2.0-only
4#
5
6from contextlib import closing
7from datetime import datetime
8import asyncio
9import json
10import logging
11import math
12import os
13import signal
14import socket
15import time
16
17logger = logging.getLogger('hashserv.server')
18
19
20class Measurement(object):
21 def __init__(self, sample):
22 self.sample = sample
23
24 def start(self):
25 self.start_time = time.perf_counter()
26
27 def end(self):
28 self.sample.add(time.perf_counter() - self.start_time)
29
30 def __enter__(self):
31 self.start()
32 return self
33
34 def __exit__(self, *args, **kwargs):
35 self.end()
36
37
38class Sample(object):
39 def __init__(self, stats):
40 self.stats = stats
41 self.num_samples = 0
42 self.elapsed = 0
43
44 def measure(self):
45 return Measurement(self)
46
47 def __enter__(self):
48 return self
49
50 def __exit__(self, *args, **kwargs):
51 self.end()
52
53 def add(self, elapsed):
54 self.num_samples += 1
55 self.elapsed += elapsed
56
57 def end(self):
58 if self.num_samples:
59 self.stats.add(self.elapsed)
60 self.num_samples = 0
61 self.elapsed = 0
62
63
64class Stats(object):
65 def __init__(self):
66 self.reset()
67
68 def reset(self):
69 self.num = 0
70 self.total_time = 0
71 self.max_time = 0
72 self.m = 0
73 self.s = 0
74 self.current_elapsed = None
75
76 def add(self, elapsed):
77 self.num += 1
78 if self.num == 1:
79 self.m = elapsed
80 self.s = 0
81 else:
82 last_m = self.m
83 self.m = last_m + (elapsed - last_m) / self.num
84 self.s = self.s + (elapsed - last_m) * (elapsed - self.m)
85
86 self.total_time += elapsed
87
88 if self.max_time < elapsed:
89 self.max_time = elapsed
90
91 def start_sample(self):
92 return Sample(self)
93
94 @property
95 def average(self):
96 if self.num == 0:
97 return 0
98 return self.total_time / self.num
99
100 @property
101 def stdev(self):
102 if self.num <= 1:
103 return 0
104 return math.sqrt(self.s / (self.num - 1))
105
106 def todict(self):
107 return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')}
108
109
110class ServerClient(object):
111 def __init__(self, reader, writer, db, request_stats):
112 self.reader = reader
113 self.writer = writer
114 self.db = db
115 self.request_stats = request_stats
116
117 async def process_requests(self):
118 try:
119 self.addr = self.writer.get_extra_info('peername')
120 logger.debug('Client %r connected' % (self.addr,))
121
122 # Read protocol and version
123 protocol = await self.reader.readline()
124 if protocol is None:
125 return
126
127 (proto_name, proto_version) = protocol.decode('utf-8').rstrip().split()
128 if proto_name != 'OEHASHEQUIV' or proto_version != '1.0':
129 return
130
131 # Read headers. Currently, no headers are implemented, so look for
132 # an empty line to signal the end of the headers
133 while True:
134 line = await self.reader.readline()
135 if line is None:
136 return
137
138 line = line.decode('utf-8').rstrip()
139 if not line:
140 break
141
142 # Handle messages
143 handlers = {
144 'get': self.handle_get,
145 'report': self.handle_report,
Andrew Geissler82c905d2020-04-13 13:39:40 -0500146 'report-equiv': self.handle_equivreport,
Brad Bishopa34c0302019-09-23 22:34:48 -0400147 'get-stream': self.handle_get_stream,
148 'get-stats': self.handle_get_stats,
149 'reset-stats': self.handle_reset_stats,
150 }
151
152 while True:
153 d = await self.read_message()
154 if d is None:
155 break
156
157 for k in handlers.keys():
158 if k in d:
159 logger.debug('Handling %s' % k)
160 if 'stream' in k:
161 await handlers[k](d[k])
162 else:
163 with self.request_stats.start_sample() as self.request_sample, \
164 self.request_sample.measure():
165 await handlers[k](d[k])
166 break
167 else:
168 logger.warning("Unrecognized command %r" % d)
169 break
170
171 await self.writer.drain()
172 finally:
173 self.writer.close()
174
175 def write_message(self, msg):
176 self.writer.write(('%s\n' % json.dumps(msg)).encode('utf-8'))
177
178 async def read_message(self):
179 l = await self.reader.readline()
180 if not l:
181 return None
182
183 try:
184 message = l.decode('utf-8')
185
186 if not message.endswith('\n'):
187 return None
188
189 return json.loads(message)
190 except (json.JSONDecodeError, UnicodeDecodeError) as e:
191 logger.error('Bad message from client: %r' % message)
192 raise e
193
194 async def handle_get(self, request):
195 method = request['method']
196 taskhash = request['taskhash']
197
198 row = self.query_equivalent(method, taskhash)
199 if row is not None:
200 logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
201 d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
202
203 self.write_message(d)
204 else:
205 self.write_message(None)
206
207 async def handle_get_stream(self, request):
208 self.write_message('ok')
209
210 while True:
211 l = await self.reader.readline()
212 if not l:
213 return
214
215 try:
216 # This inner loop is very sensitive and must be as fast as
217 # possible (which is why the request sample is handled manually
218 # instead of using 'with', and also why logging statements are
219 # commented out.
220 self.request_sample = self.request_stats.start_sample()
221 request_measure = self.request_sample.measure()
222 request_measure.start()
223
224 l = l.decode('utf-8').rstrip()
225 if l == 'END':
226 self.writer.write('ok\n'.encode('utf-8'))
227 return
228
229 (method, taskhash) = l.split()
230 #logger.debug('Looking up %s %s' % (method, taskhash))
231 row = self.query_equivalent(method, taskhash)
232 if row is not None:
233 msg = ('%s\n' % row['unihash']).encode('utf-8')
234 #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
235 else:
236 msg = '\n'.encode('utf-8')
237
238 self.writer.write(msg)
239 finally:
240 request_measure.end()
241 self.request_sample.end()
242
243 await self.writer.drain()
244
245 async def handle_report(self, data):
246 with closing(self.db.cursor()) as cursor:
247 cursor.execute('''
248 -- Find tasks with a matching outhash (that is, tasks that
249 -- are equivalent)
250 SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND outhash=:outhash
251
252 -- If there is an exact match on the taskhash, return it.
253 -- Otherwise return the oldest matching outhash of any
254 -- taskhash
255 ORDER BY CASE WHEN taskhash=:taskhash THEN 1 ELSE 2 END,
256 created ASC
257
258 -- Only return one row
259 LIMIT 1
260 ''', {k: data[k] for k in ('method', 'outhash', 'taskhash')})
261
262 row = cursor.fetchone()
263
264 # If no matching outhash was found, or one *was* found but it
265 # wasn't an exact match on the taskhash, a new entry for this
266 # taskhash should be added
267 if row is None or row['taskhash'] != data['taskhash']:
268 # If a row matching the outhash was found, the unihash for
269 # the new taskhash should be the same as that one.
270 # Otherwise the caller provided unihash is used.
271 unihash = data['unihash']
272 if row is not None:
273 unihash = row['unihash']
274
275 insert_data = {
276 'method': data['method'],
277 'outhash': data['outhash'],
278 'taskhash': data['taskhash'],
279 'unihash': unihash,
280 'created': datetime.now()
281 }
282
283 for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'):
284 if k in data:
285 insert_data[k] = data[k]
286
287 cursor.execute('''INSERT INTO tasks_v2 (%s) VALUES (%s)''' % (
288 ', '.join(sorted(insert_data.keys())),
289 ', '.join(':' + k for k in sorted(insert_data.keys()))),
290 insert_data)
291
292 self.db.commit()
293
294 logger.info('Adding taskhash %s with unihash %s',
295 data['taskhash'], unihash)
296
297 d = {
298 'taskhash': data['taskhash'],
299 'method': data['method'],
300 'unihash': unihash
301 }
302 else:
303 d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
304
305 self.write_message(d)
306
Andrew Geissler82c905d2020-04-13 13:39:40 -0500307 async def handle_equivreport(self, data):
308 with closing(self.db.cursor()) as cursor:
309 insert_data = {
310 'method': data['method'],
311 'outhash': "",
312 'taskhash': data['taskhash'],
313 'unihash': data['unihash'],
314 'created': datetime.now()
315 }
316
317 for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'):
318 if k in data:
319 insert_data[k] = data[k]
320
321 cursor.execute('''INSERT OR IGNORE INTO tasks_v2 (%s) VALUES (%s)''' % (
322 ', '.join(sorted(insert_data.keys())),
323 ', '.join(':' + k for k in sorted(insert_data.keys()))),
324 insert_data)
325
326 self.db.commit()
327
328 # Fetch the unihash that will be reported for the taskhash. If the
329 # unihash matches, it means this row was inserted (or the mapping
330 # was already valid)
331 row = self.query_equivalent(data['method'], data['taskhash'])
332
333 if row['unihash'] == data['unihash']:
334 logger.info('Adding taskhash equivalence for %s with unihash %s',
335 data['taskhash'], row['unihash'])
336
337 d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
338
339 self.write_message(d)
340
341
Brad Bishopa34c0302019-09-23 22:34:48 -0400342 async def handle_get_stats(self, request):
343 d = {
344 'requests': self.request_stats.todict(),
345 }
346
347 self.write_message(d)
348
349 async def handle_reset_stats(self, request):
350 d = {
351 'requests': self.request_stats.todict(),
352 }
353
354 self.request_stats.reset()
355 self.write_message(d)
356
357 def query_equivalent(self, method, taskhash):
358 # This is part of the inner loop and must be as fast as possible
359 try:
360 cursor = self.db.cursor()
361 cursor.execute('SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1',
362 {'method': method, 'taskhash': taskhash})
363 return cursor.fetchone()
364 except:
365 cursor.close()
366
367
368class Server(object):
369 def __init__(self, db, loop=None):
370 self.request_stats = Stats()
371 self.db = db
372
373 if loop is None:
374 self.loop = asyncio.new_event_loop()
375 self.close_loop = True
376 else:
377 self.loop = loop
378 self.close_loop = False
379
380 self._cleanup_socket = None
381
382 def start_tcp_server(self, host, port):
383 self.server = self.loop.run_until_complete(
384 asyncio.start_server(self.handle_client, host, port, loop=self.loop)
385 )
386
387 for s in self.server.sockets:
388 logger.info('Listening on %r' % (s.getsockname(),))
389 # Newer python does this automatically. Do it manually here for
390 # maximum compatibility
391 s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
392 s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
393
394 name = self.server.sockets[0].getsockname()
395 if self.server.sockets[0].family == socket.AF_INET6:
396 self.address = "[%s]:%d" % (name[0], name[1])
397 else:
398 self.address = "%s:%d" % (name[0], name[1])
399
400 def start_unix_server(self, path):
401 def cleanup():
402 os.unlink(path)
403
404 cwd = os.getcwd()
405 try:
406 # Work around path length limits in AF_UNIX
407 os.chdir(os.path.dirname(path))
408 self.server = self.loop.run_until_complete(
409 asyncio.start_unix_server(self.handle_client, os.path.basename(path), loop=self.loop)
410 )
411 finally:
412 os.chdir(cwd)
413
414 logger.info('Listening on %r' % path)
415
416 self._cleanup_socket = cleanup
417 self.address = "unix://%s" % os.path.abspath(path)
418
419 async def handle_client(self, reader, writer):
420 # writer.transport.set_write_buffer_limits(0)
421 try:
422 client = ServerClient(reader, writer, self.db, self.request_stats)
423 await client.process_requests()
424 except Exception as e:
425 import traceback
426 logger.error('Error from client: %s' % str(e), exc_info=True)
427 traceback.print_exc()
428 writer.close()
429 logger.info('Client disconnected')
430
431 def serve_forever(self):
432 def signal_handler():
433 self.loop.stop()
434
435 self.loop.add_signal_handler(signal.SIGTERM, signal_handler)
436
437 try:
438 self.loop.run_forever()
439 except KeyboardInterrupt:
440 pass
441
442 self.server.close()
443 self.loop.run_until_complete(self.server.wait_closed())
444 logger.info('Server shutting down')
445
446 if self.close_loop:
447 self.loop.close()
448
449 if self._cleanup_socket is not None:
450 self._cleanup_socket()