blob: 7f0cf6ba96ec8115df76143ad98e0a5ce2d97a21 [file] [log] [blame]
Patrick Williamsac13d5f2023-11-24 18:59:46 -06001#
2# Copyright BitBake Contributors
3#
4# SPDX-License-Identifier: GPL-2.0-only
5#
6
7import asyncio
8import itertools
9import json
10from datetime import datetime
11from .exceptions import ClientError, ConnectionClosedError
12
13
14# The Python async server defaults to a 64K receive buffer, so we hardcode our
15# maximum chunk size. It would be better if the client and server reported to
16# each other what the maximum chunk sizes were, but that will slow down the
17# connection setup with a round trip delay so I'd rather not do that unless it
18# is necessary
19DEFAULT_MAX_CHUNK = 32 * 1024
20
21
22def chunkify(msg, max_chunk):
23 if len(msg) < max_chunk - 1:
24 yield "".join((msg, "\n"))
25 else:
26 yield "".join((json.dumps({"chunk-stream": None}), "\n"))
27
28 args = [iter(msg)] * (max_chunk - 1)
29 for m in map("".join, itertools.zip_longest(*args, fillvalue="")):
30 yield "".join(itertools.chain(m, "\n"))
31 yield "\n"
32
33
34def json_serialize(obj):
35 if isinstance(obj, datetime):
36 return obj.isoformat()
37 raise TypeError("Type %s not serializeable" % type(obj))
38
39
40class StreamConnection(object):
41 def __init__(self, reader, writer, timeout, max_chunk=DEFAULT_MAX_CHUNK):
42 self.reader = reader
43 self.writer = writer
44 self.timeout = timeout
45 self.max_chunk = max_chunk
46
47 @property
48 def address(self):
49 return self.writer.get_extra_info("peername")
50
51 async def send_message(self, msg):
52 for c in chunkify(json.dumps(msg, default=json_serialize), self.max_chunk):
53 self.writer.write(c.encode("utf-8"))
54 await self.writer.drain()
55
56 async def recv_message(self):
57 l = await self.recv()
58
59 m = json.loads(l)
60 if not m:
61 return m
62
63 if "chunk-stream" in m:
64 lines = []
65 while True:
66 l = await self.recv()
67 if not l:
68 break
69 lines.append(l)
70
71 m = json.loads("".join(lines))
72
73 return m
74
75 async def send(self, msg):
76 self.writer.write(("%s\n" % msg).encode("utf-8"))
77 await self.writer.drain()
78
79 async def recv(self):
80 if self.timeout < 0:
81 line = await self.reader.readline()
82 else:
83 try:
84 line = await asyncio.wait_for(self.reader.readline(), self.timeout)
85 except asyncio.TimeoutError:
86 raise ConnectionError("Timed out waiting for data")
87
88 if not line:
89 raise ConnectionClosedError("Connection closed")
90
91 line = line.decode("utf-8")
92
93 if not line.endswith("\n"):
94 raise ConnectionError("Bad message %r" % (line))
95
96 return line.rstrip()
97
98 async def close(self):
99 self.reader = None
100 if self.writer is not None:
101 self.writer.close()
102 self.writer = None
103
104
105class WebsocketConnection(object):
106 def __init__(self, socket, timeout):
107 self.socket = socket
108 self.timeout = timeout
109
110 @property
111 def address(self):
112 return ":".join(str(s) for s in self.socket.remote_address)
113
114 async def send_message(self, msg):
115 await self.send(json.dumps(msg, default=json_serialize))
116
117 async def recv_message(self):
118 m = await self.recv()
119 return json.loads(m)
120
121 async def send(self, msg):
122 import websockets.exceptions
123
124 try:
125 await self.socket.send(msg)
126 except websockets.exceptions.ConnectionClosed:
127 raise ConnectionClosedError("Connection closed")
128
129 async def recv(self):
130 import websockets.exceptions
131
132 try:
133 if self.timeout < 0:
134 return await self.socket.recv()
135
136 try:
137 return await asyncio.wait_for(self.socket.recv(), self.timeout)
138 except asyncio.TimeoutError:
139 raise ConnectionError("Timed out waiting for data")
140 except websockets.exceptions.ConnectionClosed:
141 raise ConnectionClosedError("Connection closed")
142
143 async def close(self):
144 if self.socket is not None:
145 await self.socket.close()
146 self.socket = None