blob: d2de4891b80d96730e6e2a727a5b07fd8c57b77b [file] [log] [blame]
Andrew Geisslerc926e172021-05-07 16:11:35 -05001#
Patrick Williams92b42cb2022-09-03 06:53:57 -05002# Copyright BitBake Contributors
3#
Andrew Geisslerc926e172021-05-07 16:11:35 -05004# SPDX-License-Identifier: GPL-2.0-only
5#
6
7import abc
8import asyncio
9import json
10import os
11import signal
12import socket
13import sys
Patrick Williams213cb262021-08-07 19:21:33 -050014import multiprocessing
Andrew Geisslerc926e172021-05-07 16:11:35 -050015from . import chunkify, DEFAULT_MAX_CHUNK
16
17
18class ClientError(Exception):
19 pass
20
21
22class ServerError(Exception):
23 pass
24
25
26class AsyncServerConnection(object):
27 def __init__(self, reader, writer, proto_name, logger):
28 self.reader = reader
29 self.writer = writer
30 self.proto_name = proto_name
31 self.max_chunk = DEFAULT_MAX_CHUNK
32 self.handlers = {
33 'chunk-stream': self.handle_chunk,
Andrew Geissler09036742021-06-25 14:25:14 -050034 'ping': self.handle_ping,
Andrew Geisslerc926e172021-05-07 16:11:35 -050035 }
36 self.logger = logger
37
38 async def process_requests(self):
39 try:
40 self.addr = self.writer.get_extra_info('peername')
41 self.logger.debug('Client %r connected' % (self.addr,))
42
43 # Read protocol and version
44 client_protocol = await self.reader.readline()
Patrick Williams2390b1b2022-11-03 13:47:49 -050045 if not client_protocol:
Andrew Geisslerc926e172021-05-07 16:11:35 -050046 return
47
48 (client_proto_name, client_proto_version) = client_protocol.decode('utf-8').rstrip().split()
49 if client_proto_name != self.proto_name:
50 self.logger.debug('Rejecting invalid protocol %s' % (self.proto_name))
51 return
52
53 self.proto_version = tuple(int(v) for v in client_proto_version.split('.'))
54 if not self.validate_proto_version():
55 self.logger.debug('Rejecting invalid protocol version %s' % (client_proto_version))
56 return
57
58 # Read headers. Currently, no headers are implemented, so look for
59 # an empty line to signal the end of the headers
60 while True:
61 line = await self.reader.readline()
Patrick Williams2390b1b2022-11-03 13:47:49 -050062 if not line:
Andrew Geisslerc926e172021-05-07 16:11:35 -050063 return
64
65 line = line.decode('utf-8').rstrip()
66 if not line:
67 break
68
69 # Handle messages
70 while True:
71 d = await self.read_message()
72 if d is None:
73 break
74 await self.dispatch_message(d)
75 await self.writer.drain()
76 except ClientError as e:
77 self.logger.error(str(e))
78 finally:
79 self.writer.close()
80
81 async def dispatch_message(self, msg):
82 for k in self.handlers.keys():
83 if k in msg:
84 self.logger.debug('Handling %s' % k)
85 await self.handlers[k](msg[k])
86 return
87
88 raise ClientError("Unrecognized command %r" % msg)
89
90 def write_message(self, msg):
91 for c in chunkify(json.dumps(msg), self.max_chunk):
92 self.writer.write(c.encode('utf-8'))
93
94 async def read_message(self):
95 l = await self.reader.readline()
96 if not l:
97 return None
98
99 try:
100 message = l.decode('utf-8')
101
102 if not message.endswith('\n'):
103 return None
104
105 return json.loads(message)
106 except (json.JSONDecodeError, UnicodeDecodeError) as e:
107 self.logger.error('Bad message from client: %r' % message)
108 raise e
109
110 async def handle_chunk(self, request):
111 lines = []
112 try:
113 while True:
114 l = await self.reader.readline()
115 l = l.rstrip(b"\n").decode("utf-8")
116 if not l:
117 break
118 lines.append(l)
119
120 msg = json.loads(''.join(lines))
121 except (json.JSONDecodeError, UnicodeDecodeError) as e:
122 self.logger.error('Bad message from client: %r' % lines)
123 raise e
124
125 if 'chunk-stream' in msg:
126 raise ClientError("Nested chunks are not allowed")
127
128 await self.dispatch_message(msg)
129
Andrew Geissler09036742021-06-25 14:25:14 -0500130 async def handle_ping(self, request):
131 response = {'alive': True}
132 self.write_message(response)
133
Andrew Geisslerc926e172021-05-07 16:11:35 -0500134
135class AsyncServer(object):
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500136 def __init__(self, logger):
Andrew Geisslerc926e172021-05-07 16:11:35 -0500137 self._cleanup_socket = None
138 self.logger = logger
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500139 self.start = None
140 self.address = None
141 self.loop = None
Andrew Geisslerc926e172021-05-07 16:11:35 -0500142
143 def start_tcp_server(self, host, port):
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500144 def start_tcp():
145 self.server = self.loop.run_until_complete(
146 asyncio.start_server(self.handle_client, host, port)
147 )
Andrew Geisslerc926e172021-05-07 16:11:35 -0500148
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500149 for s in self.server.sockets:
150 self.logger.debug('Listening on %r' % (s.getsockname(),))
151 # Newer python does this automatically. Do it manually here for
152 # maximum compatibility
153 s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
154 s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
Andrew Geisslerc926e172021-05-07 16:11:35 -0500155
Patrick Williamsdb4c27e2022-08-05 08:10:29 -0500156 # Enable keep alives. This prevents broken client connections
157 # from persisting on the server for long periods of time.
158 s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
159 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
160 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
161 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
162
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500163 name = self.server.sockets[0].getsockname()
164 if self.server.sockets[0].family == socket.AF_INET6:
165 self.address = "[%s]:%d" % (name[0], name[1])
166 else:
167 self.address = "%s:%d" % (name[0], name[1])
168
169 self.start = start_tcp
Andrew Geisslerc926e172021-05-07 16:11:35 -0500170
171 def start_unix_server(self, path):
172 def cleanup():
173 os.unlink(path)
174
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500175 def start_unix():
176 cwd = os.getcwd()
177 try:
178 # Work around path length limits in AF_UNIX
179 os.chdir(os.path.dirname(path))
180 self.server = self.loop.run_until_complete(
181 asyncio.start_unix_server(self.handle_client, os.path.basename(path))
182 )
183 finally:
184 os.chdir(cwd)
Andrew Geisslerc926e172021-05-07 16:11:35 -0500185
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500186 self.logger.debug('Listening on %r' % path)
Andrew Geisslerc926e172021-05-07 16:11:35 -0500187
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500188 self._cleanup_socket = cleanup
189 self.address = "unix://%s" % os.path.abspath(path)
190
191 self.start = start_unix
Andrew Geisslerc926e172021-05-07 16:11:35 -0500192
193 @abc.abstractmethod
194 def accept_client(self, reader, writer):
195 pass
196
197 async def handle_client(self, reader, writer):
198 # writer.transport.set_write_buffer_limits(0)
199 try:
200 client = self.accept_client(reader, writer)
201 await client.process_requests()
202 except Exception as e:
203 import traceback
204 self.logger.error('Error from client: %s' % str(e), exc_info=True)
205 traceback.print_exc()
206 writer.close()
Andrew Geissler09036742021-06-25 14:25:14 -0500207 self.logger.debug('Client disconnected')
Andrew Geisslerc926e172021-05-07 16:11:35 -0500208
209 def run_loop_forever(self):
210 try:
211 self.loop.run_forever()
212 except KeyboardInterrupt:
213 pass
214
215 def signal_handler(self):
Patrick Williams213cb262021-08-07 19:21:33 -0500216 self.logger.debug("Got exit signal")
Andrew Geisslerc926e172021-05-07 16:11:35 -0500217 self.loop.stop()
218
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500219 def _serve_forever(self):
Andrew Geisslerc926e172021-05-07 16:11:35 -0500220 try:
221 self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler)
Patrick Williams213cb262021-08-07 19:21:33 -0500222 signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM])
Andrew Geisslerc926e172021-05-07 16:11:35 -0500223
224 self.run_loop_forever()
225 self.server.close()
226
227 self.loop.run_until_complete(self.server.wait_closed())
Andrew Geissler09036742021-06-25 14:25:14 -0500228 self.logger.debug('Server shutting down')
Andrew Geisslerc926e172021-05-07 16:11:35 -0500229 finally:
Andrew Geisslerc926e172021-05-07 16:11:35 -0500230 if self._cleanup_socket is not None:
231 self._cleanup_socket()
Patrick Williams213cb262021-08-07 19:21:33 -0500232
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500233 def serve_forever(self):
234 """
235 Serve requests in the current process
236 """
237 # Create loop and override any loop that may have existed in
238 # a parent process. It is possible that the usecases of
239 # serve_forever might be constrained enough to allow using
240 # get_event_loop here, but better safe than sorry for now.
241 self.loop = asyncio.new_event_loop()
242 asyncio.set_event_loop(self.loop)
243 self.start()
244 self._serve_forever()
245
Patrick Williams213cb262021-08-07 19:21:33 -0500246 def serve_as_process(self, *, prefunc=None, args=()):
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500247 """
248 Serve requests in a child process
249 """
250 def run(queue):
251 # Create loop and override any loop that may have existed
252 # in a parent process. Without doing this and instead
253 # using get_event_loop, at the very minimum the hashserv
254 # unit tests will hang when running the second test.
255 # This happens since get_event_loop in the spawned server
256 # process for the second testcase ends up with the loop
257 # from the hashserv client created in the unit test process
258 # when running the first testcase. The problem is somewhat
259 # more general, though, as any potential use of asyncio in
260 # Cooker could create a loop that needs to replaced in this
261 # new process.
262 self.loop = asyncio.new_event_loop()
263 asyncio.set_event_loop(self.loop)
264 try:
265 self.start()
266 finally:
267 queue.put(self.address)
268 queue.close()
269
Patrick Williams213cb262021-08-07 19:21:33 -0500270 if prefunc is not None:
271 prefunc(self, *args)
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500272
273 self._serve_forever()
274
275 if sys.version_info >= (3, 6):
276 self.loop.run_until_complete(self.loop.shutdown_asyncgens())
277 self.loop.close()
278
279 queue = multiprocessing.Queue()
Patrick Williams213cb262021-08-07 19:21:33 -0500280
281 # Temporarily block SIGTERM. The server process will inherit this
282 # block which will ensure it doesn't receive the SIGTERM until the
283 # handler is ready for it
284 mask = signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGTERM])
285 try:
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500286 self.process = multiprocessing.Process(target=run, args=(queue,))
Patrick Williams213cb262021-08-07 19:21:33 -0500287 self.process.start()
288
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500289 self.address = queue.get()
290 queue.close()
291 queue.join_thread()
292
Patrick Williams213cb262021-08-07 19:21:33 -0500293 return self.process
294 finally:
295 signal.pthread_sigmask(signal.SIG_SETMASK, mask)