blob: 2559bbb3fb2772984d8c42a57109b27624db4ad1 [file] [log] [blame]
Brad Bishopa34c0302019-09-23 22:34:48 -04001# Copyright (C) 2019 Garmin Ltd.
2#
3# SPDX-License-Identifier: GPL-2.0-only
4#
5
6from contextlib import closing
7import json
8import logging
9import socket
10
11
12logger = logging.getLogger('hashserv.client')
13
14
15class HashConnectionError(Exception):
16 pass
17
18
19class Client(object):
20 MODE_NORMAL = 0
21 MODE_GET_STREAM = 1
22
23 def __init__(self):
24 self._socket = None
25 self.reader = None
26 self.writer = None
27 self.mode = self.MODE_NORMAL
28
29 def connect_tcp(self, address, port):
30 def connect_sock():
31 s = socket.create_connection((address, port))
32
33 s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
34 s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
35 s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
36 return s
37
38 self._connect_sock = connect_sock
39
40 def connect_unix(self, path):
41 def connect_sock():
42 s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
43 # AF_UNIX has path length issues so chdir here to workaround
44 cwd = os.getcwd()
45 try:
46 os.chdir(os.path.dirname(path))
47 s.connect(os.path.basename(path))
48 finally:
49 os.chdir(cwd)
50 return s
51
52 self._connect_sock = connect_sock
53
54 def connect(self):
55 if self._socket is None:
56 self._socket = self._connect_sock()
57
58 self.reader = self._socket.makefile('r', encoding='utf-8')
59 self.writer = self._socket.makefile('w', encoding='utf-8')
60
61 self.writer.write('OEHASHEQUIV 1.0\n\n')
62 self.writer.flush()
63
64 # Restore mode if the socket is being re-created
65 cur_mode = self.mode
66 self.mode = self.MODE_NORMAL
67 self._set_mode(cur_mode)
68
69 return self._socket
70
71 def close(self):
72 if self._socket is not None:
73 self._socket.close()
74 self._socket = None
75 self.reader = None
76 self.writer = None
77
78 def _send_wrapper(self, proc):
79 count = 0
80 while True:
81 try:
82 self.connect()
83 return proc()
84 except (OSError, HashConnectionError, json.JSONDecodeError, UnicodeDecodeError) as e:
85 logger.warning('Error talking to server: %s' % e)
86 if count >= 3:
87 if not isinstance(e, HashConnectionError):
88 raise HashConnectionError(str(e))
89 raise e
90 self.close()
91 count += 1
92
93 def send_message(self, msg):
94 def proc():
95 self.writer.write('%s\n' % json.dumps(msg))
96 self.writer.flush()
97
98 l = self.reader.readline()
99 if not l:
100 raise HashConnectionError('Connection closed')
101
102 if not l.endswith('\n'):
103 raise HashConnectionError('Bad message %r' % message)
104
105 return json.loads(l)
106
107 return self._send_wrapper(proc)
108
109 def send_stream(self, msg):
110 def proc():
111 self.writer.write("%s\n" % msg)
112 self.writer.flush()
113 l = self.reader.readline()
114 if not l:
115 raise HashConnectionError('Connection closed')
116 return l.rstrip()
117
118 return self._send_wrapper(proc)
119
120 def _set_mode(self, new_mode):
121 if new_mode == self.MODE_NORMAL and self.mode == self.MODE_GET_STREAM:
122 r = self.send_stream('END')
123 if r != 'ok':
124 raise HashConnectionError('Bad response from server %r' % r)
125 elif new_mode == self.MODE_GET_STREAM and self.mode == self.MODE_NORMAL:
126 r = self.send_message({'get-stream': None})
127 if r != 'ok':
128 raise HashConnectionError('Bad response from server %r' % r)
129 elif new_mode != self.mode:
130 raise Exception('Undefined mode transition %r -> %r' % (self.mode, new_mode))
131
132 self.mode = new_mode
133
134 def get_unihash(self, method, taskhash):
135 self._set_mode(self.MODE_GET_STREAM)
136 r = self.send_stream('%s %s' % (method, taskhash))
137 if not r:
138 return None
139 return r
140
141 def report_unihash(self, taskhash, method, outhash, unihash, extra={}):
142 self._set_mode(self.MODE_NORMAL)
143 m = extra.copy()
144 m['taskhash'] = taskhash
145 m['method'] = method
146 m['outhash'] = outhash
147 m['unihash'] = unihash
148 return self.send_message({'report': m})
149
150 def get_stats(self):
151 self._set_mode(self.MODE_NORMAL)
152 return self.send_message({'get-stats': None})
153
154 def reset_stats(self):
155 self._set_mode(self.MODE_NORMAL)
156 return self.send_message({'reset-stats': None})