blob: 0aff77688e441da757c85380f4d83e758c9a919c [file] [log] [blame]
Brad Bishopa34c0302019-09-23 22:34:48 -04001# Copyright (C) 2019 Garmin Ltd.
2#
3# SPDX-License-Identifier: GPL-2.0-only
4#
5
6from contextlib import closing
7from datetime import datetime
8import asyncio
9import json
10import logging
11import math
12import os
13import signal
14import socket
15import time
16
17logger = logging.getLogger('hashserv.server')
18
19
20class Measurement(object):
21 def __init__(self, sample):
22 self.sample = sample
23
24 def start(self):
25 self.start_time = time.perf_counter()
26
27 def end(self):
28 self.sample.add(time.perf_counter() - self.start_time)
29
30 def __enter__(self):
31 self.start()
32 return self
33
34 def __exit__(self, *args, **kwargs):
35 self.end()
36
37
38class Sample(object):
39 def __init__(self, stats):
40 self.stats = stats
41 self.num_samples = 0
42 self.elapsed = 0
43
44 def measure(self):
45 return Measurement(self)
46
47 def __enter__(self):
48 return self
49
50 def __exit__(self, *args, **kwargs):
51 self.end()
52
53 def add(self, elapsed):
54 self.num_samples += 1
55 self.elapsed += elapsed
56
57 def end(self):
58 if self.num_samples:
59 self.stats.add(self.elapsed)
60 self.num_samples = 0
61 self.elapsed = 0
62
63
64class Stats(object):
65 def __init__(self):
66 self.reset()
67
68 def reset(self):
69 self.num = 0
70 self.total_time = 0
71 self.max_time = 0
72 self.m = 0
73 self.s = 0
74 self.current_elapsed = None
75
76 def add(self, elapsed):
77 self.num += 1
78 if self.num == 1:
79 self.m = elapsed
80 self.s = 0
81 else:
82 last_m = self.m
83 self.m = last_m + (elapsed - last_m) / self.num
84 self.s = self.s + (elapsed - last_m) * (elapsed - self.m)
85
86 self.total_time += elapsed
87
88 if self.max_time < elapsed:
89 self.max_time = elapsed
90
91 def start_sample(self):
92 return Sample(self)
93
94 @property
95 def average(self):
96 if self.num == 0:
97 return 0
98 return self.total_time / self.num
99
100 @property
101 def stdev(self):
102 if self.num <= 1:
103 return 0
104 return math.sqrt(self.s / (self.num - 1))
105
106 def todict(self):
107 return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')}
108
109
110class ServerClient(object):
111 def __init__(self, reader, writer, db, request_stats):
112 self.reader = reader
113 self.writer = writer
114 self.db = db
115 self.request_stats = request_stats
116
117 async def process_requests(self):
118 try:
119 self.addr = self.writer.get_extra_info('peername')
120 logger.debug('Client %r connected' % (self.addr,))
121
122 # Read protocol and version
123 protocol = await self.reader.readline()
124 if protocol is None:
125 return
126
127 (proto_name, proto_version) = protocol.decode('utf-8').rstrip().split()
128 if proto_name != 'OEHASHEQUIV' or proto_version != '1.0':
129 return
130
131 # Read headers. Currently, no headers are implemented, so look for
132 # an empty line to signal the end of the headers
133 while True:
134 line = await self.reader.readline()
135 if line is None:
136 return
137
138 line = line.decode('utf-8').rstrip()
139 if not line:
140 break
141
142 # Handle messages
143 handlers = {
144 'get': self.handle_get,
145 'report': self.handle_report,
146 'get-stream': self.handle_get_stream,
147 'get-stats': self.handle_get_stats,
148 'reset-stats': self.handle_reset_stats,
149 }
150
151 while True:
152 d = await self.read_message()
153 if d is None:
154 break
155
156 for k in handlers.keys():
157 if k in d:
158 logger.debug('Handling %s' % k)
159 if 'stream' in k:
160 await handlers[k](d[k])
161 else:
162 with self.request_stats.start_sample() as self.request_sample, \
163 self.request_sample.measure():
164 await handlers[k](d[k])
165 break
166 else:
167 logger.warning("Unrecognized command %r" % d)
168 break
169
170 await self.writer.drain()
171 finally:
172 self.writer.close()
173
174 def write_message(self, msg):
175 self.writer.write(('%s\n' % json.dumps(msg)).encode('utf-8'))
176
177 async def read_message(self):
178 l = await self.reader.readline()
179 if not l:
180 return None
181
182 try:
183 message = l.decode('utf-8')
184
185 if not message.endswith('\n'):
186 return None
187
188 return json.loads(message)
189 except (json.JSONDecodeError, UnicodeDecodeError) as e:
190 logger.error('Bad message from client: %r' % message)
191 raise e
192
193 async def handle_get(self, request):
194 method = request['method']
195 taskhash = request['taskhash']
196
197 row = self.query_equivalent(method, taskhash)
198 if row is not None:
199 logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
200 d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
201
202 self.write_message(d)
203 else:
204 self.write_message(None)
205
206 async def handle_get_stream(self, request):
207 self.write_message('ok')
208
209 while True:
210 l = await self.reader.readline()
211 if not l:
212 return
213
214 try:
215 # This inner loop is very sensitive and must be as fast as
216 # possible (which is why the request sample is handled manually
217 # instead of using 'with', and also why logging statements are
218 # commented out.
219 self.request_sample = self.request_stats.start_sample()
220 request_measure = self.request_sample.measure()
221 request_measure.start()
222
223 l = l.decode('utf-8').rstrip()
224 if l == 'END':
225 self.writer.write('ok\n'.encode('utf-8'))
226 return
227
228 (method, taskhash) = l.split()
229 #logger.debug('Looking up %s %s' % (method, taskhash))
230 row = self.query_equivalent(method, taskhash)
231 if row is not None:
232 msg = ('%s\n' % row['unihash']).encode('utf-8')
233 #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
234 else:
235 msg = '\n'.encode('utf-8')
236
237 self.writer.write(msg)
238 finally:
239 request_measure.end()
240 self.request_sample.end()
241
242 await self.writer.drain()
243
244 async def handle_report(self, data):
245 with closing(self.db.cursor()) as cursor:
246 cursor.execute('''
247 -- Find tasks with a matching outhash (that is, tasks that
248 -- are equivalent)
249 SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND outhash=:outhash
250
251 -- If there is an exact match on the taskhash, return it.
252 -- Otherwise return the oldest matching outhash of any
253 -- taskhash
254 ORDER BY CASE WHEN taskhash=:taskhash THEN 1 ELSE 2 END,
255 created ASC
256
257 -- Only return one row
258 LIMIT 1
259 ''', {k: data[k] for k in ('method', 'outhash', 'taskhash')})
260
261 row = cursor.fetchone()
262
263 # If no matching outhash was found, or one *was* found but it
264 # wasn't an exact match on the taskhash, a new entry for this
265 # taskhash should be added
266 if row is None or row['taskhash'] != data['taskhash']:
267 # If a row matching the outhash was found, the unihash for
268 # the new taskhash should be the same as that one.
269 # Otherwise the caller provided unihash is used.
270 unihash = data['unihash']
271 if row is not None:
272 unihash = row['unihash']
273
274 insert_data = {
275 'method': data['method'],
276 'outhash': data['outhash'],
277 'taskhash': data['taskhash'],
278 'unihash': unihash,
279 'created': datetime.now()
280 }
281
282 for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'):
283 if k in data:
284 insert_data[k] = data[k]
285
286 cursor.execute('''INSERT INTO tasks_v2 (%s) VALUES (%s)''' % (
287 ', '.join(sorted(insert_data.keys())),
288 ', '.join(':' + k for k in sorted(insert_data.keys()))),
289 insert_data)
290
291 self.db.commit()
292
293 logger.info('Adding taskhash %s with unihash %s',
294 data['taskhash'], unihash)
295
296 d = {
297 'taskhash': data['taskhash'],
298 'method': data['method'],
299 'unihash': unihash
300 }
301 else:
302 d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
303
304 self.write_message(d)
305
306 async def handle_get_stats(self, request):
307 d = {
308 'requests': self.request_stats.todict(),
309 }
310
311 self.write_message(d)
312
313 async def handle_reset_stats(self, request):
314 d = {
315 'requests': self.request_stats.todict(),
316 }
317
318 self.request_stats.reset()
319 self.write_message(d)
320
321 def query_equivalent(self, method, taskhash):
322 # This is part of the inner loop and must be as fast as possible
323 try:
324 cursor = self.db.cursor()
325 cursor.execute('SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1',
326 {'method': method, 'taskhash': taskhash})
327 return cursor.fetchone()
328 except:
329 cursor.close()
330
331
332class Server(object):
333 def __init__(self, db, loop=None):
334 self.request_stats = Stats()
335 self.db = db
336
337 if loop is None:
338 self.loop = asyncio.new_event_loop()
339 self.close_loop = True
340 else:
341 self.loop = loop
342 self.close_loop = False
343
344 self._cleanup_socket = None
345
346 def start_tcp_server(self, host, port):
347 self.server = self.loop.run_until_complete(
348 asyncio.start_server(self.handle_client, host, port, loop=self.loop)
349 )
350
351 for s in self.server.sockets:
352 logger.info('Listening on %r' % (s.getsockname(),))
353 # Newer python does this automatically. Do it manually here for
354 # maximum compatibility
355 s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
356 s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
357
358 name = self.server.sockets[0].getsockname()
359 if self.server.sockets[0].family == socket.AF_INET6:
360 self.address = "[%s]:%d" % (name[0], name[1])
361 else:
362 self.address = "%s:%d" % (name[0], name[1])
363
364 def start_unix_server(self, path):
365 def cleanup():
366 os.unlink(path)
367
368 cwd = os.getcwd()
369 try:
370 # Work around path length limits in AF_UNIX
371 os.chdir(os.path.dirname(path))
372 self.server = self.loop.run_until_complete(
373 asyncio.start_unix_server(self.handle_client, os.path.basename(path), loop=self.loop)
374 )
375 finally:
376 os.chdir(cwd)
377
378 logger.info('Listening on %r' % path)
379
380 self._cleanup_socket = cleanup
381 self.address = "unix://%s" % os.path.abspath(path)
382
383 async def handle_client(self, reader, writer):
384 # writer.transport.set_write_buffer_limits(0)
385 try:
386 client = ServerClient(reader, writer, self.db, self.request_stats)
387 await client.process_requests()
388 except Exception as e:
389 import traceback
390 logger.error('Error from client: %s' % str(e), exc_info=True)
391 traceback.print_exc()
392 writer.close()
393 logger.info('Client disconnected')
394
395 def serve_forever(self):
396 def signal_handler():
397 self.loop.stop()
398
399 self.loop.add_signal_handler(signal.SIGTERM, signal_handler)
400
401 try:
402 self.loop.run_forever()
403 except KeyboardInterrupt:
404 pass
405
406 self.server.close()
407 self.loop.run_until_complete(self.server.wait_closed())
408 logger.info('Server shutting down')
409
410 if self.close_loop:
411 self.loop.close()
412
413 if self._cleanup_socket is not None:
414 self._cleanup_socket()