blob: 585bc121da1764824212260ab7278a06746db2ff [file] [log] [blame]
Andrew Geisslerc926e172021-05-07 16:11:35 -05001#
2# SPDX-License-Identifier: GPL-2.0-only
3#
4
5import abc
6import asyncio
7import json
8import os
9import signal
10import socket
11import sys
Patrick Williams213cb262021-08-07 19:21:33 -050012import multiprocessing
Andrew Geisslerc926e172021-05-07 16:11:35 -050013from . import chunkify, DEFAULT_MAX_CHUNK
14
15
16class ClientError(Exception):
17 pass
18
19
20class ServerError(Exception):
21 pass
22
23
24class AsyncServerConnection(object):
25 def __init__(self, reader, writer, proto_name, logger):
26 self.reader = reader
27 self.writer = writer
28 self.proto_name = proto_name
29 self.max_chunk = DEFAULT_MAX_CHUNK
30 self.handlers = {
31 'chunk-stream': self.handle_chunk,
Andrew Geissler09036742021-06-25 14:25:14 -050032 'ping': self.handle_ping,
Andrew Geisslerc926e172021-05-07 16:11:35 -050033 }
34 self.logger = logger
35
36 async def process_requests(self):
37 try:
38 self.addr = self.writer.get_extra_info('peername')
39 self.logger.debug('Client %r connected' % (self.addr,))
40
41 # Read protocol and version
42 client_protocol = await self.reader.readline()
43 if client_protocol is None:
44 return
45
46 (client_proto_name, client_proto_version) = client_protocol.decode('utf-8').rstrip().split()
47 if client_proto_name != self.proto_name:
48 self.logger.debug('Rejecting invalid protocol %s' % (self.proto_name))
49 return
50
51 self.proto_version = tuple(int(v) for v in client_proto_version.split('.'))
52 if not self.validate_proto_version():
53 self.logger.debug('Rejecting invalid protocol version %s' % (client_proto_version))
54 return
55
56 # Read headers. Currently, no headers are implemented, so look for
57 # an empty line to signal the end of the headers
58 while True:
59 line = await self.reader.readline()
60 if line is None:
61 return
62
63 line = line.decode('utf-8').rstrip()
64 if not line:
65 break
66
67 # Handle messages
68 while True:
69 d = await self.read_message()
70 if d is None:
71 break
72 await self.dispatch_message(d)
73 await self.writer.drain()
74 except ClientError as e:
75 self.logger.error(str(e))
76 finally:
77 self.writer.close()
78
79 async def dispatch_message(self, msg):
80 for k in self.handlers.keys():
81 if k in msg:
82 self.logger.debug('Handling %s' % k)
83 await self.handlers[k](msg[k])
84 return
85
86 raise ClientError("Unrecognized command %r" % msg)
87
88 def write_message(self, msg):
89 for c in chunkify(json.dumps(msg), self.max_chunk):
90 self.writer.write(c.encode('utf-8'))
91
92 async def read_message(self):
93 l = await self.reader.readline()
94 if not l:
95 return None
96
97 try:
98 message = l.decode('utf-8')
99
100 if not message.endswith('\n'):
101 return None
102
103 return json.loads(message)
104 except (json.JSONDecodeError, UnicodeDecodeError) as e:
105 self.logger.error('Bad message from client: %r' % message)
106 raise e
107
108 async def handle_chunk(self, request):
109 lines = []
110 try:
111 while True:
112 l = await self.reader.readline()
113 l = l.rstrip(b"\n").decode("utf-8")
114 if not l:
115 break
116 lines.append(l)
117
118 msg = json.loads(''.join(lines))
119 except (json.JSONDecodeError, UnicodeDecodeError) as e:
120 self.logger.error('Bad message from client: %r' % lines)
121 raise e
122
123 if 'chunk-stream' in msg:
124 raise ClientError("Nested chunks are not allowed")
125
126 await self.dispatch_message(msg)
127
Andrew Geissler09036742021-06-25 14:25:14 -0500128 async def handle_ping(self, request):
129 response = {'alive': True}
130 self.write_message(response)
131
Andrew Geisslerc926e172021-05-07 16:11:35 -0500132
133class AsyncServer(object):
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500134 def __init__(self, logger):
Andrew Geisslerc926e172021-05-07 16:11:35 -0500135 self._cleanup_socket = None
136 self.logger = logger
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500137 self.start = None
138 self.address = None
139 self.loop = None
Andrew Geisslerc926e172021-05-07 16:11:35 -0500140
141 def start_tcp_server(self, host, port):
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500142 def start_tcp():
143 self.server = self.loop.run_until_complete(
144 asyncio.start_server(self.handle_client, host, port)
145 )
Andrew Geisslerc926e172021-05-07 16:11:35 -0500146
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500147 for s in self.server.sockets:
148 self.logger.debug('Listening on %r' % (s.getsockname(),))
149 # Newer python does this automatically. Do it manually here for
150 # maximum compatibility
151 s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
152 s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
Andrew Geisslerc926e172021-05-07 16:11:35 -0500153
Patrick Williamsdb4c27e2022-08-05 08:10:29 -0500154 # Enable keep alives. This prevents broken client connections
155 # from persisting on the server for long periods of time.
156 s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
157 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
158 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
159 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
160
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500161 name = self.server.sockets[0].getsockname()
162 if self.server.sockets[0].family == socket.AF_INET6:
163 self.address = "[%s]:%d" % (name[0], name[1])
164 else:
165 self.address = "%s:%d" % (name[0], name[1])
166
167 self.start = start_tcp
Andrew Geisslerc926e172021-05-07 16:11:35 -0500168
169 def start_unix_server(self, path):
170 def cleanup():
171 os.unlink(path)
172
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500173 def start_unix():
174 cwd = os.getcwd()
175 try:
176 # Work around path length limits in AF_UNIX
177 os.chdir(os.path.dirname(path))
178 self.server = self.loop.run_until_complete(
179 asyncio.start_unix_server(self.handle_client, os.path.basename(path))
180 )
181 finally:
182 os.chdir(cwd)
Andrew Geisslerc926e172021-05-07 16:11:35 -0500183
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500184 self.logger.debug('Listening on %r' % path)
Andrew Geisslerc926e172021-05-07 16:11:35 -0500185
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500186 self._cleanup_socket = cleanup
187 self.address = "unix://%s" % os.path.abspath(path)
188
189 self.start = start_unix
Andrew Geisslerc926e172021-05-07 16:11:35 -0500190
191 @abc.abstractmethod
192 def accept_client(self, reader, writer):
193 pass
194
195 async def handle_client(self, reader, writer):
196 # writer.transport.set_write_buffer_limits(0)
197 try:
198 client = self.accept_client(reader, writer)
199 await client.process_requests()
200 except Exception as e:
201 import traceback
202 self.logger.error('Error from client: %s' % str(e), exc_info=True)
203 traceback.print_exc()
204 writer.close()
Andrew Geissler09036742021-06-25 14:25:14 -0500205 self.logger.debug('Client disconnected')
Andrew Geisslerc926e172021-05-07 16:11:35 -0500206
207 def run_loop_forever(self):
208 try:
209 self.loop.run_forever()
210 except KeyboardInterrupt:
211 pass
212
213 def signal_handler(self):
Patrick Williams213cb262021-08-07 19:21:33 -0500214 self.logger.debug("Got exit signal")
Andrew Geisslerc926e172021-05-07 16:11:35 -0500215 self.loop.stop()
216
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500217 def _serve_forever(self):
Andrew Geisslerc926e172021-05-07 16:11:35 -0500218 try:
219 self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler)
Patrick Williams213cb262021-08-07 19:21:33 -0500220 signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM])
Andrew Geisslerc926e172021-05-07 16:11:35 -0500221
222 self.run_loop_forever()
223 self.server.close()
224
225 self.loop.run_until_complete(self.server.wait_closed())
Andrew Geissler09036742021-06-25 14:25:14 -0500226 self.logger.debug('Server shutting down')
Andrew Geisslerc926e172021-05-07 16:11:35 -0500227 finally:
Andrew Geisslerc926e172021-05-07 16:11:35 -0500228 if self._cleanup_socket is not None:
229 self._cleanup_socket()
Patrick Williams213cb262021-08-07 19:21:33 -0500230
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500231 def serve_forever(self):
232 """
233 Serve requests in the current process
234 """
235 # Create loop and override any loop that may have existed in
236 # a parent process. It is possible that the usecases of
237 # serve_forever might be constrained enough to allow using
238 # get_event_loop here, but better safe than sorry for now.
239 self.loop = asyncio.new_event_loop()
240 asyncio.set_event_loop(self.loop)
241 self.start()
242 self._serve_forever()
243
Patrick Williams213cb262021-08-07 19:21:33 -0500244 def serve_as_process(self, *, prefunc=None, args=()):
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500245 """
246 Serve requests in a child process
247 """
248 def run(queue):
249 # Create loop and override any loop that may have existed
250 # in a parent process. Without doing this and instead
251 # using get_event_loop, at the very minimum the hashserv
252 # unit tests will hang when running the second test.
253 # This happens since get_event_loop in the spawned server
254 # process for the second testcase ends up with the loop
255 # from the hashserv client created in the unit test process
256 # when running the first testcase. The problem is somewhat
257 # more general, though, as any potential use of asyncio in
258 # Cooker could create a loop that needs to replaced in this
259 # new process.
260 self.loop = asyncio.new_event_loop()
261 asyncio.set_event_loop(self.loop)
262 try:
263 self.start()
264 finally:
265 queue.put(self.address)
266 queue.close()
267
Patrick Williams213cb262021-08-07 19:21:33 -0500268 if prefunc is not None:
269 prefunc(self, *args)
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500270
271 self._serve_forever()
272
273 if sys.version_info >= (3, 6):
274 self.loop.run_until_complete(self.loop.shutdown_asyncgens())
275 self.loop.close()
276
277 queue = multiprocessing.Queue()
Patrick Williams213cb262021-08-07 19:21:33 -0500278
279 # Temporarily block SIGTERM. The server process will inherit this
280 # block which will ensure it doesn't receive the SIGTERM until the
281 # handler is ready for it
282 mask = signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGTERM])
283 try:
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500284 self.process = multiprocessing.Process(target=run, args=(queue,))
Patrick Williams213cb262021-08-07 19:21:33 -0500285 self.process.start()
286
Andrew Geisslerd159c7f2021-09-02 21:05:58 -0500287 self.address = queue.get()
288 queue.close()
289 queue.join_thread()
290
Patrick Williams213cb262021-08-07 19:21:33 -0500291 return self.process
292 finally:
293 signal.pthread_sigmask(signal.SIG_SETMASK, mask)