blob: f0be9a6cdb23deda91f5047716035c2fb043c5bf [file] [log] [blame]
Andrew Geisslerc926e172021-05-07 16:11:35 -05001#
Patrick Williams92b42cb2022-09-03 06:53:57 -05002# Copyright BitBake Contributors
3#
Andrew Geisslerc926e172021-05-07 16:11:35 -05004# SPDX-License-Identifier: GPL-2.0-only
5#
6
7import abc
8import asyncio
9import json
10import os
11import signal
12import socket
13import sys
Patrick Williams213cb262021-08-07 19:21:33 -050014import multiprocessing
Patrick Williamsac13d5f2023-11-24 18:59:46 -060015import logging
16from .connection import StreamConnection, WebsocketConnection
17from .exceptions import ClientError, ServerError, ConnectionClosedError, InvokeError
Andrew Geisslerc926e172021-05-07 16:11:35 -050018
19
Patrick Williamsac13d5f2023-11-24 18:59:46 -060020class ClientLoggerAdapter(logging.LoggerAdapter):
21 def process(self, msg, kwargs):
22 return f"[Client {self.extra['address']}] {msg}", kwargs
Andrew Geisslerc926e172021-05-07 16:11:35 -050023
24
25class AsyncServerConnection(object):
Patrick Williamsac13d5f2023-11-24 18:59:46 -060026 # If a handler returns this object (e.g. `return self.NO_RESPONSE`), no
27 # return message will be automatically be sent back to the client
28 NO_RESPONSE = object()
29
30 def __init__(self, socket, proto_name, logger):
31 self.socket = socket
Andrew Geisslerc926e172021-05-07 16:11:35 -050032 self.proto_name = proto_name
Andrew Geisslerc926e172021-05-07 16:11:35 -050033 self.handlers = {
Patrick Williamsac13d5f2023-11-24 18:59:46 -060034 "ping": self.handle_ping,
Andrew Geisslerc926e172021-05-07 16:11:35 -050035 }
Patrick Williamsac13d5f2023-11-24 18:59:46 -060036 self.logger = ClientLoggerAdapter(
37 logger,
38 {
39 "address": socket.address,
40 },
41 )
42
43 async def close(self):
44 await self.socket.close()
Andrew Geisslerc926e172021-05-07 16:11:35 -050045
46 async def process_requests(self):
47 try:
Patrick Williamsac13d5f2023-11-24 18:59:46 -060048 self.logger.info("Client %r connected" % (self.socket.address,))
Andrew Geisslerc926e172021-05-07 16:11:35 -050049
50 # Read protocol and version
Patrick Williamsac13d5f2023-11-24 18:59:46 -060051 client_protocol = await self.socket.recv()
Patrick Williams2390b1b2022-11-03 13:47:49 -050052 if not client_protocol:
Andrew Geisslerc926e172021-05-07 16:11:35 -050053 return
54
Patrick Williamsac13d5f2023-11-24 18:59:46 -060055 (client_proto_name, client_proto_version) = client_protocol.split()
Andrew Geisslerc926e172021-05-07 16:11:35 -050056 if client_proto_name != self.proto_name:
Patrick Williamsac13d5f2023-11-24 18:59:46 -060057 self.logger.debug("Rejecting invalid protocol %s" % (self.proto_name))
Andrew Geisslerc926e172021-05-07 16:11:35 -050058 return
59
Patrick Williamsac13d5f2023-11-24 18:59:46 -060060 self.proto_version = tuple(int(v) for v in client_proto_version.split("."))
Andrew Geisslerc926e172021-05-07 16:11:35 -050061 if not self.validate_proto_version():
Patrick Williamsac13d5f2023-11-24 18:59:46 -060062 self.logger.debug(
63 "Rejecting invalid protocol version %s" % (client_proto_version)
64 )
Andrew Geisslerc926e172021-05-07 16:11:35 -050065 return
66
67 # Read headers. Currently, no headers are implemented, so look for
68 # an empty line to signal the end of the headers
69 while True:
Patrick Williamsac13d5f2023-11-24 18:59:46 -060070 header = await self.socket.recv()
71 if not header:
Andrew Geisslerc926e172021-05-07 16:11:35 -050072 break
73
74 # Handle messages
75 while True:
Patrick Williamsac13d5f2023-11-24 18:59:46 -060076 d = await self.socket.recv_message()
Andrew Geisslerc926e172021-05-07 16:11:35 -050077 if d is None:
78 break
Patrick Williamsac13d5f2023-11-24 18:59:46 -060079 try:
80 response = await self.dispatch_message(d)
81 except InvokeError as e:
82 await self.socket.send_message(
83 {"invoke-error": {"message": str(e)}}
84 )
85 break
86
87 if response is not self.NO_RESPONSE:
88 await self.socket.send_message(response)
89
90 except ConnectionClosedError as e:
91 self.logger.info(str(e))
92 except (ClientError, ConnectionError) as e:
Andrew Geisslerc926e172021-05-07 16:11:35 -050093 self.logger.error(str(e))
94 finally:
Patrick Williamsac13d5f2023-11-24 18:59:46 -060095 await self.close()
Andrew Geisslerc926e172021-05-07 16:11:35 -050096
97 async def dispatch_message(self, msg):
98 for k in self.handlers.keys():
99 if k in msg:
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600100 self.logger.debug("Handling %s" % k)
101 return await self.handlers[k](msg[k])
Andrew Geisslerc926e172021-05-07 16:11:35 -0500102
103 raise ClientError("Unrecognized command %r" % msg)
104
Andrew Geissler09036742021-06-25 14:25:14 -0500105 async def handle_ping(self, request):
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600106 return {"alive": True}
107
108
109class StreamServer(object):
110 def __init__(self, handler, logger):
111 self.handler = handler
112 self.logger = logger
113 self.closed = False
114
115 async def handle_stream_client(self, reader, writer):
116 # writer.transport.set_write_buffer_limits(0)
117 socket = StreamConnection(reader, writer, -1)
118 if self.closed:
119 await socket.close()
120 return
121
122 await self.handler(socket)
123
124 async def stop(self):
125 self.closed = True
126
127
128class TCPStreamServer(StreamServer):
129 def __init__(self, host, port, handler, logger):
130 super().__init__(handler, logger)
131 self.host = host
132 self.port = port
133
134 def start(self, loop):
135 self.server = loop.run_until_complete(
136 asyncio.start_server(self.handle_stream_client, self.host, self.port)
137 )
138
139 for s in self.server.sockets:
140 self.logger.debug("Listening on %r" % (s.getsockname(),))
141 # Newer python does this automatically. Do it manually here for
142 # maximum compatibility
143 s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
144 s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
145
146 # Enable keep alives. This prevents broken client connections
147 # from persisting on the server for long periods of time.
148 s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
149 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
150 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
151 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
152
153 name = self.server.sockets[0].getsockname()
154 if self.server.sockets[0].family == socket.AF_INET6:
155 self.address = "[%s]:%d" % (name[0], name[1])
156 else:
157 self.address = "%s:%d" % (name[0], name[1])
158
159 return [self.server.wait_closed()]
160
161 async def stop(self):
162 await super().stop()
163 self.server.close()
164
165 def cleanup(self):
166 pass
167
168
169class UnixStreamServer(StreamServer):
170 def __init__(self, path, handler, logger):
171 super().__init__(handler, logger)
172 self.path = path
173
174 def start(self, loop):
175 cwd = os.getcwd()
176 try:
177 # Work around path length limits in AF_UNIX
178 os.chdir(os.path.dirname(self.path))
179 self.server = loop.run_until_complete(
180 asyncio.start_unix_server(
181 self.handle_stream_client, os.path.basename(self.path)
182 )
183 )
184 finally:
185 os.chdir(cwd)
186
187 self.logger.debug("Listening on %r" % self.path)
188 self.address = "unix://%s" % os.path.abspath(self.path)
189 return [self.server.wait_closed()]
190
191 async def stop(self):
192 await super().stop()
193 self.server.close()
194
195 def cleanup(self):
196 os.unlink(self.path)
197
198
199class WebsocketsServer(object):
200 def __init__(self, host, port, handler, logger):
201 self.host = host
202 self.port = port
203 self.handler = handler
204 self.logger = logger
205
206 def start(self, loop):
207 import websockets.server
208
209 self.server = loop.run_until_complete(
210 websockets.server.serve(
211 self.client_handler,
212 self.host,
213 self.port,
214 ping_interval=None,
215 )
216 )
217
218 for s in self.server.sockets:
219 self.logger.debug("Listening on %r" % (s.getsockname(),))
220
221 # Enable keep alives. This prevents broken client connections
222 # from persisting on the server for long periods of time.
223 s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
224 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
225 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
226 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
227
228 name = self.server.sockets[0].getsockname()
229 if self.server.sockets[0].family == socket.AF_INET6:
230 self.address = "ws://[%s]:%d" % (name[0], name[1])
231 else:
232 self.address = "ws://%s:%d" % (name[0], name[1])
233
234 return [self.server.wait_closed()]
235
236 async def stop(self):
237 self.server.close()
238
239 def cleanup(self):
240 pass
241
242 async def client_handler(self, websocket):
243 socket = WebsocketConnection(websocket, -1)
244 await self.handler(socket)
Andrew Geissler09036742021-06-25 14:25:14 -0500245
Andrew Geisslerc926e172021-05-07 16:11:35 -0500246
247class AsyncServer(object):
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500248 def __init__(self, logger):
Andrew Geisslerc926e172021-05-07 16:11:35 -0500249 self.logger = logger
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500250 self.loop = None
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600251 self.run_tasks = []
Andrew Geisslerc926e172021-05-07 16:11:35 -0500252
253 def start_tcp_server(self, host, port):
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600254 self.server = TCPStreamServer(host, port, self._client_handler, self.logger)
Andrew Geisslerc926e172021-05-07 16:11:35 -0500255
256 def start_unix_server(self, path):
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600257 self.server = UnixStreamServer(path, self._client_handler, self.logger)
Andrew Geisslerc926e172021-05-07 16:11:35 -0500258
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600259 def start_websocket_server(self, host, port):
260 self.server = WebsocketsServer(host, port, self._client_handler, self.logger)
Andrew Geisslerc926e172021-05-07 16:11:35 -0500261
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600262 async def _client_handler(self, socket):
263 address = socket.address
Andrew Geisslerc926e172021-05-07 16:11:35 -0500264 try:
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600265 client = self.accept_client(socket)
Andrew Geisslerc926e172021-05-07 16:11:35 -0500266 await client.process_requests()
267 except Exception as e:
268 import traceback
Andrew Geisslerc926e172021-05-07 16:11:35 -0500269
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600270 self.logger.error(
271 "Error from client %s: %s" % (address, str(e)), exc_info=True
272 )
273 traceback.print_exc()
274 finally:
275 self.logger.debug("Client %s disconnected", address)
276 await socket.close()
277
278 @abc.abstractmethod
279 def accept_client(self, socket):
280 pass
281
282 async def stop(self):
283 self.logger.debug("Stopping server")
284 await self.server.stop()
285
286 def start(self):
287 tasks = self.server.start(self.loop)
288 self.address = self.server.address
289 return tasks
Andrew Geisslerc926e172021-05-07 16:11:35 -0500290
291 def signal_handler(self):
Patrick Williams213cb262021-08-07 19:21:33 -0500292 self.logger.debug("Got exit signal")
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600293 self.loop.create_task(self.stop())
Andrew Geisslerc926e172021-05-07 16:11:35 -0500294
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600295 def _serve_forever(self, tasks):
Andrew Geisslerc926e172021-05-07 16:11:35 -0500296 try:
297 self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler)
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600298 self.loop.add_signal_handler(signal.SIGINT, self.signal_handler)
299 self.loop.add_signal_handler(signal.SIGQUIT, self.signal_handler)
Patrick Williams213cb262021-08-07 19:21:33 -0500300 signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM])
Andrew Geisslerc926e172021-05-07 16:11:35 -0500301
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600302 self.loop.run_until_complete(asyncio.gather(*tasks))
Andrew Geisslerc926e172021-05-07 16:11:35 -0500303
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600304 self.logger.debug("Server shutting down")
Andrew Geisslerc926e172021-05-07 16:11:35 -0500305 finally:
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600306 self.server.cleanup()
Patrick Williams213cb262021-08-07 19:21:33 -0500307
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500308 def serve_forever(self):
309 """
310 Serve requests in the current process
311 """
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600312 self._create_loop()
313 tasks = self.start()
314 self._serve_forever(tasks)
315 self.loop.close()
316
317 def _create_loop(self):
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500318 # Create loop and override any loop that may have existed in
319 # a parent process. It is possible that the usecases of
320 # serve_forever might be constrained enough to allow using
321 # get_event_loop here, but better safe than sorry for now.
322 self.loop = asyncio.new_event_loop()
323 asyncio.set_event_loop(self.loop)
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500324
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600325 def serve_as_process(self, *, prefunc=None, args=(), log_level=None):
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500326 """
327 Serve requests in a child process
328 """
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600329
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500330 def run(queue):
331 # Create loop and override any loop that may have existed
332 # in a parent process. Without doing this and instead
333 # using get_event_loop, at the very minimum the hashserv
334 # unit tests will hang when running the second test.
335 # This happens since get_event_loop in the spawned server
336 # process for the second testcase ends up with the loop
337 # from the hashserv client created in the unit test process
338 # when running the first testcase. The problem is somewhat
339 # more general, though, as any potential use of asyncio in
340 # Cooker could create a loop that needs to replaced in this
341 # new process.
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600342 self._create_loop()
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500343 try:
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600344 self.address = None
345 tasks = self.start()
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500346 finally:
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600347 # Always put the server address to wake up the parent task
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500348 queue.put(self.address)
349 queue.close()
350
Patrick Williams213cb262021-08-07 19:21:33 -0500351 if prefunc is not None:
352 prefunc(self, *args)
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500353
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600354 if log_level is not None:
355 self.logger.setLevel(log_level)
356
357 self._serve_forever(tasks)
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500358
359 if sys.version_info >= (3, 6):
360 self.loop.run_until_complete(self.loop.shutdown_asyncgens())
361 self.loop.close()
362
363 queue = multiprocessing.Queue()
Patrick Williams213cb262021-08-07 19:21:33 -0500364
365 # Temporarily block SIGTERM. The server process will inherit this
366 # block which will ensure it doesn't receive the SIGTERM until the
367 # handler is ready for it
368 mask = signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGTERM])
369 try:
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500370 self.process = multiprocessing.Process(target=run, args=(queue,))
Patrick Williams213cb262021-08-07 19:21:33 -0500371 self.process.start()
372
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500373 self.address = queue.get()
374 queue.close()
375 queue.join_thread()
376
Patrick Williams213cb262021-08-07 19:21:33 -0500377 return self.process
378 finally:
379 signal.pthread_sigmask(signal.SIG_SETMASK, mask)