blob: f7b0226a7a39f3a37b5f4bf627fe2c8bdcecd5c3 [file] [log] [blame]
Patrick Williamsac13d5f2023-11-24 18:59:46 -06001#! /usr/bin/env python3
2#
3# Copyright (C) 2023 Garmin Ltd.
4#
5# SPDX-License-Identifier: GPL-2.0-only
6#
7
8import logging
9from datetime import datetime
10from . import User
11
12from sqlalchemy.ext.asyncio import create_async_engine
13from sqlalchemy.pool import NullPool
14from sqlalchemy import (
15 MetaData,
16 Column,
17 Table,
18 Text,
19 Integer,
20 UniqueConstraint,
21 DateTime,
22 Index,
23 select,
24 insert,
25 exists,
26 literal,
27 and_,
28 delete,
29 update,
30 func,
Patrick Williams73bd93f2024-02-20 08:07:48 -060031 inspect,
Patrick Williamsac13d5f2023-11-24 18:59:46 -060032)
33import sqlalchemy.engine
34from sqlalchemy.orm import declarative_base
35from sqlalchemy.exc import IntegrityError
Patrick Williams73bd93f2024-02-20 08:07:48 -060036from sqlalchemy.dialects.postgresql import insert as postgres_insert
Patrick Williamsac13d5f2023-11-24 18:59:46 -060037
38Base = declarative_base()
39
40
Patrick Williams73bd93f2024-02-20 08:07:48 -060041class UnihashesV3(Base):
42 __tablename__ = "unihashes_v3"
Patrick Williamsac13d5f2023-11-24 18:59:46 -060043 id = Column(Integer, primary_key=True, autoincrement=True)
44 method = Column(Text, nullable=False)
45 taskhash = Column(Text, nullable=False)
46 unihash = Column(Text, nullable=False)
Patrick Williams73bd93f2024-02-20 08:07:48 -060047 gc_mark = Column(Text, nullable=False)
Patrick Williamsac13d5f2023-11-24 18:59:46 -060048
49 __table_args__ = (
50 UniqueConstraint("method", "taskhash"),
Patrick Williams73bd93f2024-02-20 08:07:48 -060051 Index("taskhash_lookup_v4", "method", "taskhash"),
52 Index("unihash_lookup_v1", "unihash"),
Patrick Williamsac13d5f2023-11-24 18:59:46 -060053 )
54
55
56class OuthashesV2(Base):
57 __tablename__ = "outhashes_v2"
58 id = Column(Integer, primary_key=True, autoincrement=True)
59 method = Column(Text, nullable=False)
60 taskhash = Column(Text, nullable=False)
61 outhash = Column(Text, nullable=False)
62 created = Column(DateTime)
63 owner = Column(Text)
64 PN = Column(Text)
65 PV = Column(Text)
66 PR = Column(Text)
67 task = Column(Text)
68 outhash_siginfo = Column(Text)
69
70 __table_args__ = (
71 UniqueConstraint("method", "taskhash", "outhash"),
72 Index("outhash_lookup_v3", "method", "outhash"),
73 )
74
75
76class Users(Base):
77 __tablename__ = "users"
78 id = Column(Integer, primary_key=True, autoincrement=True)
79 username = Column(Text, nullable=False)
80 token = Column(Text, nullable=False)
81 permissions = Column(Text)
82
83 __table_args__ = (UniqueConstraint("username"),)
84
85
Patrick Williams73bd93f2024-02-20 08:07:48 -060086class Config(Base):
87 __tablename__ = "config"
88 id = Column(Integer, primary_key=True, autoincrement=True)
89 name = Column(Text, nullable=False)
90 value = Column(Text)
91 __table_args__ = (
92 UniqueConstraint("name"),
93 Index("config_lookup", "name"),
94 )
95
96
97#
98# Old table versions
99#
100DeprecatedBase = declarative_base()
101
102
103class UnihashesV2(DeprecatedBase):
104 __tablename__ = "unihashes_v2"
105 id = Column(Integer, primary_key=True, autoincrement=True)
106 method = Column(Text, nullable=False)
107 taskhash = Column(Text, nullable=False)
108 unihash = Column(Text, nullable=False)
109
110 __table_args__ = (
111 UniqueConstraint("method", "taskhash"),
112 Index("taskhash_lookup_v3", "method", "taskhash"),
113 )
114
115
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600116class DatabaseEngine(object):
117 def __init__(self, url, username=None, password=None):
118 self.logger = logging.getLogger("hashserv.sqlalchemy")
119 self.url = sqlalchemy.engine.make_url(url)
120
121 if username is not None:
122 self.url = self.url.set(username=username)
123
124 if password is not None:
125 self.url = self.url.set(password=password)
126
127 async def create(self):
Patrick Williams73bd93f2024-02-20 08:07:48 -0600128 def check_table_exists(conn, name):
129 return inspect(conn).has_table(name)
130
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600131 self.logger.info("Using database %s", self.url)
Patrick Williams39653562024-03-01 08:54:02 -0600132 if self.url.drivername == 'postgresql+psycopg':
133 # Psygopg 3 (psygopg) driver can handle async connection pooling
134 self.engine = create_async_engine(self.url, max_overflow=-1)
135 else:
136 self.engine = create_async_engine(self.url, poolclass=NullPool)
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600137
138 async with self.engine.begin() as conn:
139 # Create tables
140 self.logger.info("Creating tables...")
141 await conn.run_sync(Base.metadata.create_all)
142
Patrick Williams73bd93f2024-02-20 08:07:48 -0600143 if await conn.run_sync(check_table_exists, UnihashesV2.__tablename__):
144 self.logger.info("Upgrading Unihashes V2 -> V3...")
145 statement = insert(UnihashesV3).from_select(
146 ["id", "method", "unihash", "taskhash", "gc_mark"],
147 select(
148 UnihashesV2.id,
149 UnihashesV2.method,
150 UnihashesV2.unihash,
151 UnihashesV2.taskhash,
152 literal("").label("gc_mark"),
153 ),
154 )
155 self.logger.debug("%s", statement)
156 await conn.execute(statement)
157
158 await conn.run_sync(Base.metadata.drop_all, [UnihashesV2.__table__])
159 self.logger.info("Upgrade complete")
160
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600161 def connect(self, logger):
162 return Database(self.engine, logger)
163
164
165def map_row(row):
166 if row is None:
167 return None
168 return dict(**row._mapping)
169
170
171def map_user(row):
172 if row is None:
173 return None
174 return User(
175 username=row.username,
176 permissions=set(row.permissions.split()),
177 )
178
179
Patrick Williams73bd93f2024-02-20 08:07:48 -0600180def _make_condition_statement(table, condition):
181 where = {}
182 for c in table.__table__.columns:
183 if c.key in condition and condition[c.key] is not None:
184 where[c] = condition[c.key]
185
186 return [(k == v) for k, v in where.items()]
187
188
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600189class Database(object):
190 def __init__(self, engine, logger):
191 self.engine = engine
192 self.db = None
193 self.logger = logger
194
195 async def __aenter__(self):
196 self.db = await self.engine.connect()
197 return self
198
199 async def __aexit__(self, exc_type, exc_value, traceback):
200 await self.close()
201
202 async def close(self):
203 await self.db.close()
204 self.db = None
205
Patrick Williams73bd93f2024-02-20 08:07:48 -0600206 async def _execute(self, statement):
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600207 self.logger.debug("%s", statement)
Patrick Williams73bd93f2024-02-20 08:07:48 -0600208 return await self.db.execute(statement)
209
210 async def _set_config(self, name, value):
211 while True:
212 result = await self._execute(
213 update(Config).where(Config.name == name).values(value=value)
214 )
215
216 if result.rowcount == 0:
217 self.logger.debug("Config '%s' not found. Adding it", name)
218 try:
219 await self._execute(insert(Config).values(name=name, value=value))
220 except IntegrityError:
221 # Race. Try again
222 continue
223
224 break
225
226 def _get_config_subquery(self, name, default=None):
227 if default is not None:
228 return func.coalesce(
229 select(Config.value).where(Config.name == name).scalar_subquery(),
230 default,
231 )
232 return select(Config.value).where(Config.name == name).scalar_subquery()
233
234 async def _get_config(self, name):
235 result = await self._execute(select(Config.value).where(Config.name == name))
236 row = result.first()
237 if row is None:
238 return None
239 return row.value
240
241 async def get_unihash_by_taskhash_full(self, method, taskhash):
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600242 async with self.db.begin():
Patrick Williams73bd93f2024-02-20 08:07:48 -0600243 result = await self._execute(
244 select(
245 OuthashesV2,
246 UnihashesV3.unihash.label("unihash"),
247 )
248 .join(
249 UnihashesV3,
250 and_(
251 UnihashesV3.method == OuthashesV2.method,
252 UnihashesV3.taskhash == OuthashesV2.taskhash,
253 ),
254 )
255 .where(
256 OuthashesV2.method == method,
257 OuthashesV2.taskhash == taskhash,
258 )
259 .order_by(
260 OuthashesV2.created.asc(),
261 )
262 .limit(1)
263 )
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600264 return map_row(result.first())
265
266 async def get_unihash_by_outhash(self, method, outhash):
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600267 async with self.db.begin():
Patrick Williams73bd93f2024-02-20 08:07:48 -0600268 result = await self._execute(
269 select(OuthashesV2, UnihashesV3.unihash.label("unihash"))
270 .join(
271 UnihashesV3,
272 and_(
273 UnihashesV3.method == OuthashesV2.method,
274 UnihashesV3.taskhash == OuthashesV2.taskhash,
275 ),
276 )
277 .where(
278 OuthashesV2.method == method,
279 OuthashesV2.outhash == outhash,
280 )
281 .order_by(
282 OuthashesV2.created.asc(),
283 )
284 .limit(1)
285 )
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600286 return map_row(result.first())
287
Patrick Williams73bd93f2024-02-20 08:07:48 -0600288 async def unihash_exists(self, unihash):
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600289 async with self.db.begin():
Patrick Williams73bd93f2024-02-20 08:07:48 -0600290 result = await self._execute(
291 select(UnihashesV3).where(UnihashesV3.unihash == unihash).limit(1)
292 )
293
294 return result.first() is not None
295
296 async def get_outhash(self, method, outhash):
297 async with self.db.begin():
298 result = await self._execute(
299 select(OuthashesV2)
300 .where(
301 OuthashesV2.method == method,
302 OuthashesV2.outhash == outhash,
303 )
304 .order_by(
305 OuthashesV2.created.asc(),
306 )
307 .limit(1)
308 )
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600309 return map_row(result.first())
310
311 async def get_equivalent_for_outhash(self, method, outhash, taskhash):
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600312 async with self.db.begin():
Patrick Williams73bd93f2024-02-20 08:07:48 -0600313 result = await self._execute(
314 select(
315 OuthashesV2.taskhash.label("taskhash"),
316 UnihashesV3.unihash.label("unihash"),
317 )
318 .join(
319 UnihashesV3,
320 and_(
321 UnihashesV3.method == OuthashesV2.method,
322 UnihashesV3.taskhash == OuthashesV2.taskhash,
323 ),
324 )
325 .where(
326 OuthashesV2.method == method,
327 OuthashesV2.outhash == outhash,
328 OuthashesV2.taskhash != taskhash,
329 )
330 .order_by(
331 OuthashesV2.created.asc(),
332 )
333 .limit(1)
334 )
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600335 return map_row(result.first())
336
337 async def get_equivalent(self, method, taskhash):
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600338 async with self.db.begin():
Patrick Williams73bd93f2024-02-20 08:07:48 -0600339 result = await self._execute(
340 select(
341 UnihashesV3.unihash,
342 UnihashesV3.method,
343 UnihashesV3.taskhash,
344 ).where(
345 UnihashesV3.method == method,
346 UnihashesV3.taskhash == taskhash,
347 )
348 )
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600349 return map_row(result.first())
350
351 async def remove(self, condition):
352 async def do_remove(table):
Patrick Williams73bd93f2024-02-20 08:07:48 -0600353 where = _make_condition_statement(table, condition)
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600354 if where:
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600355 async with self.db.begin():
Patrick Williams73bd93f2024-02-20 08:07:48 -0600356 result = await self._execute(delete(table).where(*where))
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600357 return result.rowcount
358
359 return 0
360
361 count = 0
Patrick Williams73bd93f2024-02-20 08:07:48 -0600362 count += await do_remove(UnihashesV3)
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600363 count += await do_remove(OuthashesV2)
364
365 return count
366
Patrick Williams73bd93f2024-02-20 08:07:48 -0600367 async def get_current_gc_mark(self):
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600368 async with self.db.begin():
Patrick Williams73bd93f2024-02-20 08:07:48 -0600369 return await self._get_config("gc-mark")
370
371 async def gc_status(self):
372 async with self.db.begin():
373 gc_mark_subquery = self._get_config_subquery("gc-mark", "")
374
375 result = await self._execute(
376 select(func.count())
377 .select_from(UnihashesV3)
378 .where(UnihashesV3.gc_mark == gc_mark_subquery)
379 )
380 keep_rows = result.scalar()
381
382 result = await self._execute(
383 select(func.count())
384 .select_from(UnihashesV3)
385 .where(UnihashesV3.gc_mark != gc_mark_subquery)
386 )
387 remove_rows = result.scalar()
388
389 return (keep_rows, remove_rows, await self._get_config("gc-mark"))
390
391 async def gc_mark(self, mark, condition):
392 async with self.db.begin():
393 await self._set_config("gc-mark", mark)
394
395 where = _make_condition_statement(UnihashesV3, condition)
396 if not where:
397 return 0
398
399 result = await self._execute(
400 update(UnihashesV3)
401 .values(gc_mark=self._get_config_subquery("gc-mark", ""))
402 .where(*where)
403 )
404 return result.rowcount
405
406 async def gc_sweep(self):
407 async with self.db.begin():
408 result = await self._execute(
409 delete(UnihashesV3).where(
410 # A sneaky conditional that provides some errant use
411 # protection: If the config mark is NULL, this will not
412 # match any rows because No default is specified in the
413 # select statement
414 UnihashesV3.gc_mark
415 != self._get_config_subquery("gc-mark")
416 )
417 )
418 await self._set_config("gc-mark", None)
419
420 return result.rowcount
421
422 async def clean_unused(self, oldest):
423 async with self.db.begin():
424 result = await self._execute(
425 delete(OuthashesV2).where(
426 OuthashesV2.created < oldest,
427 ~(
428 select(UnihashesV3.id)
429 .where(
430 UnihashesV3.method == OuthashesV2.method,
431 UnihashesV3.taskhash == OuthashesV2.taskhash,
432 )
433 .limit(1)
434 .exists()
435 ),
436 )
437 )
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600438 return result.rowcount
439
440 async def insert_unihash(self, method, taskhash, unihash):
Patrick Williams73bd93f2024-02-20 08:07:48 -0600441 # Postgres specific ignore on insert duplicate
442 if self.engine.name == "postgresql":
443 statement = (
444 postgres_insert(UnihashesV3)
445 .values(
446 method=method,
447 taskhash=taskhash,
448 unihash=unihash,
449 gc_mark=self._get_config_subquery("gc-mark", ""),
450 )
451 .on_conflict_do_nothing(index_elements=("method", "taskhash"))
452 )
453 else:
454 statement = insert(UnihashesV3).values(
455 method=method,
456 taskhash=taskhash,
457 unihash=unihash,
458 gc_mark=self._get_config_subquery("gc-mark", ""),
459 )
460
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600461 try:
462 async with self.db.begin():
Patrick Williams73bd93f2024-02-20 08:07:48 -0600463 result = await self._execute(statement)
464 return result.rowcount != 0
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600465 except IntegrityError:
466 self.logger.debug(
467 "%s, %s, %s already in unihash database", method, taskhash, unihash
468 )
469 return False
470
471 async def insert_outhash(self, data):
472 outhash_columns = set(c.key for c in OuthashesV2.__table__.columns)
473
474 data = {k: v for k, v in data.items() if k in outhash_columns}
475
476 if "created" in data and not isinstance(data["created"], datetime):
477 data["created"] = datetime.fromisoformat(data["created"])
478
Patrick Williams73bd93f2024-02-20 08:07:48 -0600479 # Postgres specific ignore on insert duplicate
480 if self.engine.name == "postgresql":
481 statement = (
482 postgres_insert(OuthashesV2)
483 .values(**data)
484 .on_conflict_do_nothing(
485 index_elements=("method", "taskhash", "outhash")
486 )
487 )
488 else:
489 statement = insert(OuthashesV2).values(**data)
490
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600491 try:
492 async with self.db.begin():
Patrick Williams73bd93f2024-02-20 08:07:48 -0600493 result = await self._execute(statement)
494 return result.rowcount != 0
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600495 except IntegrityError:
496 self.logger.debug(
497 "%s, %s already in outhash database", data["method"], data["outhash"]
498 )
499 return False
500
501 async def _get_user(self, username):
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600502 async with self.db.begin():
Patrick Williams73bd93f2024-02-20 08:07:48 -0600503 result = await self._execute(
504 select(
505 Users.username,
506 Users.permissions,
507 Users.token,
508 ).where(
509 Users.username == username,
510 )
511 )
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600512 return result.first()
513
514 async def lookup_user_token(self, username):
515 row = await self._get_user(username)
516 if not row:
517 return None, None
518 return map_user(row), row.token
519
520 async def lookup_user(self, username):
521 return map_user(await self._get_user(username))
522
523 async def set_user_token(self, username, token):
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600524 async with self.db.begin():
Patrick Williams73bd93f2024-02-20 08:07:48 -0600525 result = await self._execute(
526 update(Users)
527 .where(
528 Users.username == username,
529 )
530 .values(
531 token=token,
532 )
533 )
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600534 return result.rowcount != 0
535
536 async def set_user_perms(self, username, permissions):
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600537 async with self.db.begin():
Patrick Williams73bd93f2024-02-20 08:07:48 -0600538 result = await self._execute(
539 update(Users)
540 .where(Users.username == username)
541 .values(permissions=" ".join(permissions))
542 )
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600543 return result.rowcount != 0
544
545 async def get_all_users(self):
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600546 async with self.db.begin():
Patrick Williams73bd93f2024-02-20 08:07:48 -0600547 result = await self._execute(
548 select(
549 Users.username,
550 Users.permissions,
551 )
552 )
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600553 return [map_user(row) for row in result]
554
555 async def new_user(self, username, permissions, token):
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600556 try:
557 async with self.db.begin():
Patrick Williams73bd93f2024-02-20 08:07:48 -0600558 await self._execute(
559 insert(Users).values(
560 username=username,
561 permissions=" ".join(permissions),
562 token=token,
563 )
564 )
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600565 return True
566 except IntegrityError as e:
567 self.logger.debug("Cannot create new user %s: %s", username, e)
568 return False
569
570 async def delete_user(self, username):
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600571 async with self.db.begin():
Patrick Williams73bd93f2024-02-20 08:07:48 -0600572 result = await self._execute(
573 delete(Users).where(Users.username == username)
574 )
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600575 return result.rowcount != 0
576
577 async def get_usage(self):
578 usage = {}
579 async with self.db.begin() as session:
580 for name, table in Base.metadata.tables.items():
Patrick Williams73bd93f2024-02-20 08:07:48 -0600581 result = await self._execute(
582 statement=select(func.count()).select_from(table)
583 )
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600584 usage[name] = {
585 "rows": result.scalar(),
586 }
587
588 return usage
589
590 async def get_query_columns(self):
591 columns = set()
Patrick Williams73bd93f2024-02-20 08:07:48 -0600592 for table in (UnihashesV3, OuthashesV2):
Patrick Williamsac13d5f2023-11-24 18:59:46 -0600593 for c in table.__table__.columns:
594 if not isinstance(c.type, Text):
595 continue
596 columns.add(c.key)
597
598 return list(columns)