blob: a66117acad1a54a3942d401337aa928e7a2f1141 [file] [log] [blame]
Andrew Geisslerc926e172021-05-07 16:11:35 -05001#
Patrick Williams92b42cb2022-09-03 06:53:57 -05002# Copyright BitBake Contributors
3#
Andrew Geisslerc926e172021-05-07 16:11:35 -05004# SPDX-License-Identifier: GPL-2.0-only
5#
6
7import abc
8import asyncio
9import json
10import os
11import signal
12import socket
13import sys
Patrick Williams213cb262021-08-07 19:21:33 -050014import multiprocessing
Patrick Williamsac13d5f2023-11-24 18:59:46 -060015import logging
16from .connection import StreamConnection, WebsocketConnection
17from .exceptions import ClientError, ServerError, ConnectionClosedError, InvokeError
Andrew Geisslerc926e172021-05-07 16:11:35 -050018
19
Patrick Williamsac13d5f2023-11-24 18:59:46 -060020class ClientLoggerAdapter(logging.LoggerAdapter):
21 def process(self, msg, kwargs):
22 return f"[Client {self.extra['address']}] {msg}", kwargs
Andrew Geisslerc926e172021-05-07 16:11:35 -050023
24
25class AsyncServerConnection(object):
Patrick Williamsac13d5f2023-11-24 18:59:46 -060026 # If a handler returns this object (e.g. `return self.NO_RESPONSE`), no
27 # return message will be automatically be sent back to the client
28 NO_RESPONSE = object()
29
30 def __init__(self, socket, proto_name, logger):
31 self.socket = socket
Andrew Geisslerc926e172021-05-07 16:11:35 -050032 self.proto_name = proto_name
Andrew Geisslerc926e172021-05-07 16:11:35 -050033 self.handlers = {
Patrick Williamsac13d5f2023-11-24 18:59:46 -060034 "ping": self.handle_ping,
Andrew Geisslerc926e172021-05-07 16:11:35 -050035 }
Patrick Williamsac13d5f2023-11-24 18:59:46 -060036 self.logger = ClientLoggerAdapter(
37 logger,
38 {
39 "address": socket.address,
40 },
41 )
Patrick Williams39653562024-03-01 08:54:02 -060042 self.client_headers = {}
Patrick Williamsac13d5f2023-11-24 18:59:46 -060043
44 async def close(self):
45 await self.socket.close()
Andrew Geisslerc926e172021-05-07 16:11:35 -050046
Patrick Williams39653562024-03-01 08:54:02 -060047 async def handle_headers(self, headers):
48 return {}
49
Andrew Geisslerc926e172021-05-07 16:11:35 -050050 async def process_requests(self):
51 try:
Patrick Williamsac13d5f2023-11-24 18:59:46 -060052 self.logger.info("Client %r connected" % (self.socket.address,))
Andrew Geisslerc926e172021-05-07 16:11:35 -050053
54 # Read protocol and version
Patrick Williamsac13d5f2023-11-24 18:59:46 -060055 client_protocol = await self.socket.recv()
Patrick Williams2390b1b2022-11-03 13:47:49 -050056 if not client_protocol:
Andrew Geisslerc926e172021-05-07 16:11:35 -050057 return
58
Patrick Williamsac13d5f2023-11-24 18:59:46 -060059 (client_proto_name, client_proto_version) = client_protocol.split()
Andrew Geisslerc926e172021-05-07 16:11:35 -050060 if client_proto_name != self.proto_name:
Patrick Williamsac13d5f2023-11-24 18:59:46 -060061 self.logger.debug("Rejecting invalid protocol %s" % (self.proto_name))
Andrew Geisslerc926e172021-05-07 16:11:35 -050062 return
63
Patrick Williamsac13d5f2023-11-24 18:59:46 -060064 self.proto_version = tuple(int(v) for v in client_proto_version.split("."))
Andrew Geisslerc926e172021-05-07 16:11:35 -050065 if not self.validate_proto_version():
Patrick Williamsac13d5f2023-11-24 18:59:46 -060066 self.logger.debug(
67 "Rejecting invalid protocol version %s" % (client_proto_version)
68 )
Andrew Geisslerc926e172021-05-07 16:11:35 -050069 return
70
Patrick Williams39653562024-03-01 08:54:02 -060071 # Read headers
72 self.client_headers = {}
Andrew Geisslerc926e172021-05-07 16:11:35 -050073 while True:
Patrick Williamsac13d5f2023-11-24 18:59:46 -060074 header = await self.socket.recv()
75 if not header:
Patrick Williams39653562024-03-01 08:54:02 -060076 # Empty line. End of headers
Andrew Geisslerc926e172021-05-07 16:11:35 -050077 break
Patrick Williams39653562024-03-01 08:54:02 -060078 tag, value = header.split(":", 1)
79 self.client_headers[tag.lower()] = value.strip()
80
81 if self.client_headers.get("needs-headers", "false") == "true":
82 for k, v in (await self.handle_headers(self.client_headers)).items():
83 await self.socket.send("%s: %s" % (k, v))
84 await self.socket.send("")
Andrew Geisslerc926e172021-05-07 16:11:35 -050085
86 # Handle messages
87 while True:
Patrick Williamsac13d5f2023-11-24 18:59:46 -060088 d = await self.socket.recv_message()
Andrew Geisslerc926e172021-05-07 16:11:35 -050089 if d is None:
90 break
Patrick Williamsac13d5f2023-11-24 18:59:46 -060091 try:
92 response = await self.dispatch_message(d)
93 except InvokeError as e:
94 await self.socket.send_message(
95 {"invoke-error": {"message": str(e)}}
96 )
97 break
98
99 if response is not self.NO_RESPONSE:
100 await self.socket.send_message(response)
101
102 except ConnectionClosedError as e:
103 self.logger.info(str(e))
104 except (ClientError, ConnectionError) as e:
Andrew Geisslerc926e172021-05-07 16:11:35 -0500105 self.logger.error(str(e))
106 finally:
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600107 await self.close()
Andrew Geisslerc926e172021-05-07 16:11:35 -0500108
109 async def dispatch_message(self, msg):
110 for k in self.handlers.keys():
111 if k in msg:
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600112 self.logger.debug("Handling %s" % k)
113 return await self.handlers[k](msg[k])
Andrew Geisslerc926e172021-05-07 16:11:35 -0500114
115 raise ClientError("Unrecognized command %r" % msg)
116
Andrew Geissler09036742021-06-25 14:25:14 -0500117 async def handle_ping(self, request):
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600118 return {"alive": True}
119
120
121class StreamServer(object):
122 def __init__(self, handler, logger):
123 self.handler = handler
124 self.logger = logger
125 self.closed = False
126
127 async def handle_stream_client(self, reader, writer):
128 # writer.transport.set_write_buffer_limits(0)
129 socket = StreamConnection(reader, writer, -1)
130 if self.closed:
131 await socket.close()
132 return
133
134 await self.handler(socket)
135
136 async def stop(self):
137 self.closed = True
138
139
140class TCPStreamServer(StreamServer):
141 def __init__(self, host, port, handler, logger):
142 super().__init__(handler, logger)
143 self.host = host
144 self.port = port
145
146 def start(self, loop):
147 self.server = loop.run_until_complete(
148 asyncio.start_server(self.handle_stream_client, self.host, self.port)
149 )
150
151 for s in self.server.sockets:
152 self.logger.debug("Listening on %r" % (s.getsockname(),))
153 # Newer python does this automatically. Do it manually here for
154 # maximum compatibility
155 s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
156 s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
157
158 # Enable keep alives. This prevents broken client connections
159 # from persisting on the server for long periods of time.
160 s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
161 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
162 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
163 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
164
165 name = self.server.sockets[0].getsockname()
166 if self.server.sockets[0].family == socket.AF_INET6:
167 self.address = "[%s]:%d" % (name[0], name[1])
168 else:
169 self.address = "%s:%d" % (name[0], name[1])
170
171 return [self.server.wait_closed()]
172
173 async def stop(self):
174 await super().stop()
175 self.server.close()
176
177 def cleanup(self):
178 pass
179
180
181class UnixStreamServer(StreamServer):
182 def __init__(self, path, handler, logger):
183 super().__init__(handler, logger)
184 self.path = path
185
186 def start(self, loop):
187 cwd = os.getcwd()
188 try:
189 # Work around path length limits in AF_UNIX
190 os.chdir(os.path.dirname(self.path))
191 self.server = loop.run_until_complete(
192 asyncio.start_unix_server(
193 self.handle_stream_client, os.path.basename(self.path)
194 )
195 )
196 finally:
197 os.chdir(cwd)
198
199 self.logger.debug("Listening on %r" % self.path)
200 self.address = "unix://%s" % os.path.abspath(self.path)
201 return [self.server.wait_closed()]
202
203 async def stop(self):
204 await super().stop()
205 self.server.close()
206
207 def cleanup(self):
208 os.unlink(self.path)
209
210
211class WebsocketsServer(object):
212 def __init__(self, host, port, handler, logger):
213 self.host = host
214 self.port = port
215 self.handler = handler
216 self.logger = logger
217
218 def start(self, loop):
219 import websockets.server
220
221 self.server = loop.run_until_complete(
222 websockets.server.serve(
223 self.client_handler,
224 self.host,
225 self.port,
226 ping_interval=None,
227 )
228 )
229
230 for s in self.server.sockets:
231 self.logger.debug("Listening on %r" % (s.getsockname(),))
232
233 # Enable keep alives. This prevents broken client connections
234 # from persisting on the server for long periods of time.
235 s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
236 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
237 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
238 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
239
240 name = self.server.sockets[0].getsockname()
241 if self.server.sockets[0].family == socket.AF_INET6:
242 self.address = "ws://[%s]:%d" % (name[0], name[1])
243 else:
244 self.address = "ws://%s:%d" % (name[0], name[1])
245
246 return [self.server.wait_closed()]
247
248 async def stop(self):
249 self.server.close()
250
251 def cleanup(self):
252 pass
253
254 async def client_handler(self, websocket):
255 socket = WebsocketConnection(websocket, -1)
256 await self.handler(socket)
Andrew Geissler09036742021-06-25 14:25:14 -0500257
Andrew Geisslerc926e172021-05-07 16:11:35 -0500258
259class AsyncServer(object):
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500260 def __init__(self, logger):
Andrew Geisslerc926e172021-05-07 16:11:35 -0500261 self.logger = logger
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500262 self.loop = None
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600263 self.run_tasks = []
Andrew Geisslerc926e172021-05-07 16:11:35 -0500264
265 def start_tcp_server(self, host, port):
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600266 self.server = TCPStreamServer(host, port, self._client_handler, self.logger)
Andrew Geisslerc926e172021-05-07 16:11:35 -0500267
268 def start_unix_server(self, path):
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600269 self.server = UnixStreamServer(path, self._client_handler, self.logger)
Andrew Geisslerc926e172021-05-07 16:11:35 -0500270
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600271 def start_websocket_server(self, host, port):
272 self.server = WebsocketsServer(host, port, self._client_handler, self.logger)
Andrew Geisslerc926e172021-05-07 16:11:35 -0500273
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600274 async def _client_handler(self, socket):
275 address = socket.address
Andrew Geisslerc926e172021-05-07 16:11:35 -0500276 try:
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600277 client = self.accept_client(socket)
Andrew Geisslerc926e172021-05-07 16:11:35 -0500278 await client.process_requests()
279 except Exception as e:
280 import traceback
Andrew Geisslerc926e172021-05-07 16:11:35 -0500281
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600282 self.logger.error(
283 "Error from client %s: %s" % (address, str(e)), exc_info=True
284 )
285 traceback.print_exc()
286 finally:
287 self.logger.debug("Client %s disconnected", address)
288 await socket.close()
289
290 @abc.abstractmethod
291 def accept_client(self, socket):
292 pass
293
294 async def stop(self):
295 self.logger.debug("Stopping server")
296 await self.server.stop()
297
298 def start(self):
299 tasks = self.server.start(self.loop)
300 self.address = self.server.address
301 return tasks
Andrew Geisslerc926e172021-05-07 16:11:35 -0500302
303 def signal_handler(self):
Patrick Williams213cb262021-08-07 19:21:33 -0500304 self.logger.debug("Got exit signal")
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600305 self.loop.create_task(self.stop())
Andrew Geisslerc926e172021-05-07 16:11:35 -0500306
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600307 def _serve_forever(self, tasks):
Andrew Geisslerc926e172021-05-07 16:11:35 -0500308 try:
309 self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler)
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600310 self.loop.add_signal_handler(signal.SIGINT, self.signal_handler)
311 self.loop.add_signal_handler(signal.SIGQUIT, self.signal_handler)
Patrick Williams213cb262021-08-07 19:21:33 -0500312 signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM])
Andrew Geisslerc926e172021-05-07 16:11:35 -0500313
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600314 self.loop.run_until_complete(asyncio.gather(*tasks))
Andrew Geisslerc926e172021-05-07 16:11:35 -0500315
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600316 self.logger.debug("Server shutting down")
Andrew Geisslerc926e172021-05-07 16:11:35 -0500317 finally:
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600318 self.server.cleanup()
Patrick Williams213cb262021-08-07 19:21:33 -0500319
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500320 def serve_forever(self):
321 """
322 Serve requests in the current process
323 """
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600324 self._create_loop()
325 tasks = self.start()
326 self._serve_forever(tasks)
327 self.loop.close()
328
329 def _create_loop(self):
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500330 # Create loop and override any loop that may have existed in
331 # a parent process. It is possible that the usecases of
332 # serve_forever might be constrained enough to allow using
333 # get_event_loop here, but better safe than sorry for now.
334 self.loop = asyncio.new_event_loop()
335 asyncio.set_event_loop(self.loop)
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500336
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600337 def serve_as_process(self, *, prefunc=None, args=(), log_level=None):
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500338 """
339 Serve requests in a child process
340 """
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600341
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500342 def run(queue):
343 # Create loop and override any loop that may have existed
344 # in a parent process. Without doing this and instead
345 # using get_event_loop, at the very minimum the hashserv
346 # unit tests will hang when running the second test.
347 # This happens since get_event_loop in the spawned server
348 # process for the second testcase ends up with the loop
349 # from the hashserv client created in the unit test process
350 # when running the first testcase. The problem is somewhat
351 # more general, though, as any potential use of asyncio in
352 # Cooker could create a loop that needs to replaced in this
353 # new process.
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600354 self._create_loop()
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500355 try:
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600356 self.address = None
357 tasks = self.start()
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500358 finally:
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600359 # Always put the server address to wake up the parent task
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500360 queue.put(self.address)
361 queue.close()
362
Patrick Williams213cb262021-08-07 19:21:33 -0500363 if prefunc is not None:
364 prefunc(self, *args)
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500365
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600366 if log_level is not None:
367 self.logger.setLevel(log_level)
368
369 self._serve_forever(tasks)
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500370
371 if sys.version_info >= (3, 6):
372 self.loop.run_until_complete(self.loop.shutdown_asyncgens())
373 self.loop.close()
374
375 queue = multiprocessing.Queue()
Patrick Williams213cb262021-08-07 19:21:33 -0500376
377 # Temporarily block SIGTERM. The server process will inherit this
378 # block which will ensure it doesn't receive the SIGTERM until the
379 # handler is ready for it
380 mask = signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGTERM])
381 try:
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500382 self.process = multiprocessing.Process(target=run, args=(queue,))
Patrick Williams213cb262021-08-07 19:21:33 -0500383 self.process.start()
384
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500385 self.address = queue.get()
386 queue.close()
387 queue.join_thread()
388
Patrick Williams213cb262021-08-07 19:21:33 -0500389 return self.process
390 finally:
391 signal.pthread_sigmask(signal.SIG_SETMASK, mask)