blob: a350b4fb1262237cd163d4a7da89370be66d1203 [file] [log] [blame]
Andrew Geisslerc926e172021-05-07 16:11:35 -05001#
Patrick Williams92b42cb2022-09-03 06:53:57 -05002# Copyright BitBake Contributors
3#
Andrew Geisslerc926e172021-05-07 16:11:35 -05004# SPDX-License-Identifier: GPL-2.0-only
5#
6
7import abc
8import asyncio
9import json
10import os
11import socket
Andrew Geisslereff27472021-10-29 15:35:00 -050012import sys
Patrick Williams44b3caf2024-04-12 16:51:14 -050013import re
Patrick Williams73bd93f2024-02-20 08:07:48 -060014import contextlib
15from threading import Thread
Patrick Williamsac13d5f2023-11-24 18:59:46 -060016from .connection import StreamConnection, WebsocketConnection, DEFAULT_MAX_CHUNK
17from .exceptions import ConnectionClosedError, InvokeError
Andrew Geisslerc926e172021-05-07 16:11:35 -050018
Patrick Williams44b3caf2024-04-12 16:51:14 -050019UNIX_PREFIX = "unix://"
20WS_PREFIX = "ws://"
21WSS_PREFIX = "wss://"
22
23ADDR_TYPE_UNIX = 0
24ADDR_TYPE_TCP = 1
25ADDR_TYPE_WS = 2
26
27def parse_address(addr):
28 if addr.startswith(UNIX_PREFIX):
29 return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX) :],))
30 elif addr.startswith(WS_PREFIX) or addr.startswith(WSS_PREFIX):
31 return (ADDR_TYPE_WS, (addr,))
32 else:
33 m = re.match(r"\[(?P<host>[^\]]*)\]:(?P<port>\d+)$", addr)
34 if m is not None:
35 host = m.group("host")
36 port = m.group("port")
37 else:
38 host, port = addr.split(":")
39
40 return (ADDR_TYPE_TCP, (host, int(port)))
Andrew Geisslerc926e172021-05-07 16:11:35 -050041
42class AsyncClient(object):
Patrick Williams39653562024-03-01 08:54:02 -060043 def __init__(
44 self,
45 proto_name,
46 proto_version,
47 logger,
48 timeout=30,
49 server_headers=False,
50 headers={},
51 ):
Patrick Williamsac13d5f2023-11-24 18:59:46 -060052 self.socket = None
Andrew Geisslerc926e172021-05-07 16:11:35 -050053 self.max_chunk = DEFAULT_MAX_CHUNK
54 self.proto_name = proto_name
55 self.proto_version = proto_version
56 self.logger = logger
Patrick Williams213cb262021-08-07 19:21:33 -050057 self.timeout = timeout
Patrick Williams39653562024-03-01 08:54:02 -060058 self.needs_server_headers = server_headers
59 self.server_headers = {}
60 self.headers = headers
Andrew Geisslerc926e172021-05-07 16:11:35 -050061
62 async def connect_tcp(self, address, port):
63 async def connect_sock():
Patrick Williamsac13d5f2023-11-24 18:59:46 -060064 reader, writer = await asyncio.open_connection(address, port)
65 return StreamConnection(reader, writer, self.timeout, self.max_chunk)
Andrew Geisslerc926e172021-05-07 16:11:35 -050066
67 self._connect_sock = connect_sock
68
69 async def connect_unix(self, path):
70 async def connect_sock():
Andrew Geissler87f5cff2022-09-30 13:13:31 -050071 # AF_UNIX has path length issues so chdir here to workaround
72 cwd = os.getcwd()
73 try:
74 os.chdir(os.path.dirname(path))
75 # The socket must be opened synchronously so that CWD doesn't get
76 # changed out from underneath us so we pass as a sock into asyncio
77 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
78 sock.connect(os.path.basename(path))
79 finally:
Patrick Williamsac13d5f2023-11-24 18:59:46 -060080 os.chdir(cwd)
81 reader, writer = await asyncio.open_unix_connection(sock=sock)
82 return StreamConnection(reader, writer, self.timeout, self.max_chunk)
83
84 self._connect_sock = connect_sock
85
86 async def connect_websocket(self, uri):
87 import websockets
88
89 async def connect_sock():
90 websocket = await websockets.connect(uri, ping_interval=None)
91 return WebsocketConnection(websocket, self.timeout)
Andrew Geisslerc926e172021-05-07 16:11:35 -050092
93 self._connect_sock = connect_sock
94
95 async def setup_connection(self):
Patrick Williamsac13d5f2023-11-24 18:59:46 -060096 # Send headers
97 await self.socket.send("%s %s" % (self.proto_name, self.proto_version))
Patrick Williams39653562024-03-01 08:54:02 -060098 await self.socket.send(
99 "needs-headers: %s" % ("true" if self.needs_server_headers else "false")
100 )
101 for k, v in self.headers.items():
102 await self.socket.send("%s: %s" % (k, v))
103
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600104 # End of headers
105 await self.socket.send("")
Andrew Geisslerc926e172021-05-07 16:11:35 -0500106
Patrick Williams39653562024-03-01 08:54:02 -0600107 self.server_headers = {}
108 if self.needs_server_headers:
109 while True:
110 line = await self.socket.recv()
111 if not line:
112 # End headers
113 break
114 tag, value = line.split(":", 1)
115 self.server_headers[tag.lower()] = value.strip()
116
117 async def get_header(self, tag, default):
118 await self.connect()
119 return self.server_headers.get(tag, default)
120
Andrew Geisslerc926e172021-05-07 16:11:35 -0500121 async def connect(self):
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600122 if self.socket is None:
123 self.socket = await self._connect_sock()
Andrew Geisslerc926e172021-05-07 16:11:35 -0500124 await self.setup_connection()
125
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600126 async def disconnect(self):
127 if self.socket is not None:
128 await self.socket.close()
129 self.socket = None
Andrew Geisslerc926e172021-05-07 16:11:35 -0500130
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600131 async def close(self):
132 await self.disconnect()
Andrew Geisslerc926e172021-05-07 16:11:35 -0500133
134 async def _send_wrapper(self, proc):
135 count = 0
136 while True:
137 try:
138 await self.connect()
139 return await proc()
140 except (
141 OSError,
142 ConnectionError,
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600143 ConnectionClosedError,
Andrew Geisslerc926e172021-05-07 16:11:35 -0500144 json.JSONDecodeError,
145 UnicodeDecodeError,
146 ) as e:
147 self.logger.warning("Error talking to server: %s" % e)
148 if count >= 3:
149 if not isinstance(e, ConnectionError):
150 raise ConnectionError(str(e))
151 raise e
152 await self.close()
153 count += 1
154
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600155 def check_invoke_error(self, msg):
156 if isinstance(msg, dict) and "invoke-error" in msg:
157 raise InvokeError(msg["invoke-error"]["message"])
Patrick Williams213cb262021-08-07 19:21:33 -0500158
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600159 async def invoke(self, msg):
Andrew Geisslerc926e172021-05-07 16:11:35 -0500160 async def proc():
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600161 await self.socket.send_message(msg)
162 return await self.socket.recv_message()
Andrew Geisslerc926e172021-05-07 16:11:35 -0500163
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600164 result = await self._send_wrapper(proc)
165 self.check_invoke_error(result)
166 return result
Andrew Geisslerc926e172021-05-07 16:11:35 -0500167
Andrew Geissler09036742021-06-25 14:25:14 -0500168 async def ping(self):
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600169 return await self.invoke({"ping": {}})
170
171 async def __aenter__(self):
172 return self
173
174 async def __aexit__(self, exc_type, exc_value, traceback):
175 await self.close()
Andrew Geissler09036742021-06-25 14:25:14 -0500176
Andrew Geisslerc926e172021-05-07 16:11:35 -0500177
178class Client(object):
179 def __init__(self):
180 self.client = self._get_async_client()
181 self.loop = asyncio.new_event_loop()
182
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500183 # Override any pre-existing loop.
184 # Without this, the PR server export selftest triggers a hang
185 # when running with Python 3.7. The drawback is that there is
186 # potential for issues if the PR and hash equiv (or some new)
187 # clients need to both be instantiated in the same process.
188 # This should be revisited if/when Python 3.9 becomes the
189 # minimum required version for BitBake, as it seems not
190 # required (but harmless) with it.
191 asyncio.set_event_loop(self.loop)
192
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600193 self._add_methods("connect_tcp", "ping")
Andrew Geisslerc926e172021-05-07 16:11:35 -0500194
195 @abc.abstractmethod
196 def _get_async_client(self):
197 pass
198
199 def _get_downcall_wrapper(self, downcall):
200 def wrapper(*args, **kwargs):
201 return self.loop.run_until_complete(downcall(*args, **kwargs))
202
203 return wrapper
204
205 def _add_methods(self, *methods):
206 for m in methods:
207 downcall = getattr(self.client, m)
208 setattr(self, m, self._get_downcall_wrapper(downcall))
209
210 def connect_unix(self, path):
Andrew Geissler87f5cff2022-09-30 13:13:31 -0500211 self.loop.run_until_complete(self.client.connect_unix(path))
212 self.loop.run_until_complete(self.client.connect())
Andrew Geisslerc926e172021-05-07 16:11:35 -0500213
214 @property
215 def max_chunk(self):
216 return self.client.max_chunk
217
218 @max_chunk.setter
219 def max_chunk(self, value):
220 self.client.max_chunk = value
Andrew Geisslereff27472021-10-29 15:35:00 -0500221
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600222 def disconnect(self):
Andrew Geisslereff27472021-10-29 15:35:00 -0500223 self.loop.run_until_complete(self.client.close())
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600224
225 def close(self):
226 if self.loop:
227 self.loop.run_until_complete(self.client.close())
228 if sys.version_info >= (3, 6):
229 self.loop.run_until_complete(self.loop.shutdown_asyncgens())
230 self.loop.close()
231 self.loop = None
232
233 def __enter__(self):
234 return self
235
236 def __exit__(self, exc_type, exc_value, traceback):
237 self.close()
238 return False
Patrick Williams73bd93f2024-02-20 08:07:48 -0600239
240
241class ClientPool(object):
242 def __init__(self, max_clients):
243 self.avail_clients = []
244 self.num_clients = 0
245 self.max_clients = max_clients
246 self.loop = None
247 self.client_condition = None
248
249 @abc.abstractmethod
250 async def _new_client(self):
251 raise NotImplementedError("Must be implemented in derived class")
252
253 def close(self):
254 if self.client_condition:
255 self.client_condition = None
256
257 if self.loop:
258 self.loop.run_until_complete(self.__close_clients())
259 self.loop.run_until_complete(self.loop.shutdown_asyncgens())
260 self.loop.close()
261 self.loop = None
262
263 def run_tasks(self, tasks):
264 if not self.loop:
265 self.loop = asyncio.new_event_loop()
266
267 thread = Thread(target=self.__thread_main, args=(tasks,))
268 thread.start()
269 thread.join()
270
271 @contextlib.asynccontextmanager
272 async def get_client(self):
273 async with self.client_condition:
274 if self.avail_clients:
275 client = self.avail_clients.pop()
276 elif self.num_clients < self.max_clients:
277 self.num_clients += 1
278 client = await self._new_client()
279 else:
280 while not self.avail_clients:
281 await self.client_condition.wait()
282 client = self.avail_clients.pop()
283
284 try:
285 yield client
286 finally:
287 async with self.client_condition:
288 self.avail_clients.append(client)
289 self.client_condition.notify()
290
291 def __thread_main(self, tasks):
292 async def process_task(task):
293 async with self.get_client() as client:
294 await task(client)
295
296 asyncio.set_event_loop(self.loop)
297 if not self.client_condition:
298 self.client_condition = asyncio.Condition()
299 tasks = [process_task(t) for t in tasks]
300 self.loop.run_until_complete(asyncio.gather(*tasks))
301
302 async def __close_clients(self):
303 for c in self.avail_clients:
304 await c.close()
305 self.avail_clients = []
306 self.num_clients = 0
307
308 def __enter__(self):
309 return self
310
311 def __exit__(self, exc_type, exc_value, traceback):
312 self.close()
313 return False