blob: 79919c5be6d9c94597a9b4659ba004b43e56a9a6 [file] [log] [blame]
Andrew Geisslerc926e172021-05-07 16:11:35 -05001#
2# SPDX-License-Identifier: GPL-2.0-only
3#
4
5import abc
6import asyncio
7import json
8import os
9import socket
10from . import chunkify, DEFAULT_MAX_CHUNK
11
12
13class AsyncClient(object):
14 def __init__(self, proto_name, proto_version, logger):
15 self.reader = None
16 self.writer = None
17 self.max_chunk = DEFAULT_MAX_CHUNK
18 self.proto_name = proto_name
19 self.proto_version = proto_version
20 self.logger = logger
21
22 async def connect_tcp(self, address, port):
23 async def connect_sock():
24 return await asyncio.open_connection(address, port)
25
26 self._connect_sock = connect_sock
27
28 async def connect_unix(self, path):
29 async def connect_sock():
30 return await asyncio.open_unix_connection(path)
31
32 self._connect_sock = connect_sock
33
34 async def setup_connection(self):
35 s = '%s %s\n\n' % (self.proto_name, self.proto_version)
36 self.writer.write(s.encode("utf-8"))
37 await self.writer.drain()
38
39 async def connect(self):
40 if self.reader is None or self.writer is None:
41 (self.reader, self.writer) = await self._connect_sock()
42 await self.setup_connection()
43
44 async def close(self):
45 self.reader = None
46
47 if self.writer is not None:
48 self.writer.close()
49 self.writer = None
50
51 async def _send_wrapper(self, proc):
52 count = 0
53 while True:
54 try:
55 await self.connect()
56 return await proc()
57 except (
58 OSError,
59 ConnectionError,
60 json.JSONDecodeError,
61 UnicodeDecodeError,
62 ) as e:
63 self.logger.warning("Error talking to server: %s" % e)
64 if count >= 3:
65 if not isinstance(e, ConnectionError):
66 raise ConnectionError(str(e))
67 raise e
68 await self.close()
69 count += 1
70
71 async def send_message(self, msg):
72 async def get_line():
73 line = await self.reader.readline()
74 if not line:
75 raise ConnectionError("Connection closed")
76
77 line = line.decode("utf-8")
78
79 if not line.endswith("\n"):
80 raise ConnectionError("Bad message %r" % msg)
81
82 return line
83
84 async def proc():
85 for c in chunkify(json.dumps(msg), self.max_chunk):
86 self.writer.write(c.encode("utf-8"))
87 await self.writer.drain()
88
89 l = await get_line()
90
91 m = json.loads(l)
92 if m and "chunk-stream" in m:
93 lines = []
94 while True:
95 l = (await get_line()).rstrip("\n")
96 if not l:
97 break
98 lines.append(l)
99
100 m = json.loads("".join(lines))
101
102 return m
103
104 return await self._send_wrapper(proc)
105
Andrew Geissler09036742021-06-25 14:25:14 -0500106 async def ping(self):
107 return await self.send_message(
108 {'ping': {}}
109 )
110
Andrew Geisslerc926e172021-05-07 16:11:35 -0500111
112class Client(object):
113 def __init__(self):
114 self.client = self._get_async_client()
115 self.loop = asyncio.new_event_loop()
116
Andrew Geissler09036742021-06-25 14:25:14 -0500117 self._add_methods('connect_tcp', 'close', 'ping')
Andrew Geisslerc926e172021-05-07 16:11:35 -0500118
119 @abc.abstractmethod
120 def _get_async_client(self):
121 pass
122
123 def _get_downcall_wrapper(self, downcall):
124 def wrapper(*args, **kwargs):
125 return self.loop.run_until_complete(downcall(*args, **kwargs))
126
127 return wrapper
128
129 def _add_methods(self, *methods):
130 for m in methods:
131 downcall = getattr(self.client, m)
132 setattr(self, m, self._get_downcall_wrapper(downcall))
133
134 def connect_unix(self, path):
135 # AF_UNIX has path length issues so chdir here to workaround
136 cwd = os.getcwd()
137 try:
138 os.chdir(os.path.dirname(path))
139 self.loop.run_until_complete(self.client.connect_unix(os.path.basename(path)))
140 self.loop.run_until_complete(self.client.connect())
141 finally:
142 os.chdir(cwd)
143
144 @property
145 def max_chunk(self):
146 return self.client.max_chunk
147
148 @max_chunk.setter
149 def max_chunk(self, value):
150 self.client.max_chunk = value