blob: f65956617b9cfc94db181a3822f084141c6f2156 [file] [log] [blame]
Brad Bishopa34c0302019-09-23 22:34:48 -04001# Copyright (C) 2019 Garmin Ltd.
2#
3# SPDX-License-Identifier: GPL-2.0-only
4#
5
6from contextlib import closing
7import json
8import logging
9import socket
Brad Bishop00e122a2019-10-05 11:10:57 -040010import os
Brad Bishopa34c0302019-09-23 22:34:48 -040011
12
13logger = logging.getLogger('hashserv.client')
14
15
16class HashConnectionError(Exception):
17 pass
18
19
20class Client(object):
21 MODE_NORMAL = 0
22 MODE_GET_STREAM = 1
23
24 def __init__(self):
25 self._socket = None
26 self.reader = None
27 self.writer = None
28 self.mode = self.MODE_NORMAL
29
30 def connect_tcp(self, address, port):
31 def connect_sock():
32 s = socket.create_connection((address, port))
33
34 s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
35 s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
36 s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
37 return s
38
39 self._connect_sock = connect_sock
40
41 def connect_unix(self, path):
42 def connect_sock():
43 s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
44 # AF_UNIX has path length issues so chdir here to workaround
45 cwd = os.getcwd()
46 try:
47 os.chdir(os.path.dirname(path))
48 s.connect(os.path.basename(path))
49 finally:
50 os.chdir(cwd)
51 return s
52
53 self._connect_sock = connect_sock
54
55 def connect(self):
56 if self._socket is None:
57 self._socket = self._connect_sock()
58
59 self.reader = self._socket.makefile('r', encoding='utf-8')
60 self.writer = self._socket.makefile('w', encoding='utf-8')
61
62 self.writer.write('OEHASHEQUIV 1.0\n\n')
63 self.writer.flush()
64
65 # Restore mode if the socket is being re-created
66 cur_mode = self.mode
67 self.mode = self.MODE_NORMAL
68 self._set_mode(cur_mode)
69
70 return self._socket
71
72 def close(self):
73 if self._socket is not None:
74 self._socket.close()
75 self._socket = None
76 self.reader = None
77 self.writer = None
78
79 def _send_wrapper(self, proc):
80 count = 0
81 while True:
82 try:
83 self.connect()
84 return proc()
85 except (OSError, HashConnectionError, json.JSONDecodeError, UnicodeDecodeError) as e:
86 logger.warning('Error talking to server: %s' % e)
87 if count >= 3:
88 if not isinstance(e, HashConnectionError):
89 raise HashConnectionError(str(e))
90 raise e
91 self.close()
92 count += 1
93
94 def send_message(self, msg):
95 def proc():
96 self.writer.write('%s\n' % json.dumps(msg))
97 self.writer.flush()
98
99 l = self.reader.readline()
100 if not l:
101 raise HashConnectionError('Connection closed')
102
103 if not l.endswith('\n'):
104 raise HashConnectionError('Bad message %r' % message)
105
106 return json.loads(l)
107
108 return self._send_wrapper(proc)
109
110 def send_stream(self, msg):
111 def proc():
112 self.writer.write("%s\n" % msg)
113 self.writer.flush()
114 l = self.reader.readline()
115 if not l:
116 raise HashConnectionError('Connection closed')
117 return l.rstrip()
118
119 return self._send_wrapper(proc)
120
121 def _set_mode(self, new_mode):
122 if new_mode == self.MODE_NORMAL and self.mode == self.MODE_GET_STREAM:
123 r = self.send_stream('END')
124 if r != 'ok':
125 raise HashConnectionError('Bad response from server %r' % r)
126 elif new_mode == self.MODE_GET_STREAM and self.mode == self.MODE_NORMAL:
127 r = self.send_message({'get-stream': None})
128 if r != 'ok':
129 raise HashConnectionError('Bad response from server %r' % r)
130 elif new_mode != self.mode:
131 raise Exception('Undefined mode transition %r -> %r' % (self.mode, new_mode))
132
133 self.mode = new_mode
134
135 def get_unihash(self, method, taskhash):
136 self._set_mode(self.MODE_GET_STREAM)
137 r = self.send_stream('%s %s' % (method, taskhash))
138 if not r:
139 return None
140 return r
141
142 def report_unihash(self, taskhash, method, outhash, unihash, extra={}):
143 self._set_mode(self.MODE_NORMAL)
144 m = extra.copy()
145 m['taskhash'] = taskhash
146 m['method'] = method
147 m['outhash'] = outhash
148 m['unihash'] = unihash
149 return self.send_message({'report': m})
150
151 def get_stats(self):
152 self._set_mode(self.MODE_NORMAL)
153 return self.send_message({'get-stats': None})
154
155 def reset_stats(self):
156 self._set_mode(self.MODE_NORMAL)
157 return self.send_message({'reset-stats': None})