blob: cb3384639d7f58f44dae58999dcc74cf6d944ca6 [file] [log] [blame]
Andrew Geisslerc926e172021-05-07 16:11:35 -05001#
2# SPDX-License-Identifier: GPL-2.0-only
3#
4
5import abc
6import asyncio
7import json
8import os
9import signal
10import socket
11import sys
12from . import chunkify, DEFAULT_MAX_CHUNK
13
14
15class ClientError(Exception):
16 pass
17
18
19class ServerError(Exception):
20 pass
21
22
23class AsyncServerConnection(object):
24 def __init__(self, reader, writer, proto_name, logger):
25 self.reader = reader
26 self.writer = writer
27 self.proto_name = proto_name
28 self.max_chunk = DEFAULT_MAX_CHUNK
29 self.handlers = {
30 'chunk-stream': self.handle_chunk,
31 }
32 self.logger = logger
33
34 async def process_requests(self):
35 try:
36 self.addr = self.writer.get_extra_info('peername')
37 self.logger.debug('Client %r connected' % (self.addr,))
38
39 # Read protocol and version
40 client_protocol = await self.reader.readline()
41 if client_protocol is None:
42 return
43
44 (client_proto_name, client_proto_version) = client_protocol.decode('utf-8').rstrip().split()
45 if client_proto_name != self.proto_name:
46 self.logger.debug('Rejecting invalid protocol %s' % (self.proto_name))
47 return
48
49 self.proto_version = tuple(int(v) for v in client_proto_version.split('.'))
50 if not self.validate_proto_version():
51 self.logger.debug('Rejecting invalid protocol version %s' % (client_proto_version))
52 return
53
54 # Read headers. Currently, no headers are implemented, so look for
55 # an empty line to signal the end of the headers
56 while True:
57 line = await self.reader.readline()
58 if line is None:
59 return
60
61 line = line.decode('utf-8').rstrip()
62 if not line:
63 break
64
65 # Handle messages
66 while True:
67 d = await self.read_message()
68 if d is None:
69 break
70 await self.dispatch_message(d)
71 await self.writer.drain()
72 except ClientError as e:
73 self.logger.error(str(e))
74 finally:
75 self.writer.close()
76
77 async def dispatch_message(self, msg):
78 for k in self.handlers.keys():
79 if k in msg:
80 self.logger.debug('Handling %s' % k)
81 await self.handlers[k](msg[k])
82 return
83
84 raise ClientError("Unrecognized command %r" % msg)
85
86 def write_message(self, msg):
87 for c in chunkify(json.dumps(msg), self.max_chunk):
88 self.writer.write(c.encode('utf-8'))
89
90 async def read_message(self):
91 l = await self.reader.readline()
92 if not l:
93 return None
94
95 try:
96 message = l.decode('utf-8')
97
98 if not message.endswith('\n'):
99 return None
100
101 return json.loads(message)
102 except (json.JSONDecodeError, UnicodeDecodeError) as e:
103 self.logger.error('Bad message from client: %r' % message)
104 raise e
105
106 async def handle_chunk(self, request):
107 lines = []
108 try:
109 while True:
110 l = await self.reader.readline()
111 l = l.rstrip(b"\n").decode("utf-8")
112 if not l:
113 break
114 lines.append(l)
115
116 msg = json.loads(''.join(lines))
117 except (json.JSONDecodeError, UnicodeDecodeError) as e:
118 self.logger.error('Bad message from client: %r' % lines)
119 raise e
120
121 if 'chunk-stream' in msg:
122 raise ClientError("Nested chunks are not allowed")
123
124 await self.dispatch_message(msg)
125
126
127class AsyncServer(object):
128 def __init__(self, logger, loop=None):
129 if loop is None:
130 self.loop = asyncio.new_event_loop()
131 self.close_loop = True
132 else:
133 self.loop = loop
134 self.close_loop = False
135
136 self._cleanup_socket = None
137 self.logger = logger
138
139 def start_tcp_server(self, host, port):
140 self.server = self.loop.run_until_complete(
141 asyncio.start_server(self.handle_client, host, port, loop=self.loop)
142 )
143
144 for s in self.server.sockets:
145 self.logger.info('Listening on %r' % (s.getsockname(),))
146 # Newer python does this automatically. Do it manually here for
147 # maximum compatibility
148 s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
149 s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
150
151 name = self.server.sockets[0].getsockname()
152 if self.server.sockets[0].family == socket.AF_INET6:
153 self.address = "[%s]:%d" % (name[0], name[1])
154 else:
155 self.address = "%s:%d" % (name[0], name[1])
156
157 def start_unix_server(self, path):
158 def cleanup():
159 os.unlink(path)
160
161 cwd = os.getcwd()
162 try:
163 # Work around path length limits in AF_UNIX
164 os.chdir(os.path.dirname(path))
165 self.server = self.loop.run_until_complete(
166 asyncio.start_unix_server(self.handle_client, os.path.basename(path), loop=self.loop)
167 )
168 finally:
169 os.chdir(cwd)
170
171 self.logger.info('Listening on %r' % path)
172
173 self._cleanup_socket = cleanup
174 self.address = "unix://%s" % os.path.abspath(path)
175
176 @abc.abstractmethod
177 def accept_client(self, reader, writer):
178 pass
179
180 async def handle_client(self, reader, writer):
181 # writer.transport.set_write_buffer_limits(0)
182 try:
183 client = self.accept_client(reader, writer)
184 await client.process_requests()
185 except Exception as e:
186 import traceback
187 self.logger.error('Error from client: %s' % str(e), exc_info=True)
188 traceback.print_exc()
189 writer.close()
190 self.logger.info('Client disconnected')
191
192 def run_loop_forever(self):
193 try:
194 self.loop.run_forever()
195 except KeyboardInterrupt:
196 pass
197
198 def signal_handler(self):
199 self.loop.stop()
200
201 def serve_forever(self):
202 asyncio.set_event_loop(self.loop)
203 try:
204 self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler)
205
206 self.run_loop_forever()
207 self.server.close()
208
209 self.loop.run_until_complete(self.server.wait_closed())
210 self.logger.info('Server shutting down')
211 finally:
212 if self.close_loop:
213 if sys.version_info >= (3, 6):
214 self.loop.run_until_complete(self.loop.shutdown_asyncgens())
215 self.loop.close()
216
217 if self._cleanup_socket is not None:
218 self._cleanup_socket()