| # Copyright (C) 2019 Garmin Ltd. |
| # |
| # SPDX-License-Identifier: GPL-2.0-only |
| # |
| |
| import asyncio |
| import json |
| import logging |
| import socket |
| import os |
| from . import chunkify, DEFAULT_MAX_CHUNK, create_async_client |
| |
| |
| logger = logging.getLogger("hashserv.client") |
| |
| |
| class HashConnectionError(Exception): |
| pass |
| |
| |
| class AsyncClient(object): |
| MODE_NORMAL = 0 |
| MODE_GET_STREAM = 1 |
| |
| def __init__(self): |
| self.reader = None |
| self.writer = None |
| self.mode = self.MODE_NORMAL |
| self.max_chunk = DEFAULT_MAX_CHUNK |
| |
| async def connect_tcp(self, address, port): |
| async def connect_sock(): |
| return await asyncio.open_connection(address, port) |
| |
| self._connect_sock = connect_sock |
| |
| async def connect_unix(self, path): |
| async def connect_sock(): |
| return await asyncio.open_unix_connection(path) |
| |
| self._connect_sock = connect_sock |
| |
| async def _connect(self): |
| if self.reader is None or self.writer is None: |
| (self.reader, self.writer) = await self._connect_sock() |
| |
| self.writer.write("OEHASHEQUIV 1.1\n\n".encode("utf-8")) |
| await self.writer.drain() |
| |
| cur_mode = self.mode |
| self.mode = self.MODE_NORMAL |
| await self._set_mode(cur_mode) |
| |
| async def close(self): |
| self.reader = None |
| |
| if self.writer is not None: |
| self.writer.close() |
| self.writer = None |
| |
| async def _send_wrapper(self, proc): |
| count = 0 |
| while True: |
| try: |
| await self._connect() |
| return await proc() |
| except ( |
| OSError, |
| HashConnectionError, |
| json.JSONDecodeError, |
| UnicodeDecodeError, |
| ) as e: |
| logger.warning("Error talking to server: %s" % e) |
| if count >= 3: |
| if not isinstance(e, HashConnectionError): |
| raise HashConnectionError(str(e)) |
| raise e |
| await self.close() |
| count += 1 |
| |
| async def send_message(self, msg): |
| async def get_line(): |
| line = await self.reader.readline() |
| if not line: |
| raise HashConnectionError("Connection closed") |
| |
| line = line.decode("utf-8") |
| |
| if not line.endswith("\n"): |
| raise HashConnectionError("Bad message %r" % message) |
| |
| return line |
| |
| async def proc(): |
| for c in chunkify(json.dumps(msg), self.max_chunk): |
| self.writer.write(c.encode("utf-8")) |
| await self.writer.drain() |
| |
| l = await get_line() |
| |
| m = json.loads(l) |
| if "chunk-stream" in m: |
| lines = [] |
| while True: |
| l = (await get_line()).rstrip("\n") |
| if not l: |
| break |
| lines.append(l) |
| |
| m = json.loads("".join(lines)) |
| |
| return m |
| |
| return await self._send_wrapper(proc) |
| |
| async def send_stream(self, msg): |
| async def proc(): |
| self.writer.write(("%s\n" % msg).encode("utf-8")) |
| await self.writer.drain() |
| l = await self.reader.readline() |
| if not l: |
| raise HashConnectionError("Connection closed") |
| return l.decode("utf-8").rstrip() |
| |
| return await self._send_wrapper(proc) |
| |
| async def _set_mode(self, new_mode): |
| if new_mode == self.MODE_NORMAL and self.mode == self.MODE_GET_STREAM: |
| r = await self.send_stream("END") |
| if r != "ok": |
| raise HashConnectionError("Bad response from server %r" % r) |
| elif new_mode == self.MODE_GET_STREAM and self.mode == self.MODE_NORMAL: |
| r = await self.send_message({"get-stream": None}) |
| if r != "ok": |
| raise HashConnectionError("Bad response from server %r" % r) |
| elif new_mode != self.mode: |
| raise Exception( |
| "Undefined mode transition %r -> %r" % (self.mode, new_mode) |
| ) |
| |
| self.mode = new_mode |
| |
| async def get_unihash(self, method, taskhash): |
| await self._set_mode(self.MODE_GET_STREAM) |
| r = await self.send_stream("%s %s" % (method, taskhash)) |
| if not r: |
| return None |
| return r |
| |
| async def report_unihash(self, taskhash, method, outhash, unihash, extra={}): |
| await self._set_mode(self.MODE_NORMAL) |
| m = extra.copy() |
| m["taskhash"] = taskhash |
| m["method"] = method |
| m["outhash"] = outhash |
| m["unihash"] = unihash |
| return await self.send_message({"report": m}) |
| |
| async def report_unihash_equiv(self, taskhash, method, unihash, extra={}): |
| await self._set_mode(self.MODE_NORMAL) |
| m = extra.copy() |
| m["taskhash"] = taskhash |
| m["method"] = method |
| m["unihash"] = unihash |
| return await self.send_message({"report-equiv": m}) |
| |
| async def get_taskhash(self, method, taskhash, all_properties=False): |
| await self._set_mode(self.MODE_NORMAL) |
| return await self.send_message( |
| {"get": {"taskhash": taskhash, "method": method, "all": all_properties}} |
| ) |
| |
| async def get_stats(self): |
| await self._set_mode(self.MODE_NORMAL) |
| return await self.send_message({"get-stats": None}) |
| |
| async def reset_stats(self): |
| await self._set_mode(self.MODE_NORMAL) |
| return await self.send_message({"reset-stats": None}) |
| |
| async def backfill_wait(self): |
| await self._set_mode(self.MODE_NORMAL) |
| return (await self.send_message({"backfill-wait": None}))["tasks"] |
| |
| |
| class Client(object): |
| def __init__(self): |
| self.client = AsyncClient() |
| self.loop = asyncio.new_event_loop() |
| |
| for call in ( |
| "connect_tcp", |
| "connect_unix", |
| "close", |
| "get_unihash", |
| "report_unihash", |
| "report_unihash_equiv", |
| "get_taskhash", |
| "get_stats", |
| "reset_stats", |
| "backfill_wait", |
| ): |
| downcall = getattr(self.client, call) |
| setattr(self, call, self._get_downcall_wrapper(downcall)) |
| |
| def _get_downcall_wrapper(self, downcall): |
| def wrapper(*args, **kwargs): |
| return self.loop.run_until_complete(downcall(*args, **kwargs)) |
| |
| return wrapper |
| |
| @property |
| def max_chunk(self): |
| return self.client.max_chunk |
| |
| @max_chunk.setter |
| def max_chunk(self, value): |
| self.client.max_chunk = value |