blob: 4cdad9ac3c054af794a4d617e495c83879c556fd [file] [log] [blame]
Andrew Geisslerc926e172021-05-07 16:11:35 -05001#
2# SPDX-License-Identifier: GPL-2.0-only
3#
4
5import abc
6import asyncio
7import json
8import os
9import socket
10from . import chunkify, DEFAULT_MAX_CHUNK
11
12
13class AsyncClient(object):
14 def __init__(self, proto_name, proto_version, logger):
15 self.reader = None
16 self.writer = None
17 self.max_chunk = DEFAULT_MAX_CHUNK
18 self.proto_name = proto_name
19 self.proto_version = proto_version
20 self.logger = logger
21
22 async def connect_tcp(self, address, port):
23 async def connect_sock():
24 return await asyncio.open_connection(address, port)
25
26 self._connect_sock = connect_sock
27
28 async def connect_unix(self, path):
29 async def connect_sock():
30 return await asyncio.open_unix_connection(path)
31
32 self._connect_sock = connect_sock
33
34 async def setup_connection(self):
35 s = '%s %s\n\n' % (self.proto_name, self.proto_version)
36 self.writer.write(s.encode("utf-8"))
37 await self.writer.drain()
38
39 async def connect(self):
40 if self.reader is None or self.writer is None:
41 (self.reader, self.writer) = await self._connect_sock()
42 await self.setup_connection()
43
44 async def close(self):
45 self.reader = None
46
47 if self.writer is not None:
48 self.writer.close()
49 self.writer = None
50
51 async def _send_wrapper(self, proc):
52 count = 0
53 while True:
54 try:
55 await self.connect()
56 return await proc()
57 except (
58 OSError,
59 ConnectionError,
60 json.JSONDecodeError,
61 UnicodeDecodeError,
62 ) as e:
63 self.logger.warning("Error talking to server: %s" % e)
64 if count >= 3:
65 if not isinstance(e, ConnectionError):
66 raise ConnectionError(str(e))
67 raise e
68 await self.close()
69 count += 1
70
71 async def send_message(self, msg):
72 async def get_line():
73 line = await self.reader.readline()
74 if not line:
75 raise ConnectionError("Connection closed")
76
77 line = line.decode("utf-8")
78
79 if not line.endswith("\n"):
80 raise ConnectionError("Bad message %r" % msg)
81
82 return line
83
84 async def proc():
85 for c in chunkify(json.dumps(msg), self.max_chunk):
86 self.writer.write(c.encode("utf-8"))
87 await self.writer.drain()
88
89 l = await get_line()
90
91 m = json.loads(l)
92 if m and "chunk-stream" in m:
93 lines = []
94 while True:
95 l = (await get_line()).rstrip("\n")
96 if not l:
97 break
98 lines.append(l)
99
100 m = json.loads("".join(lines))
101
102 return m
103
104 return await self._send_wrapper(proc)
105
106
107class Client(object):
108 def __init__(self):
109 self.client = self._get_async_client()
110 self.loop = asyncio.new_event_loop()
111
112 self._add_methods('connect_tcp', 'close')
113
114 @abc.abstractmethod
115 def _get_async_client(self):
116 pass
117
118 def _get_downcall_wrapper(self, downcall):
119 def wrapper(*args, **kwargs):
120 return self.loop.run_until_complete(downcall(*args, **kwargs))
121
122 return wrapper
123
124 def _add_methods(self, *methods):
125 for m in methods:
126 downcall = getattr(self.client, m)
127 setattr(self, m, self._get_downcall_wrapper(downcall))
128
129 def connect_unix(self, path):
130 # AF_UNIX has path length issues so chdir here to workaround
131 cwd = os.getcwd()
132 try:
133 os.chdir(os.path.dirname(path))
134 self.loop.run_until_complete(self.client.connect_unix(os.path.basename(path)))
135 self.loop.run_until_complete(self.client.connect())
136 finally:
137 os.chdir(cwd)
138
139 @property
140 def max_chunk(self):
141 return self.client.max_chunk
142
143 @max_chunk.setter
144 def max_chunk(self, value):
145 self.client.max_chunk = value