| #! /usr/bin/env python3 |
| # |
| # Copyright (C) 2023 Garmin Ltd. |
| # |
| # SPDX-License-Identifier: GPL-2.0-only |
| # |
| |
| import logging |
| from datetime import datetime |
| from . import User |
| |
| from sqlalchemy.ext.asyncio import create_async_engine |
| from sqlalchemy.pool import NullPool |
| from sqlalchemy import ( |
| MetaData, |
| Column, |
| Table, |
| Text, |
| Integer, |
| UniqueConstraint, |
| DateTime, |
| Index, |
| select, |
| insert, |
| exists, |
| literal, |
| and_, |
| delete, |
| update, |
| func, |
| inspect, |
| ) |
| import sqlalchemy.engine |
| from sqlalchemy.orm import declarative_base |
| from sqlalchemy.exc import IntegrityError |
| from sqlalchemy.dialects.postgresql import insert as postgres_insert |
| |
| Base = declarative_base() |
| |
| |
| class UnihashesV3(Base): |
| __tablename__ = "unihashes_v3" |
| id = Column(Integer, primary_key=True, autoincrement=True) |
| method = Column(Text, nullable=False) |
| taskhash = Column(Text, nullable=False) |
| unihash = Column(Text, nullable=False) |
| gc_mark = Column(Text, nullable=False) |
| |
| __table_args__ = ( |
| UniqueConstraint("method", "taskhash"), |
| Index("taskhash_lookup_v4", "method", "taskhash"), |
| Index("unihash_lookup_v1", "unihash"), |
| ) |
| |
| |
| class OuthashesV2(Base): |
| __tablename__ = "outhashes_v2" |
| id = Column(Integer, primary_key=True, autoincrement=True) |
| method = Column(Text, nullable=False) |
| taskhash = Column(Text, nullable=False) |
| outhash = Column(Text, nullable=False) |
| created = Column(DateTime) |
| owner = Column(Text) |
| PN = Column(Text) |
| PV = Column(Text) |
| PR = Column(Text) |
| task = Column(Text) |
| outhash_siginfo = Column(Text) |
| |
| __table_args__ = ( |
| UniqueConstraint("method", "taskhash", "outhash"), |
| Index("outhash_lookup_v3", "method", "outhash"), |
| ) |
| |
| |
| class Users(Base): |
| __tablename__ = "users" |
| id = Column(Integer, primary_key=True, autoincrement=True) |
| username = Column(Text, nullable=False) |
| token = Column(Text, nullable=False) |
| permissions = Column(Text) |
| |
| __table_args__ = (UniqueConstraint("username"),) |
| |
| |
| class Config(Base): |
| __tablename__ = "config" |
| id = Column(Integer, primary_key=True, autoincrement=True) |
| name = Column(Text, nullable=False) |
| value = Column(Text) |
| __table_args__ = ( |
| UniqueConstraint("name"), |
| Index("config_lookup", "name"), |
| ) |
| |
| |
| # |
| # Old table versions |
| # |
| DeprecatedBase = declarative_base() |
| |
| |
| class UnihashesV2(DeprecatedBase): |
| __tablename__ = "unihashes_v2" |
| id = Column(Integer, primary_key=True, autoincrement=True) |
| method = Column(Text, nullable=False) |
| taskhash = Column(Text, nullable=False) |
| unihash = Column(Text, nullable=False) |
| |
| __table_args__ = ( |
| UniqueConstraint("method", "taskhash"), |
| Index("taskhash_lookup_v3", "method", "taskhash"), |
| ) |
| |
| |
| class DatabaseEngine(object): |
| def __init__(self, url, username=None, password=None): |
| self.logger = logging.getLogger("hashserv.sqlalchemy") |
| self.url = sqlalchemy.engine.make_url(url) |
| |
| if username is not None: |
| self.url = self.url.set(username=username) |
| |
| if password is not None: |
| self.url = self.url.set(password=password) |
| |
| async def create(self): |
| def check_table_exists(conn, name): |
| return inspect(conn).has_table(name) |
| |
| self.logger.info("Using database %s", self.url) |
| if self.url.drivername == 'postgresql+psycopg': |
| # Psygopg 3 (psygopg) driver can handle async connection pooling |
| self.engine = create_async_engine(self.url, max_overflow=-1) |
| else: |
| self.engine = create_async_engine(self.url, poolclass=NullPool) |
| |
| async with self.engine.begin() as conn: |
| # Create tables |
| self.logger.info("Creating tables...") |
| await conn.run_sync(Base.metadata.create_all) |
| |
| if await conn.run_sync(check_table_exists, UnihashesV2.__tablename__): |
| self.logger.info("Upgrading Unihashes V2 -> V3...") |
| statement = insert(UnihashesV3).from_select( |
| ["id", "method", "unihash", "taskhash", "gc_mark"], |
| select( |
| UnihashesV2.id, |
| UnihashesV2.method, |
| UnihashesV2.unihash, |
| UnihashesV2.taskhash, |
| literal("").label("gc_mark"), |
| ), |
| ) |
| self.logger.debug("%s", statement) |
| await conn.execute(statement) |
| |
| await conn.run_sync(Base.metadata.drop_all, [UnihashesV2.__table__]) |
| self.logger.info("Upgrade complete") |
| |
| def connect(self, logger): |
| return Database(self.engine, logger) |
| |
| |
| def map_row(row): |
| if row is None: |
| return None |
| return dict(**row._mapping) |
| |
| |
| def map_user(row): |
| if row is None: |
| return None |
| return User( |
| username=row.username, |
| permissions=set(row.permissions.split()), |
| ) |
| |
| |
| def _make_condition_statement(table, condition): |
| where = {} |
| for c in table.__table__.columns: |
| if c.key in condition and condition[c.key] is not None: |
| where[c] = condition[c.key] |
| |
| return [(k == v) for k, v in where.items()] |
| |
| |
| class Database(object): |
| def __init__(self, engine, logger): |
| self.engine = engine |
| self.db = None |
| self.logger = logger |
| |
| async def __aenter__(self): |
| self.db = await self.engine.connect() |
| return self |
| |
| async def __aexit__(self, exc_type, exc_value, traceback): |
| await self.close() |
| |
| async def close(self): |
| await self.db.close() |
| self.db = None |
| |
| async def _execute(self, statement): |
| self.logger.debug("%s", statement) |
| return await self.db.execute(statement) |
| |
| async def _set_config(self, name, value): |
| while True: |
| result = await self._execute( |
| update(Config).where(Config.name == name).values(value=value) |
| ) |
| |
| if result.rowcount == 0: |
| self.logger.debug("Config '%s' not found. Adding it", name) |
| try: |
| await self._execute(insert(Config).values(name=name, value=value)) |
| except IntegrityError: |
| # Race. Try again |
| continue |
| |
| break |
| |
| def _get_config_subquery(self, name, default=None): |
| if default is not None: |
| return func.coalesce( |
| select(Config.value).where(Config.name == name).scalar_subquery(), |
| default, |
| ) |
| return select(Config.value).where(Config.name == name).scalar_subquery() |
| |
| async def _get_config(self, name): |
| result = await self._execute(select(Config.value).where(Config.name == name)) |
| row = result.first() |
| if row is None: |
| return None |
| return row.value |
| |
| async def get_unihash_by_taskhash_full(self, method, taskhash): |
| async with self.db.begin(): |
| result = await self._execute( |
| select( |
| OuthashesV2, |
| UnihashesV3.unihash.label("unihash"), |
| ) |
| .join( |
| UnihashesV3, |
| and_( |
| UnihashesV3.method == OuthashesV2.method, |
| UnihashesV3.taskhash == OuthashesV2.taskhash, |
| ), |
| ) |
| .where( |
| OuthashesV2.method == method, |
| OuthashesV2.taskhash == taskhash, |
| ) |
| .order_by( |
| OuthashesV2.created.asc(), |
| ) |
| .limit(1) |
| ) |
| return map_row(result.first()) |
| |
| async def get_unihash_by_outhash(self, method, outhash): |
| async with self.db.begin(): |
| result = await self._execute( |
| select(OuthashesV2, UnihashesV3.unihash.label("unihash")) |
| .join( |
| UnihashesV3, |
| and_( |
| UnihashesV3.method == OuthashesV2.method, |
| UnihashesV3.taskhash == OuthashesV2.taskhash, |
| ), |
| ) |
| .where( |
| OuthashesV2.method == method, |
| OuthashesV2.outhash == outhash, |
| ) |
| .order_by( |
| OuthashesV2.created.asc(), |
| ) |
| .limit(1) |
| ) |
| return map_row(result.first()) |
| |
| async def unihash_exists(self, unihash): |
| async with self.db.begin(): |
| result = await self._execute( |
| select(UnihashesV3).where(UnihashesV3.unihash == unihash).limit(1) |
| ) |
| |
| return result.first() is not None |
| |
| async def get_outhash(self, method, outhash): |
| async with self.db.begin(): |
| result = await self._execute( |
| select(OuthashesV2) |
| .where( |
| OuthashesV2.method == method, |
| OuthashesV2.outhash == outhash, |
| ) |
| .order_by( |
| OuthashesV2.created.asc(), |
| ) |
| .limit(1) |
| ) |
| return map_row(result.first()) |
| |
| async def get_equivalent_for_outhash(self, method, outhash, taskhash): |
| async with self.db.begin(): |
| result = await self._execute( |
| select( |
| OuthashesV2.taskhash.label("taskhash"), |
| UnihashesV3.unihash.label("unihash"), |
| ) |
| .join( |
| UnihashesV3, |
| and_( |
| UnihashesV3.method == OuthashesV2.method, |
| UnihashesV3.taskhash == OuthashesV2.taskhash, |
| ), |
| ) |
| .where( |
| OuthashesV2.method == method, |
| OuthashesV2.outhash == outhash, |
| OuthashesV2.taskhash != taskhash, |
| ) |
| .order_by( |
| OuthashesV2.created.asc(), |
| ) |
| .limit(1) |
| ) |
| return map_row(result.first()) |
| |
| async def get_equivalent(self, method, taskhash): |
| async with self.db.begin(): |
| result = await self._execute( |
| select( |
| UnihashesV3.unihash, |
| UnihashesV3.method, |
| UnihashesV3.taskhash, |
| ).where( |
| UnihashesV3.method == method, |
| UnihashesV3.taskhash == taskhash, |
| ) |
| ) |
| return map_row(result.first()) |
| |
| async def remove(self, condition): |
| async def do_remove(table): |
| where = _make_condition_statement(table, condition) |
| if where: |
| async with self.db.begin(): |
| result = await self._execute(delete(table).where(*where)) |
| return result.rowcount |
| |
| return 0 |
| |
| count = 0 |
| count += await do_remove(UnihashesV3) |
| count += await do_remove(OuthashesV2) |
| |
| return count |
| |
| async def get_current_gc_mark(self): |
| async with self.db.begin(): |
| return await self._get_config("gc-mark") |
| |
| async def gc_status(self): |
| async with self.db.begin(): |
| gc_mark_subquery = self._get_config_subquery("gc-mark", "") |
| |
| result = await self._execute( |
| select(func.count()) |
| .select_from(UnihashesV3) |
| .where(UnihashesV3.gc_mark == gc_mark_subquery) |
| ) |
| keep_rows = result.scalar() |
| |
| result = await self._execute( |
| select(func.count()) |
| .select_from(UnihashesV3) |
| .where(UnihashesV3.gc_mark != gc_mark_subquery) |
| ) |
| remove_rows = result.scalar() |
| |
| return (keep_rows, remove_rows, await self._get_config("gc-mark")) |
| |
| async def gc_mark(self, mark, condition): |
| async with self.db.begin(): |
| await self._set_config("gc-mark", mark) |
| |
| where = _make_condition_statement(UnihashesV3, condition) |
| if not where: |
| return 0 |
| |
| result = await self._execute( |
| update(UnihashesV3) |
| .values(gc_mark=self._get_config_subquery("gc-mark", "")) |
| .where(*where) |
| ) |
| return result.rowcount |
| |
| async def gc_sweep(self): |
| async with self.db.begin(): |
| result = await self._execute( |
| delete(UnihashesV3).where( |
| # A sneaky conditional that provides some errant use |
| # protection: If the config mark is NULL, this will not |
| # match any rows because No default is specified in the |
| # select statement |
| UnihashesV3.gc_mark |
| != self._get_config_subquery("gc-mark") |
| ) |
| ) |
| await self._set_config("gc-mark", None) |
| |
| return result.rowcount |
| |
| async def clean_unused(self, oldest): |
| async with self.db.begin(): |
| result = await self._execute( |
| delete(OuthashesV2).where( |
| OuthashesV2.created < oldest, |
| ~( |
| select(UnihashesV3.id) |
| .where( |
| UnihashesV3.method == OuthashesV2.method, |
| UnihashesV3.taskhash == OuthashesV2.taskhash, |
| ) |
| .limit(1) |
| .exists() |
| ), |
| ) |
| ) |
| return result.rowcount |
| |
| async def insert_unihash(self, method, taskhash, unihash): |
| # Postgres specific ignore on insert duplicate |
| if self.engine.name == "postgresql": |
| statement = ( |
| postgres_insert(UnihashesV3) |
| .values( |
| method=method, |
| taskhash=taskhash, |
| unihash=unihash, |
| gc_mark=self._get_config_subquery("gc-mark", ""), |
| ) |
| .on_conflict_do_nothing(index_elements=("method", "taskhash")) |
| ) |
| else: |
| statement = insert(UnihashesV3).values( |
| method=method, |
| taskhash=taskhash, |
| unihash=unihash, |
| gc_mark=self._get_config_subquery("gc-mark", ""), |
| ) |
| |
| try: |
| async with self.db.begin(): |
| result = await self._execute(statement) |
| return result.rowcount != 0 |
| except IntegrityError: |
| self.logger.debug( |
| "%s, %s, %s already in unihash database", method, taskhash, unihash |
| ) |
| return False |
| |
| async def insert_outhash(self, data): |
| outhash_columns = set(c.key for c in OuthashesV2.__table__.columns) |
| |
| data = {k: v for k, v in data.items() if k in outhash_columns} |
| |
| if "created" in data and not isinstance(data["created"], datetime): |
| data["created"] = datetime.fromisoformat(data["created"]) |
| |
| # Postgres specific ignore on insert duplicate |
| if self.engine.name == "postgresql": |
| statement = ( |
| postgres_insert(OuthashesV2) |
| .values(**data) |
| .on_conflict_do_nothing( |
| index_elements=("method", "taskhash", "outhash") |
| ) |
| ) |
| else: |
| statement = insert(OuthashesV2).values(**data) |
| |
| try: |
| async with self.db.begin(): |
| result = await self._execute(statement) |
| return result.rowcount != 0 |
| except IntegrityError: |
| self.logger.debug( |
| "%s, %s already in outhash database", data["method"], data["outhash"] |
| ) |
| return False |
| |
| async def _get_user(self, username): |
| async with self.db.begin(): |
| result = await self._execute( |
| select( |
| Users.username, |
| Users.permissions, |
| Users.token, |
| ).where( |
| Users.username == username, |
| ) |
| ) |
| return result.first() |
| |
| async def lookup_user_token(self, username): |
| row = await self._get_user(username) |
| if not row: |
| return None, None |
| return map_user(row), row.token |
| |
| async def lookup_user(self, username): |
| return map_user(await self._get_user(username)) |
| |
| async def set_user_token(self, username, token): |
| async with self.db.begin(): |
| result = await self._execute( |
| update(Users) |
| .where( |
| Users.username == username, |
| ) |
| .values( |
| token=token, |
| ) |
| ) |
| return result.rowcount != 0 |
| |
| async def set_user_perms(self, username, permissions): |
| async with self.db.begin(): |
| result = await self._execute( |
| update(Users) |
| .where(Users.username == username) |
| .values(permissions=" ".join(permissions)) |
| ) |
| return result.rowcount != 0 |
| |
| async def get_all_users(self): |
| async with self.db.begin(): |
| result = await self._execute( |
| select( |
| Users.username, |
| Users.permissions, |
| ) |
| ) |
| return [map_user(row) for row in result] |
| |
| async def new_user(self, username, permissions, token): |
| try: |
| async with self.db.begin(): |
| await self._execute( |
| insert(Users).values( |
| username=username, |
| permissions=" ".join(permissions), |
| token=token, |
| ) |
| ) |
| return True |
| except IntegrityError as e: |
| self.logger.debug("Cannot create new user %s: %s", username, e) |
| return False |
| |
| async def delete_user(self, username): |
| async with self.db.begin(): |
| result = await self._execute( |
| delete(Users).where(Users.username == username) |
| ) |
| return result.rowcount != 0 |
| |
| async def get_usage(self): |
| usage = {} |
| async with self.db.begin() as session: |
| for name, table in Base.metadata.tables.items(): |
| result = await self._execute( |
| statement=select(func.count()).select_from(table) |
| ) |
| usage[name] = { |
| "rows": result.scalar(), |
| } |
| |
| return usage |
| |
| async def get_query_columns(self): |
| columns = set() |
| for table in (UnihashesV3, OuthashesV2): |
| for c in table.__table__.columns: |
| if not isinstance(c.type, Text): |
| continue |
| columns.add(c.key) |
| |
| return list(columns) |