blob: ef20cb71df43d3182369adb7185743a8d670a830 [file] [log] [blame]
Andrew Geisslerc926e172021-05-07 16:11:35 -05001#
2# SPDX-License-Identifier: GPL-2.0-only
3#
4
5import abc
6import asyncio
7import json
8import os
9import signal
10import socket
11import sys
12from . import chunkify, DEFAULT_MAX_CHUNK
13
14
15class ClientError(Exception):
16 pass
17
18
19class ServerError(Exception):
20 pass
21
22
23class AsyncServerConnection(object):
24 def __init__(self, reader, writer, proto_name, logger):
25 self.reader = reader
26 self.writer = writer
27 self.proto_name = proto_name
28 self.max_chunk = DEFAULT_MAX_CHUNK
29 self.handlers = {
30 'chunk-stream': self.handle_chunk,
Andrew Geissler09036742021-06-25 14:25:14 -050031 'ping': self.handle_ping,
Andrew Geisslerc926e172021-05-07 16:11:35 -050032 }
33 self.logger = logger
34
35 async def process_requests(self):
36 try:
37 self.addr = self.writer.get_extra_info('peername')
38 self.logger.debug('Client %r connected' % (self.addr,))
39
40 # Read protocol and version
41 client_protocol = await self.reader.readline()
42 if client_protocol is None:
43 return
44
45 (client_proto_name, client_proto_version) = client_protocol.decode('utf-8').rstrip().split()
46 if client_proto_name != self.proto_name:
47 self.logger.debug('Rejecting invalid protocol %s' % (self.proto_name))
48 return
49
50 self.proto_version = tuple(int(v) for v in client_proto_version.split('.'))
51 if not self.validate_proto_version():
52 self.logger.debug('Rejecting invalid protocol version %s' % (client_proto_version))
53 return
54
55 # Read headers. Currently, no headers are implemented, so look for
56 # an empty line to signal the end of the headers
57 while True:
58 line = await self.reader.readline()
59 if line is None:
60 return
61
62 line = line.decode('utf-8').rstrip()
63 if not line:
64 break
65
66 # Handle messages
67 while True:
68 d = await self.read_message()
69 if d is None:
70 break
71 await self.dispatch_message(d)
72 await self.writer.drain()
73 except ClientError as e:
74 self.logger.error(str(e))
75 finally:
76 self.writer.close()
77
78 async def dispatch_message(self, msg):
79 for k in self.handlers.keys():
80 if k in msg:
81 self.logger.debug('Handling %s' % k)
82 await self.handlers[k](msg[k])
83 return
84
85 raise ClientError("Unrecognized command %r" % msg)
86
87 def write_message(self, msg):
88 for c in chunkify(json.dumps(msg), self.max_chunk):
89 self.writer.write(c.encode('utf-8'))
90
91 async def read_message(self):
92 l = await self.reader.readline()
93 if not l:
94 return None
95
96 try:
97 message = l.decode('utf-8')
98
99 if not message.endswith('\n'):
100 return None
101
102 return json.loads(message)
103 except (json.JSONDecodeError, UnicodeDecodeError) as e:
104 self.logger.error('Bad message from client: %r' % message)
105 raise e
106
107 async def handle_chunk(self, request):
108 lines = []
109 try:
110 while True:
111 l = await self.reader.readline()
112 l = l.rstrip(b"\n").decode("utf-8")
113 if not l:
114 break
115 lines.append(l)
116
117 msg = json.loads(''.join(lines))
118 except (json.JSONDecodeError, UnicodeDecodeError) as e:
119 self.logger.error('Bad message from client: %r' % lines)
120 raise e
121
122 if 'chunk-stream' in msg:
123 raise ClientError("Nested chunks are not allowed")
124
125 await self.dispatch_message(msg)
126
Andrew Geissler09036742021-06-25 14:25:14 -0500127 async def handle_ping(self, request):
128 response = {'alive': True}
129 self.write_message(response)
130
Andrew Geisslerc926e172021-05-07 16:11:35 -0500131
132class AsyncServer(object):
133 def __init__(self, logger, loop=None):
134 if loop is None:
135 self.loop = asyncio.new_event_loop()
136 self.close_loop = True
137 else:
138 self.loop = loop
139 self.close_loop = False
140
141 self._cleanup_socket = None
142 self.logger = logger
143
144 def start_tcp_server(self, host, port):
145 self.server = self.loop.run_until_complete(
146 asyncio.start_server(self.handle_client, host, port, loop=self.loop)
147 )
148
149 for s in self.server.sockets:
Andrew Geissler09036742021-06-25 14:25:14 -0500150 self.logger.debug('Listening on %r' % (s.getsockname(),))
Andrew Geisslerc926e172021-05-07 16:11:35 -0500151 # Newer python does this automatically. Do it manually here for
152 # maximum compatibility
153 s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
154 s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
155
156 name = self.server.sockets[0].getsockname()
157 if self.server.sockets[0].family == socket.AF_INET6:
158 self.address = "[%s]:%d" % (name[0], name[1])
159 else:
160 self.address = "%s:%d" % (name[0], name[1])
161
162 def start_unix_server(self, path):
163 def cleanup():
164 os.unlink(path)
165
166 cwd = os.getcwd()
167 try:
168 # Work around path length limits in AF_UNIX
169 os.chdir(os.path.dirname(path))
170 self.server = self.loop.run_until_complete(
171 asyncio.start_unix_server(self.handle_client, os.path.basename(path), loop=self.loop)
172 )
173 finally:
174 os.chdir(cwd)
175
Andrew Geissler09036742021-06-25 14:25:14 -0500176 self.logger.debug('Listening on %r' % path)
Andrew Geisslerc926e172021-05-07 16:11:35 -0500177
178 self._cleanup_socket = cleanup
179 self.address = "unix://%s" % os.path.abspath(path)
180
181 @abc.abstractmethod
182 def accept_client(self, reader, writer):
183 pass
184
185 async def handle_client(self, reader, writer):
186 # writer.transport.set_write_buffer_limits(0)
187 try:
188 client = self.accept_client(reader, writer)
189 await client.process_requests()
190 except Exception as e:
191 import traceback
192 self.logger.error('Error from client: %s' % str(e), exc_info=True)
193 traceback.print_exc()
194 writer.close()
Andrew Geissler09036742021-06-25 14:25:14 -0500195 self.logger.debug('Client disconnected')
Andrew Geisslerc926e172021-05-07 16:11:35 -0500196
197 def run_loop_forever(self):
198 try:
199 self.loop.run_forever()
200 except KeyboardInterrupt:
201 pass
202
203 def signal_handler(self):
204 self.loop.stop()
205
206 def serve_forever(self):
207 asyncio.set_event_loop(self.loop)
208 try:
209 self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler)
210
211 self.run_loop_forever()
212 self.server.close()
213
214 self.loop.run_until_complete(self.server.wait_closed())
Andrew Geissler09036742021-06-25 14:25:14 -0500215 self.logger.debug('Server shutting down')
Andrew Geisslerc926e172021-05-07 16:11:35 -0500216 finally:
217 if self.close_loop:
218 if sys.version_info >= (3, 6):
219 self.loop.run_until_complete(self.loop.shutdown_asyncgens())
220 self.loop.close()
221
222 if self._cleanup_socket is not None:
223 self._cleanup_socket()