blob: 3eb4fdde8ae516b961cf815d7fe09447d4ce0af5 [file] [log] [blame]
Andrew Geisslerc926e172021-05-07 16:11:35 -05001#
2# SPDX-License-Identifier: GPL-2.0-only
3#
4
5import abc
6import asyncio
7import json
8import os
9import socket
10from . import chunkify, DEFAULT_MAX_CHUNK
11
12
13class AsyncClient(object):
Patrick Williams213cb262021-08-07 19:21:33 -050014 def __init__(self, proto_name, proto_version, logger, timeout=30):
Andrew Geisslerc926e172021-05-07 16:11:35 -050015 self.reader = None
16 self.writer = None
17 self.max_chunk = DEFAULT_MAX_CHUNK
18 self.proto_name = proto_name
19 self.proto_version = proto_version
20 self.logger = logger
Patrick Williams213cb262021-08-07 19:21:33 -050021 self.timeout = timeout
Andrew Geisslerc926e172021-05-07 16:11:35 -050022
23 async def connect_tcp(self, address, port):
24 async def connect_sock():
25 return await asyncio.open_connection(address, port)
26
27 self._connect_sock = connect_sock
28
29 async def connect_unix(self, path):
30 async def connect_sock():
31 return await asyncio.open_unix_connection(path)
32
33 self._connect_sock = connect_sock
34
35 async def setup_connection(self):
36 s = '%s %s\n\n' % (self.proto_name, self.proto_version)
37 self.writer.write(s.encode("utf-8"))
38 await self.writer.drain()
39
40 async def connect(self):
41 if self.reader is None or self.writer is None:
42 (self.reader, self.writer) = await self._connect_sock()
43 await self.setup_connection()
44
45 async def close(self):
46 self.reader = None
47
48 if self.writer is not None:
49 self.writer.close()
50 self.writer = None
51
52 async def _send_wrapper(self, proc):
53 count = 0
54 while True:
55 try:
56 await self.connect()
57 return await proc()
58 except (
59 OSError,
60 ConnectionError,
61 json.JSONDecodeError,
62 UnicodeDecodeError,
63 ) as e:
64 self.logger.warning("Error talking to server: %s" % e)
65 if count >= 3:
66 if not isinstance(e, ConnectionError):
67 raise ConnectionError(str(e))
68 raise e
69 await self.close()
70 count += 1
71
72 async def send_message(self, msg):
73 async def get_line():
Patrick Williams213cb262021-08-07 19:21:33 -050074 try:
75 line = await asyncio.wait_for(self.reader.readline(), self.timeout)
76 except asyncio.TimeoutError:
77 raise ConnectionError("Timed out waiting for server")
78
Andrew Geisslerc926e172021-05-07 16:11:35 -050079 if not line:
80 raise ConnectionError("Connection closed")
81
82 line = line.decode("utf-8")
83
84 if not line.endswith("\n"):
Patrick Williams213cb262021-08-07 19:21:33 -050085 raise ConnectionError("Bad message %r" % (line))
Andrew Geisslerc926e172021-05-07 16:11:35 -050086
87 return line
88
89 async def proc():
90 for c in chunkify(json.dumps(msg), self.max_chunk):
91 self.writer.write(c.encode("utf-8"))
92 await self.writer.drain()
93
94 l = await get_line()
95
96 m = json.loads(l)
97 if m and "chunk-stream" in m:
98 lines = []
99 while True:
100 l = (await get_line()).rstrip("\n")
101 if not l:
102 break
103 lines.append(l)
104
105 m = json.loads("".join(lines))
106
107 return m
108
109 return await self._send_wrapper(proc)
110
Andrew Geissler09036742021-06-25 14:25:14 -0500111 async def ping(self):
112 return await self.send_message(
113 {'ping': {}}
114 )
115
Andrew Geisslerc926e172021-05-07 16:11:35 -0500116
117class Client(object):
118 def __init__(self):
119 self.client = self._get_async_client()
120 self.loop = asyncio.new_event_loop()
121
Andrew Geissler09036742021-06-25 14:25:14 -0500122 self._add_methods('connect_tcp', 'close', 'ping')
Andrew Geisslerc926e172021-05-07 16:11:35 -0500123
124 @abc.abstractmethod
125 def _get_async_client(self):
126 pass
127
128 def _get_downcall_wrapper(self, downcall):
129 def wrapper(*args, **kwargs):
130 return self.loop.run_until_complete(downcall(*args, **kwargs))
131
132 return wrapper
133
134 def _add_methods(self, *methods):
135 for m in methods:
136 downcall = getattr(self.client, m)
137 setattr(self, m, self._get_downcall_wrapper(downcall))
138
139 def connect_unix(self, path):
140 # AF_UNIX has path length issues so chdir here to workaround
141 cwd = os.getcwd()
142 try:
143 os.chdir(os.path.dirname(path))
144 self.loop.run_until_complete(self.client.connect_unix(os.path.basename(path)))
145 self.loop.run_until_complete(self.client.connect())
146 finally:
147 os.chdir(cwd)
148
149 @property
150 def max_chunk(self):
151 return self.client.max_chunk
152
153 @max_chunk.setter
154 def max_chunk(self, value):
155 self.client.max_chunk = value