blob: 2559bbb3fb2772984d8c42a57109b27624db4ad1 [file] [log] [blame]
# Copyright (C) 2019 Garmin Ltd.
#
# SPDX-License-Identifier: GPL-2.0-only
#
from contextlib import closing
import json
import logging
import socket
logger = logging.getLogger('hashserv.client')
class HashConnectionError(Exception):
pass
class Client(object):
MODE_NORMAL = 0
MODE_GET_STREAM = 1
def __init__(self):
self._socket = None
self.reader = None
self.writer = None
self.mode = self.MODE_NORMAL
def connect_tcp(self, address, port):
def connect_sock():
s = socket.create_connection((address, port))
s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
return s
self._connect_sock = connect_sock
def connect_unix(self, path):
def connect_sock():
s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
# AF_UNIX has path length issues so chdir here to workaround
cwd = os.getcwd()
try:
os.chdir(os.path.dirname(path))
s.connect(os.path.basename(path))
finally:
os.chdir(cwd)
return s
self._connect_sock = connect_sock
def connect(self):
if self._socket is None:
self._socket = self._connect_sock()
self.reader = self._socket.makefile('r', encoding='utf-8')
self.writer = self._socket.makefile('w', encoding='utf-8')
self.writer.write('OEHASHEQUIV 1.0\n\n')
self.writer.flush()
# Restore mode if the socket is being re-created
cur_mode = self.mode
self.mode = self.MODE_NORMAL
self._set_mode(cur_mode)
return self._socket
def close(self):
if self._socket is not None:
self._socket.close()
self._socket = None
self.reader = None
self.writer = None
def _send_wrapper(self, proc):
count = 0
while True:
try:
self.connect()
return proc()
except (OSError, HashConnectionError, json.JSONDecodeError, UnicodeDecodeError) as e:
logger.warning('Error talking to server: %s' % e)
if count >= 3:
if not isinstance(e, HashConnectionError):
raise HashConnectionError(str(e))
raise e
self.close()
count += 1
def send_message(self, msg):
def proc():
self.writer.write('%s\n' % json.dumps(msg))
self.writer.flush()
l = self.reader.readline()
if not l:
raise HashConnectionError('Connection closed')
if not l.endswith('\n'):
raise HashConnectionError('Bad message %r' % message)
return json.loads(l)
return self._send_wrapper(proc)
def send_stream(self, msg):
def proc():
self.writer.write("%s\n" % msg)
self.writer.flush()
l = self.reader.readline()
if not l:
raise HashConnectionError('Connection closed')
return l.rstrip()
return self._send_wrapper(proc)
def _set_mode(self, new_mode):
if new_mode == self.MODE_NORMAL and self.mode == self.MODE_GET_STREAM:
r = self.send_stream('END')
if r != 'ok':
raise HashConnectionError('Bad response from server %r' % r)
elif new_mode == self.MODE_GET_STREAM and self.mode == self.MODE_NORMAL:
r = self.send_message({'get-stream': None})
if r != 'ok':
raise HashConnectionError('Bad response from server %r' % r)
elif new_mode != self.mode:
raise Exception('Undefined mode transition %r -> %r' % (self.mode, new_mode))
self.mode = new_mode
def get_unihash(self, method, taskhash):
self._set_mode(self.MODE_GET_STREAM)
r = self.send_stream('%s %s' % (method, taskhash))
if not r:
return None
return r
def report_unihash(self, taskhash, method, outhash, unihash, extra={}):
self._set_mode(self.MODE_NORMAL)
m = extra.copy()
m['taskhash'] = taskhash
m['method'] = method
m['outhash'] = outhash
m['unihash'] = unihash
return self.send_message({'report': m})
def get_stats(self):
self._set_mode(self.MODE_NORMAL)
return self.send_message({'get-stats': None})
def reset_stats(self):
self._set_mode(self.MODE_NORMAL)
return self.send_message({'reset-stats': None})