| #! /usr/bin/env python3 |
| # |
| # Copyright (C) 2023 Garmin Ltd. |
| # |
| # SPDX-License-Identifier: GPL-2.0-only |
| # |
| |
| import logging |
| from datetime import datetime |
| from . import User |
| |
| from sqlalchemy.ext.asyncio import create_async_engine |
| from sqlalchemy.pool import NullPool |
| from sqlalchemy import ( |
| MetaData, |
| Column, |
| Table, |
| Text, |
| Integer, |
| UniqueConstraint, |
| DateTime, |
| Index, |
| select, |
| insert, |
| exists, |
| literal, |
| and_, |
| delete, |
| update, |
| func, |
| ) |
| import sqlalchemy.engine |
| from sqlalchemy.orm import declarative_base |
| from sqlalchemy.exc import IntegrityError |
| |
| Base = declarative_base() |
| |
| |
| class UnihashesV2(Base): |
| __tablename__ = "unihashes_v2" |
| id = Column(Integer, primary_key=True, autoincrement=True) |
| method = Column(Text, nullable=False) |
| taskhash = Column(Text, nullable=False) |
| unihash = Column(Text, nullable=False) |
| |
| __table_args__ = ( |
| UniqueConstraint("method", "taskhash"), |
| Index("taskhash_lookup_v3", "method", "taskhash"), |
| ) |
| |
| |
| class OuthashesV2(Base): |
| __tablename__ = "outhashes_v2" |
| id = Column(Integer, primary_key=True, autoincrement=True) |
| method = Column(Text, nullable=False) |
| taskhash = Column(Text, nullable=False) |
| outhash = Column(Text, nullable=False) |
| created = Column(DateTime) |
| owner = Column(Text) |
| PN = Column(Text) |
| PV = Column(Text) |
| PR = Column(Text) |
| task = Column(Text) |
| outhash_siginfo = Column(Text) |
| |
| __table_args__ = ( |
| UniqueConstraint("method", "taskhash", "outhash"), |
| Index("outhash_lookup_v3", "method", "outhash"), |
| ) |
| |
| |
| class Users(Base): |
| __tablename__ = "users" |
| id = Column(Integer, primary_key=True, autoincrement=True) |
| username = Column(Text, nullable=False) |
| token = Column(Text, nullable=False) |
| permissions = Column(Text) |
| |
| __table_args__ = (UniqueConstraint("username"),) |
| |
| |
| class DatabaseEngine(object): |
| def __init__(self, url, username=None, password=None): |
| self.logger = logging.getLogger("hashserv.sqlalchemy") |
| self.url = sqlalchemy.engine.make_url(url) |
| |
| if username is not None: |
| self.url = self.url.set(username=username) |
| |
| if password is not None: |
| self.url = self.url.set(password=password) |
| |
| async def create(self): |
| self.logger.info("Using database %s", self.url) |
| self.engine = create_async_engine(self.url, poolclass=NullPool) |
| |
| async with self.engine.begin() as conn: |
| # Create tables |
| self.logger.info("Creating tables...") |
| await conn.run_sync(Base.metadata.create_all) |
| |
| def connect(self, logger): |
| return Database(self.engine, logger) |
| |
| |
| def map_row(row): |
| if row is None: |
| return None |
| return dict(**row._mapping) |
| |
| |
| def map_user(row): |
| if row is None: |
| return None |
| return User( |
| username=row.username, |
| permissions=set(row.permissions.split()), |
| ) |
| |
| |
| class Database(object): |
| def __init__(self, engine, logger): |
| self.engine = engine |
| self.db = None |
| self.logger = logger |
| |
| async def __aenter__(self): |
| self.db = await self.engine.connect() |
| return self |
| |
| async def __aexit__(self, exc_type, exc_value, traceback): |
| await self.close() |
| |
| async def close(self): |
| await self.db.close() |
| self.db = None |
| |
| async def get_unihash_by_taskhash_full(self, method, taskhash): |
| statement = ( |
| select( |
| OuthashesV2, |
| UnihashesV2.unihash.label("unihash"), |
| ) |
| .join( |
| UnihashesV2, |
| and_( |
| UnihashesV2.method == OuthashesV2.method, |
| UnihashesV2.taskhash == OuthashesV2.taskhash, |
| ), |
| ) |
| .where( |
| OuthashesV2.method == method, |
| OuthashesV2.taskhash == taskhash, |
| ) |
| .order_by( |
| OuthashesV2.created.asc(), |
| ) |
| .limit(1) |
| ) |
| self.logger.debug("%s", statement) |
| async with self.db.begin(): |
| result = await self.db.execute(statement) |
| return map_row(result.first()) |
| |
| async def get_unihash_by_outhash(self, method, outhash): |
| statement = ( |
| select(OuthashesV2, UnihashesV2.unihash.label("unihash")) |
| .join( |
| UnihashesV2, |
| and_( |
| UnihashesV2.method == OuthashesV2.method, |
| UnihashesV2.taskhash == OuthashesV2.taskhash, |
| ), |
| ) |
| .where( |
| OuthashesV2.method == method, |
| OuthashesV2.outhash == outhash, |
| ) |
| .order_by( |
| OuthashesV2.created.asc(), |
| ) |
| .limit(1) |
| ) |
| self.logger.debug("%s", statement) |
| async with self.db.begin(): |
| result = await self.db.execute(statement) |
| return map_row(result.first()) |
| |
| async def get_outhash(self, method, outhash): |
| statement = ( |
| select(OuthashesV2) |
| .where( |
| OuthashesV2.method == method, |
| OuthashesV2.outhash == outhash, |
| ) |
| .order_by( |
| OuthashesV2.created.asc(), |
| ) |
| .limit(1) |
| ) |
| |
| self.logger.debug("%s", statement) |
| async with self.db.begin(): |
| result = await self.db.execute(statement) |
| return map_row(result.first()) |
| |
| async def get_equivalent_for_outhash(self, method, outhash, taskhash): |
| statement = ( |
| select( |
| OuthashesV2.taskhash.label("taskhash"), |
| UnihashesV2.unihash.label("unihash"), |
| ) |
| .join( |
| UnihashesV2, |
| and_( |
| UnihashesV2.method == OuthashesV2.method, |
| UnihashesV2.taskhash == OuthashesV2.taskhash, |
| ), |
| ) |
| .where( |
| OuthashesV2.method == method, |
| OuthashesV2.outhash == outhash, |
| OuthashesV2.taskhash != taskhash, |
| ) |
| .order_by( |
| OuthashesV2.created.asc(), |
| ) |
| .limit(1) |
| ) |
| self.logger.debug("%s", statement) |
| async with self.db.begin(): |
| result = await self.db.execute(statement) |
| return map_row(result.first()) |
| |
| async def get_equivalent(self, method, taskhash): |
| statement = select( |
| UnihashesV2.unihash, |
| UnihashesV2.method, |
| UnihashesV2.taskhash, |
| ).where( |
| UnihashesV2.method == method, |
| UnihashesV2.taskhash == taskhash, |
| ) |
| self.logger.debug("%s", statement) |
| async with self.db.begin(): |
| result = await self.db.execute(statement) |
| return map_row(result.first()) |
| |
| async def remove(self, condition): |
| async def do_remove(table): |
| where = {} |
| for c in table.__table__.columns: |
| if c.key in condition and condition[c.key] is not None: |
| where[c] = condition[c.key] |
| |
| if where: |
| statement = delete(table).where(*[(k == v) for k, v in where.items()]) |
| self.logger.debug("%s", statement) |
| async with self.db.begin(): |
| result = await self.db.execute(statement) |
| return result.rowcount |
| |
| return 0 |
| |
| count = 0 |
| count += await do_remove(UnihashesV2) |
| count += await do_remove(OuthashesV2) |
| |
| return count |
| |
| async def clean_unused(self, oldest): |
| statement = delete(OuthashesV2).where( |
| OuthashesV2.created < oldest, |
| ~( |
| select(UnihashesV2.id) |
| .where( |
| UnihashesV2.method == OuthashesV2.method, |
| UnihashesV2.taskhash == OuthashesV2.taskhash, |
| ) |
| .limit(1) |
| .exists() |
| ), |
| ) |
| self.logger.debug("%s", statement) |
| async with self.db.begin(): |
| result = await self.db.execute(statement) |
| return result.rowcount |
| |
| async def insert_unihash(self, method, taskhash, unihash): |
| statement = insert(UnihashesV2).values( |
| method=method, |
| taskhash=taskhash, |
| unihash=unihash, |
| ) |
| self.logger.debug("%s", statement) |
| try: |
| async with self.db.begin(): |
| await self.db.execute(statement) |
| return True |
| except IntegrityError: |
| self.logger.debug( |
| "%s, %s, %s already in unihash database", method, taskhash, unihash |
| ) |
| return False |
| |
| async def insert_outhash(self, data): |
| outhash_columns = set(c.key for c in OuthashesV2.__table__.columns) |
| |
| data = {k: v for k, v in data.items() if k in outhash_columns} |
| |
| if "created" in data and not isinstance(data["created"], datetime): |
| data["created"] = datetime.fromisoformat(data["created"]) |
| |
| statement = insert(OuthashesV2).values(**data) |
| self.logger.debug("%s", statement) |
| try: |
| async with self.db.begin(): |
| await self.db.execute(statement) |
| return True |
| except IntegrityError: |
| self.logger.debug( |
| "%s, %s already in outhash database", data["method"], data["outhash"] |
| ) |
| return False |
| |
| async def _get_user(self, username): |
| statement = select( |
| Users.username, |
| Users.permissions, |
| Users.token, |
| ).where( |
| Users.username == username, |
| ) |
| self.logger.debug("%s", statement) |
| async with self.db.begin(): |
| result = await self.db.execute(statement) |
| return result.first() |
| |
| async def lookup_user_token(self, username): |
| row = await self._get_user(username) |
| if not row: |
| return None, None |
| return map_user(row), row.token |
| |
| async def lookup_user(self, username): |
| return map_user(await self._get_user(username)) |
| |
| async def set_user_token(self, username, token): |
| statement = ( |
| update(Users) |
| .where( |
| Users.username == username, |
| ) |
| .values( |
| token=token, |
| ) |
| ) |
| self.logger.debug("%s", statement) |
| async with self.db.begin(): |
| result = await self.db.execute(statement) |
| return result.rowcount != 0 |
| |
| async def set_user_perms(self, username, permissions): |
| statement = ( |
| update(Users) |
| .where(Users.username == username) |
| .values(permissions=" ".join(permissions)) |
| ) |
| self.logger.debug("%s", statement) |
| async with self.db.begin(): |
| result = await self.db.execute(statement) |
| return result.rowcount != 0 |
| |
| async def get_all_users(self): |
| statement = select( |
| Users.username, |
| Users.permissions, |
| ) |
| self.logger.debug("%s", statement) |
| async with self.db.begin(): |
| result = await self.db.execute(statement) |
| return [map_user(row) for row in result] |
| |
| async def new_user(self, username, permissions, token): |
| statement = insert(Users).values( |
| username=username, |
| permissions=" ".join(permissions), |
| token=token, |
| ) |
| self.logger.debug("%s", statement) |
| try: |
| async with self.db.begin(): |
| await self.db.execute(statement) |
| return True |
| except IntegrityError as e: |
| self.logger.debug("Cannot create new user %s: %s", username, e) |
| return False |
| |
| async def delete_user(self, username): |
| statement = delete(Users).where(Users.username == username) |
| self.logger.debug("%s", statement) |
| async with self.db.begin(): |
| result = await self.db.execute(statement) |
| return result.rowcount != 0 |
| |
| async def get_usage(self): |
| usage = {} |
| async with self.db.begin() as session: |
| for name, table in Base.metadata.tables.items(): |
| statement = select(func.count()).select_from(table) |
| self.logger.debug("%s", statement) |
| result = await self.db.execute(statement) |
| usage[name] = { |
| "rows": result.scalar(), |
| } |
| |
| return usage |
| |
| async def get_query_columns(self): |
| columns = set() |
| for table in (UnihashesV2, OuthashesV2): |
| for c in table.__table__.columns: |
| if not isinstance(c.type, Text): |
| continue |
| columns.add(c.key) |
| |
| return list(columns) |