Patrick Williams | ac13d5f | 2023-11-24 18:59:46 -0600 | [diff] [blame] | 1 | #! /usr/bin/env python3 |
| 2 | # |
| 3 | # Copyright (C) 2023 Garmin Ltd. |
| 4 | # |
| 5 | # SPDX-License-Identifier: GPL-2.0-only |
| 6 | # |
| 7 | |
| 8 | import logging |
| 9 | from datetime import datetime |
| 10 | from . import User |
| 11 | |
| 12 | from sqlalchemy.ext.asyncio import create_async_engine |
| 13 | from sqlalchemy.pool import NullPool |
| 14 | from sqlalchemy import ( |
| 15 | MetaData, |
| 16 | Column, |
| 17 | Table, |
| 18 | Text, |
| 19 | Integer, |
| 20 | UniqueConstraint, |
| 21 | DateTime, |
| 22 | Index, |
| 23 | select, |
| 24 | insert, |
| 25 | exists, |
| 26 | literal, |
| 27 | and_, |
| 28 | delete, |
| 29 | update, |
| 30 | func, |
| 31 | ) |
| 32 | import sqlalchemy.engine |
| 33 | from sqlalchemy.orm import declarative_base |
| 34 | from sqlalchemy.exc import IntegrityError |
| 35 | |
| 36 | Base = declarative_base() |
| 37 | |
| 38 | |
| 39 | class UnihashesV2(Base): |
| 40 | __tablename__ = "unihashes_v2" |
| 41 | id = Column(Integer, primary_key=True, autoincrement=True) |
| 42 | method = Column(Text, nullable=False) |
| 43 | taskhash = Column(Text, nullable=False) |
| 44 | unihash = Column(Text, nullable=False) |
| 45 | |
| 46 | __table_args__ = ( |
| 47 | UniqueConstraint("method", "taskhash"), |
| 48 | Index("taskhash_lookup_v3", "method", "taskhash"), |
| 49 | ) |
| 50 | |
| 51 | |
| 52 | class OuthashesV2(Base): |
| 53 | __tablename__ = "outhashes_v2" |
| 54 | id = Column(Integer, primary_key=True, autoincrement=True) |
| 55 | method = Column(Text, nullable=False) |
| 56 | taskhash = Column(Text, nullable=False) |
| 57 | outhash = Column(Text, nullable=False) |
| 58 | created = Column(DateTime) |
| 59 | owner = Column(Text) |
| 60 | PN = Column(Text) |
| 61 | PV = Column(Text) |
| 62 | PR = Column(Text) |
| 63 | task = Column(Text) |
| 64 | outhash_siginfo = Column(Text) |
| 65 | |
| 66 | __table_args__ = ( |
| 67 | UniqueConstraint("method", "taskhash", "outhash"), |
| 68 | Index("outhash_lookup_v3", "method", "outhash"), |
| 69 | ) |
| 70 | |
| 71 | |
| 72 | class Users(Base): |
| 73 | __tablename__ = "users" |
| 74 | id = Column(Integer, primary_key=True, autoincrement=True) |
| 75 | username = Column(Text, nullable=False) |
| 76 | token = Column(Text, nullable=False) |
| 77 | permissions = Column(Text) |
| 78 | |
| 79 | __table_args__ = (UniqueConstraint("username"),) |
| 80 | |
| 81 | |
| 82 | class DatabaseEngine(object): |
| 83 | def __init__(self, url, username=None, password=None): |
| 84 | self.logger = logging.getLogger("hashserv.sqlalchemy") |
| 85 | self.url = sqlalchemy.engine.make_url(url) |
| 86 | |
| 87 | if username is not None: |
| 88 | self.url = self.url.set(username=username) |
| 89 | |
| 90 | if password is not None: |
| 91 | self.url = self.url.set(password=password) |
| 92 | |
| 93 | async def create(self): |
| 94 | self.logger.info("Using database %s", self.url) |
| 95 | self.engine = create_async_engine(self.url, poolclass=NullPool) |
| 96 | |
| 97 | async with self.engine.begin() as conn: |
| 98 | # Create tables |
| 99 | self.logger.info("Creating tables...") |
| 100 | await conn.run_sync(Base.metadata.create_all) |
| 101 | |
| 102 | def connect(self, logger): |
| 103 | return Database(self.engine, logger) |
| 104 | |
| 105 | |
| 106 | def map_row(row): |
| 107 | if row is None: |
| 108 | return None |
| 109 | return dict(**row._mapping) |
| 110 | |
| 111 | |
| 112 | def map_user(row): |
| 113 | if row is None: |
| 114 | return None |
| 115 | return User( |
| 116 | username=row.username, |
| 117 | permissions=set(row.permissions.split()), |
| 118 | ) |
| 119 | |
| 120 | |
| 121 | class Database(object): |
| 122 | def __init__(self, engine, logger): |
| 123 | self.engine = engine |
| 124 | self.db = None |
| 125 | self.logger = logger |
| 126 | |
| 127 | async def __aenter__(self): |
| 128 | self.db = await self.engine.connect() |
| 129 | return self |
| 130 | |
| 131 | async def __aexit__(self, exc_type, exc_value, traceback): |
| 132 | await self.close() |
| 133 | |
| 134 | async def close(self): |
| 135 | await self.db.close() |
| 136 | self.db = None |
| 137 | |
| 138 | async def get_unihash_by_taskhash_full(self, method, taskhash): |
| 139 | statement = ( |
| 140 | select( |
| 141 | OuthashesV2, |
| 142 | UnihashesV2.unihash.label("unihash"), |
| 143 | ) |
| 144 | .join( |
| 145 | UnihashesV2, |
| 146 | and_( |
| 147 | UnihashesV2.method == OuthashesV2.method, |
| 148 | UnihashesV2.taskhash == OuthashesV2.taskhash, |
| 149 | ), |
| 150 | ) |
| 151 | .where( |
| 152 | OuthashesV2.method == method, |
| 153 | OuthashesV2.taskhash == taskhash, |
| 154 | ) |
| 155 | .order_by( |
| 156 | OuthashesV2.created.asc(), |
| 157 | ) |
| 158 | .limit(1) |
| 159 | ) |
| 160 | self.logger.debug("%s", statement) |
| 161 | async with self.db.begin(): |
| 162 | result = await self.db.execute(statement) |
| 163 | return map_row(result.first()) |
| 164 | |
| 165 | async def get_unihash_by_outhash(self, method, outhash): |
| 166 | statement = ( |
| 167 | select(OuthashesV2, UnihashesV2.unihash.label("unihash")) |
| 168 | .join( |
| 169 | UnihashesV2, |
| 170 | and_( |
| 171 | UnihashesV2.method == OuthashesV2.method, |
| 172 | UnihashesV2.taskhash == OuthashesV2.taskhash, |
| 173 | ), |
| 174 | ) |
| 175 | .where( |
| 176 | OuthashesV2.method == method, |
| 177 | OuthashesV2.outhash == outhash, |
| 178 | ) |
| 179 | .order_by( |
| 180 | OuthashesV2.created.asc(), |
| 181 | ) |
| 182 | .limit(1) |
| 183 | ) |
| 184 | self.logger.debug("%s", statement) |
| 185 | async with self.db.begin(): |
| 186 | result = await self.db.execute(statement) |
| 187 | return map_row(result.first()) |
| 188 | |
| 189 | async def get_outhash(self, method, outhash): |
| 190 | statement = ( |
| 191 | select(OuthashesV2) |
| 192 | .where( |
| 193 | OuthashesV2.method == method, |
| 194 | OuthashesV2.outhash == outhash, |
| 195 | ) |
| 196 | .order_by( |
| 197 | OuthashesV2.created.asc(), |
| 198 | ) |
| 199 | .limit(1) |
| 200 | ) |
| 201 | |
| 202 | self.logger.debug("%s", statement) |
| 203 | async with self.db.begin(): |
| 204 | result = await self.db.execute(statement) |
| 205 | return map_row(result.first()) |
| 206 | |
| 207 | async def get_equivalent_for_outhash(self, method, outhash, taskhash): |
| 208 | statement = ( |
| 209 | select( |
| 210 | OuthashesV2.taskhash.label("taskhash"), |
| 211 | UnihashesV2.unihash.label("unihash"), |
| 212 | ) |
| 213 | .join( |
| 214 | UnihashesV2, |
| 215 | and_( |
| 216 | UnihashesV2.method == OuthashesV2.method, |
| 217 | UnihashesV2.taskhash == OuthashesV2.taskhash, |
| 218 | ), |
| 219 | ) |
| 220 | .where( |
| 221 | OuthashesV2.method == method, |
| 222 | OuthashesV2.outhash == outhash, |
| 223 | OuthashesV2.taskhash != taskhash, |
| 224 | ) |
| 225 | .order_by( |
| 226 | OuthashesV2.created.asc(), |
| 227 | ) |
| 228 | .limit(1) |
| 229 | ) |
| 230 | self.logger.debug("%s", statement) |
| 231 | async with self.db.begin(): |
| 232 | result = await self.db.execute(statement) |
| 233 | return map_row(result.first()) |
| 234 | |
| 235 | async def get_equivalent(self, method, taskhash): |
| 236 | statement = select( |
| 237 | UnihashesV2.unihash, |
| 238 | UnihashesV2.method, |
| 239 | UnihashesV2.taskhash, |
| 240 | ).where( |
| 241 | UnihashesV2.method == method, |
| 242 | UnihashesV2.taskhash == taskhash, |
| 243 | ) |
| 244 | self.logger.debug("%s", statement) |
| 245 | async with self.db.begin(): |
| 246 | result = await self.db.execute(statement) |
| 247 | return map_row(result.first()) |
| 248 | |
| 249 | async def remove(self, condition): |
| 250 | async def do_remove(table): |
| 251 | where = {} |
| 252 | for c in table.__table__.columns: |
| 253 | if c.key in condition and condition[c.key] is not None: |
| 254 | where[c] = condition[c.key] |
| 255 | |
| 256 | if where: |
| 257 | statement = delete(table).where(*[(k == v) for k, v in where.items()]) |
| 258 | self.logger.debug("%s", statement) |
| 259 | async with self.db.begin(): |
| 260 | result = await self.db.execute(statement) |
| 261 | return result.rowcount |
| 262 | |
| 263 | return 0 |
| 264 | |
| 265 | count = 0 |
| 266 | count += await do_remove(UnihashesV2) |
| 267 | count += await do_remove(OuthashesV2) |
| 268 | |
| 269 | return count |
| 270 | |
| 271 | async def clean_unused(self, oldest): |
| 272 | statement = delete(OuthashesV2).where( |
| 273 | OuthashesV2.created < oldest, |
| 274 | ~( |
| 275 | select(UnihashesV2.id) |
| 276 | .where( |
| 277 | UnihashesV2.method == OuthashesV2.method, |
| 278 | UnihashesV2.taskhash == OuthashesV2.taskhash, |
| 279 | ) |
| 280 | .limit(1) |
| 281 | .exists() |
| 282 | ), |
| 283 | ) |
| 284 | self.logger.debug("%s", statement) |
| 285 | async with self.db.begin(): |
| 286 | result = await self.db.execute(statement) |
| 287 | return result.rowcount |
| 288 | |
| 289 | async def insert_unihash(self, method, taskhash, unihash): |
| 290 | statement = insert(UnihashesV2).values( |
| 291 | method=method, |
| 292 | taskhash=taskhash, |
| 293 | unihash=unihash, |
| 294 | ) |
| 295 | self.logger.debug("%s", statement) |
| 296 | try: |
| 297 | async with self.db.begin(): |
| 298 | await self.db.execute(statement) |
| 299 | return True |
| 300 | except IntegrityError: |
| 301 | self.logger.debug( |
| 302 | "%s, %s, %s already in unihash database", method, taskhash, unihash |
| 303 | ) |
| 304 | return False |
| 305 | |
| 306 | async def insert_outhash(self, data): |
| 307 | outhash_columns = set(c.key for c in OuthashesV2.__table__.columns) |
| 308 | |
| 309 | data = {k: v for k, v in data.items() if k in outhash_columns} |
| 310 | |
| 311 | if "created" in data and not isinstance(data["created"], datetime): |
| 312 | data["created"] = datetime.fromisoformat(data["created"]) |
| 313 | |
| 314 | statement = insert(OuthashesV2).values(**data) |
| 315 | self.logger.debug("%s", statement) |
| 316 | try: |
| 317 | async with self.db.begin(): |
| 318 | await self.db.execute(statement) |
| 319 | return True |
| 320 | except IntegrityError: |
| 321 | self.logger.debug( |
| 322 | "%s, %s already in outhash database", data["method"], data["outhash"] |
| 323 | ) |
| 324 | return False |
| 325 | |
| 326 | async def _get_user(self, username): |
| 327 | statement = select( |
| 328 | Users.username, |
| 329 | Users.permissions, |
| 330 | Users.token, |
| 331 | ).where( |
| 332 | Users.username == username, |
| 333 | ) |
| 334 | self.logger.debug("%s", statement) |
| 335 | async with self.db.begin(): |
| 336 | result = await self.db.execute(statement) |
| 337 | return result.first() |
| 338 | |
| 339 | async def lookup_user_token(self, username): |
| 340 | row = await self._get_user(username) |
| 341 | if not row: |
| 342 | return None, None |
| 343 | return map_user(row), row.token |
| 344 | |
| 345 | async def lookup_user(self, username): |
| 346 | return map_user(await self._get_user(username)) |
| 347 | |
| 348 | async def set_user_token(self, username, token): |
| 349 | statement = ( |
| 350 | update(Users) |
| 351 | .where( |
| 352 | Users.username == username, |
| 353 | ) |
| 354 | .values( |
| 355 | token=token, |
| 356 | ) |
| 357 | ) |
| 358 | self.logger.debug("%s", statement) |
| 359 | async with self.db.begin(): |
| 360 | result = await self.db.execute(statement) |
| 361 | return result.rowcount != 0 |
| 362 | |
| 363 | async def set_user_perms(self, username, permissions): |
| 364 | statement = ( |
| 365 | update(Users) |
| 366 | .where(Users.username == username) |
| 367 | .values(permissions=" ".join(permissions)) |
| 368 | ) |
| 369 | self.logger.debug("%s", statement) |
| 370 | async with self.db.begin(): |
| 371 | result = await self.db.execute(statement) |
| 372 | return result.rowcount != 0 |
| 373 | |
| 374 | async def get_all_users(self): |
| 375 | statement = select( |
| 376 | Users.username, |
| 377 | Users.permissions, |
| 378 | ) |
| 379 | self.logger.debug("%s", statement) |
| 380 | async with self.db.begin(): |
| 381 | result = await self.db.execute(statement) |
| 382 | return [map_user(row) for row in result] |
| 383 | |
| 384 | async def new_user(self, username, permissions, token): |
| 385 | statement = insert(Users).values( |
| 386 | username=username, |
| 387 | permissions=" ".join(permissions), |
| 388 | token=token, |
| 389 | ) |
| 390 | self.logger.debug("%s", statement) |
| 391 | try: |
| 392 | async with self.db.begin(): |
| 393 | await self.db.execute(statement) |
| 394 | return True |
| 395 | except IntegrityError as e: |
| 396 | self.logger.debug("Cannot create new user %s: %s", username, e) |
| 397 | return False |
| 398 | |
| 399 | async def delete_user(self, username): |
| 400 | statement = delete(Users).where(Users.username == username) |
| 401 | self.logger.debug("%s", statement) |
| 402 | async with self.db.begin(): |
| 403 | result = await self.db.execute(statement) |
| 404 | return result.rowcount != 0 |
| 405 | |
| 406 | async def get_usage(self): |
| 407 | usage = {} |
| 408 | async with self.db.begin() as session: |
| 409 | for name, table in Base.metadata.tables.items(): |
| 410 | statement = select(func.count()).select_from(table) |
| 411 | self.logger.debug("%s", statement) |
| 412 | result = await self.db.execute(statement) |
| 413 | usage[name] = { |
| 414 | "rows": result.scalar(), |
| 415 | } |
| 416 | |
| 417 | return usage |
| 418 | |
| 419 | async def get_query_columns(self): |
| 420 | columns = set() |
| 421 | for table in (UnihashesV2, OuthashesV2): |
| 422 | for c in table.__table__.columns: |
| 423 | if not isinstance(c.type, Text): |
| 424 | continue |
| 425 | columns.add(c.key) |
| 426 | |
| 427 | return list(columns) |