blob: cee04bffb037c3389b61470bdba33d59d8f386b1 [file] [log] [blame]
#! /usr/bin/env python3
#
# Copyright (C) 2023 Garmin Ltd.
#
# SPDX-License-Identifier: GPL-2.0-only
#
import logging
from datetime import datetime
from . import User
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.pool import NullPool
from sqlalchemy import (
MetaData,
Column,
Table,
Text,
Integer,
UniqueConstraint,
DateTime,
Index,
select,
insert,
exists,
literal,
and_,
delete,
update,
func,
)
import sqlalchemy.engine
from sqlalchemy.orm import declarative_base
from sqlalchemy.exc import IntegrityError
Base = declarative_base()
class UnihashesV2(Base):
__tablename__ = "unihashes_v2"
id = Column(Integer, primary_key=True, autoincrement=True)
method = Column(Text, nullable=False)
taskhash = Column(Text, nullable=False)
unihash = Column(Text, nullable=False)
__table_args__ = (
UniqueConstraint("method", "taskhash"),
Index("taskhash_lookup_v3", "method", "taskhash"),
)
class OuthashesV2(Base):
__tablename__ = "outhashes_v2"
id = Column(Integer, primary_key=True, autoincrement=True)
method = Column(Text, nullable=False)
taskhash = Column(Text, nullable=False)
outhash = Column(Text, nullable=False)
created = Column(DateTime)
owner = Column(Text)
PN = Column(Text)
PV = Column(Text)
PR = Column(Text)
task = Column(Text)
outhash_siginfo = Column(Text)
__table_args__ = (
UniqueConstraint("method", "taskhash", "outhash"),
Index("outhash_lookup_v3", "method", "outhash"),
)
class Users(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True, autoincrement=True)
username = Column(Text, nullable=False)
token = Column(Text, nullable=False)
permissions = Column(Text)
__table_args__ = (UniqueConstraint("username"),)
class DatabaseEngine(object):
def __init__(self, url, username=None, password=None):
self.logger = logging.getLogger("hashserv.sqlalchemy")
self.url = sqlalchemy.engine.make_url(url)
if username is not None:
self.url = self.url.set(username=username)
if password is not None:
self.url = self.url.set(password=password)
async def create(self):
self.logger.info("Using database %s", self.url)
self.engine = create_async_engine(self.url, poolclass=NullPool)
async with self.engine.begin() as conn:
# Create tables
self.logger.info("Creating tables...")
await conn.run_sync(Base.metadata.create_all)
def connect(self, logger):
return Database(self.engine, logger)
def map_row(row):
if row is None:
return None
return dict(**row._mapping)
def map_user(row):
if row is None:
return None
return User(
username=row.username,
permissions=set(row.permissions.split()),
)
class Database(object):
def __init__(self, engine, logger):
self.engine = engine
self.db = None
self.logger = logger
async def __aenter__(self):
self.db = await self.engine.connect()
return self
async def __aexit__(self, exc_type, exc_value, traceback):
await self.close()
async def close(self):
await self.db.close()
self.db = None
async def get_unihash_by_taskhash_full(self, method, taskhash):
statement = (
select(
OuthashesV2,
UnihashesV2.unihash.label("unihash"),
)
.join(
UnihashesV2,
and_(
UnihashesV2.method == OuthashesV2.method,
UnihashesV2.taskhash == OuthashesV2.taskhash,
),
)
.where(
OuthashesV2.method == method,
OuthashesV2.taskhash == taskhash,
)
.order_by(
OuthashesV2.created.asc(),
)
.limit(1)
)
self.logger.debug("%s", statement)
async with self.db.begin():
result = await self.db.execute(statement)
return map_row(result.first())
async def get_unihash_by_outhash(self, method, outhash):
statement = (
select(OuthashesV2, UnihashesV2.unihash.label("unihash"))
.join(
UnihashesV2,
and_(
UnihashesV2.method == OuthashesV2.method,
UnihashesV2.taskhash == OuthashesV2.taskhash,
),
)
.where(
OuthashesV2.method == method,
OuthashesV2.outhash == outhash,
)
.order_by(
OuthashesV2.created.asc(),
)
.limit(1)
)
self.logger.debug("%s", statement)
async with self.db.begin():
result = await self.db.execute(statement)
return map_row(result.first())
async def get_outhash(self, method, outhash):
statement = (
select(OuthashesV2)
.where(
OuthashesV2.method == method,
OuthashesV2.outhash == outhash,
)
.order_by(
OuthashesV2.created.asc(),
)
.limit(1)
)
self.logger.debug("%s", statement)
async with self.db.begin():
result = await self.db.execute(statement)
return map_row(result.first())
async def get_equivalent_for_outhash(self, method, outhash, taskhash):
statement = (
select(
OuthashesV2.taskhash.label("taskhash"),
UnihashesV2.unihash.label("unihash"),
)
.join(
UnihashesV2,
and_(
UnihashesV2.method == OuthashesV2.method,
UnihashesV2.taskhash == OuthashesV2.taskhash,
),
)
.where(
OuthashesV2.method == method,
OuthashesV2.outhash == outhash,
OuthashesV2.taskhash != taskhash,
)
.order_by(
OuthashesV2.created.asc(),
)
.limit(1)
)
self.logger.debug("%s", statement)
async with self.db.begin():
result = await self.db.execute(statement)
return map_row(result.first())
async def get_equivalent(self, method, taskhash):
statement = select(
UnihashesV2.unihash,
UnihashesV2.method,
UnihashesV2.taskhash,
).where(
UnihashesV2.method == method,
UnihashesV2.taskhash == taskhash,
)
self.logger.debug("%s", statement)
async with self.db.begin():
result = await self.db.execute(statement)
return map_row(result.first())
async def remove(self, condition):
async def do_remove(table):
where = {}
for c in table.__table__.columns:
if c.key in condition and condition[c.key] is not None:
where[c] = condition[c.key]
if where:
statement = delete(table).where(*[(k == v) for k, v in where.items()])
self.logger.debug("%s", statement)
async with self.db.begin():
result = await self.db.execute(statement)
return result.rowcount
return 0
count = 0
count += await do_remove(UnihashesV2)
count += await do_remove(OuthashesV2)
return count
async def clean_unused(self, oldest):
statement = delete(OuthashesV2).where(
OuthashesV2.created < oldest,
~(
select(UnihashesV2.id)
.where(
UnihashesV2.method == OuthashesV2.method,
UnihashesV2.taskhash == OuthashesV2.taskhash,
)
.limit(1)
.exists()
),
)
self.logger.debug("%s", statement)
async with self.db.begin():
result = await self.db.execute(statement)
return result.rowcount
async def insert_unihash(self, method, taskhash, unihash):
statement = insert(UnihashesV2).values(
method=method,
taskhash=taskhash,
unihash=unihash,
)
self.logger.debug("%s", statement)
try:
async with self.db.begin():
await self.db.execute(statement)
return True
except IntegrityError:
self.logger.debug(
"%s, %s, %s already in unihash database", method, taskhash, unihash
)
return False
async def insert_outhash(self, data):
outhash_columns = set(c.key for c in OuthashesV2.__table__.columns)
data = {k: v for k, v in data.items() if k in outhash_columns}
if "created" in data and not isinstance(data["created"], datetime):
data["created"] = datetime.fromisoformat(data["created"])
statement = insert(OuthashesV2).values(**data)
self.logger.debug("%s", statement)
try:
async with self.db.begin():
await self.db.execute(statement)
return True
except IntegrityError:
self.logger.debug(
"%s, %s already in outhash database", data["method"], data["outhash"]
)
return False
async def _get_user(self, username):
statement = select(
Users.username,
Users.permissions,
Users.token,
).where(
Users.username == username,
)
self.logger.debug("%s", statement)
async with self.db.begin():
result = await self.db.execute(statement)
return result.first()
async def lookup_user_token(self, username):
row = await self._get_user(username)
if not row:
return None, None
return map_user(row), row.token
async def lookup_user(self, username):
return map_user(await self._get_user(username))
async def set_user_token(self, username, token):
statement = (
update(Users)
.where(
Users.username == username,
)
.values(
token=token,
)
)
self.logger.debug("%s", statement)
async with self.db.begin():
result = await self.db.execute(statement)
return result.rowcount != 0
async def set_user_perms(self, username, permissions):
statement = (
update(Users)
.where(Users.username == username)
.values(permissions=" ".join(permissions))
)
self.logger.debug("%s", statement)
async with self.db.begin():
result = await self.db.execute(statement)
return result.rowcount != 0
async def get_all_users(self):
statement = select(
Users.username,
Users.permissions,
)
self.logger.debug("%s", statement)
async with self.db.begin():
result = await self.db.execute(statement)
return [map_user(row) for row in result]
async def new_user(self, username, permissions, token):
statement = insert(Users).values(
username=username,
permissions=" ".join(permissions),
token=token,
)
self.logger.debug("%s", statement)
try:
async with self.db.begin():
await self.db.execute(statement)
return True
except IntegrityError as e:
self.logger.debug("Cannot create new user %s: %s", username, e)
return False
async def delete_user(self, username):
statement = delete(Users).where(Users.username == username)
self.logger.debug("%s", statement)
async with self.db.begin():
result = await self.db.execute(statement)
return result.rowcount != 0
async def get_usage(self):
usage = {}
async with self.db.begin() as session:
for name, table in Base.metadata.tables.items():
statement = select(func.count()).select_from(table)
self.logger.debug("%s", statement)
result = await self.db.execute(statement)
usage[name] = {
"rows": result.scalar(),
}
return usage
async def get_query_columns(self):
columns = set()
for table in (UnihashesV2, OuthashesV2):
for c in table.__table__.columns:
if not isinstance(c.type, Text):
continue
columns.add(c.key)
return list(columns)