blob: cee04bffb037c3389b61470bdba33d59d8f386b1 [file] [log] [blame]
Patrick Williamsac13d5f2023-11-24 18:59:46 -06001#! /usr/bin/env python3
2#
3# Copyright (C) 2023 Garmin Ltd.
4#
5# SPDX-License-Identifier: GPL-2.0-only
6#
7
8import logging
9from datetime import datetime
10from . import User
11
12from sqlalchemy.ext.asyncio import create_async_engine
13from sqlalchemy.pool import NullPool
14from sqlalchemy import (
15 MetaData,
16 Column,
17 Table,
18 Text,
19 Integer,
20 UniqueConstraint,
21 DateTime,
22 Index,
23 select,
24 insert,
25 exists,
26 literal,
27 and_,
28 delete,
29 update,
30 func,
31)
32import sqlalchemy.engine
33from sqlalchemy.orm import declarative_base
34from sqlalchemy.exc import IntegrityError
35
36Base = declarative_base()
37
38
39class UnihashesV2(Base):
40 __tablename__ = "unihashes_v2"
41 id = Column(Integer, primary_key=True, autoincrement=True)
42 method = Column(Text, nullable=False)
43 taskhash = Column(Text, nullable=False)
44 unihash = Column(Text, nullable=False)
45
46 __table_args__ = (
47 UniqueConstraint("method", "taskhash"),
48 Index("taskhash_lookup_v3", "method", "taskhash"),
49 )
50
51
52class OuthashesV2(Base):
53 __tablename__ = "outhashes_v2"
54 id = Column(Integer, primary_key=True, autoincrement=True)
55 method = Column(Text, nullable=False)
56 taskhash = Column(Text, nullable=False)
57 outhash = Column(Text, nullable=False)
58 created = Column(DateTime)
59 owner = Column(Text)
60 PN = Column(Text)
61 PV = Column(Text)
62 PR = Column(Text)
63 task = Column(Text)
64 outhash_siginfo = Column(Text)
65
66 __table_args__ = (
67 UniqueConstraint("method", "taskhash", "outhash"),
68 Index("outhash_lookup_v3", "method", "outhash"),
69 )
70
71
72class Users(Base):
73 __tablename__ = "users"
74 id = Column(Integer, primary_key=True, autoincrement=True)
75 username = Column(Text, nullable=False)
76 token = Column(Text, nullable=False)
77 permissions = Column(Text)
78
79 __table_args__ = (UniqueConstraint("username"),)
80
81
82class DatabaseEngine(object):
83 def __init__(self, url, username=None, password=None):
84 self.logger = logging.getLogger("hashserv.sqlalchemy")
85 self.url = sqlalchemy.engine.make_url(url)
86
87 if username is not None:
88 self.url = self.url.set(username=username)
89
90 if password is not None:
91 self.url = self.url.set(password=password)
92
93 async def create(self):
94 self.logger.info("Using database %s", self.url)
95 self.engine = create_async_engine(self.url, poolclass=NullPool)
96
97 async with self.engine.begin() as conn:
98 # Create tables
99 self.logger.info("Creating tables...")
100 await conn.run_sync(Base.metadata.create_all)
101
102 def connect(self, logger):
103 return Database(self.engine, logger)
104
105
106def map_row(row):
107 if row is None:
108 return None
109 return dict(**row._mapping)
110
111
112def map_user(row):
113 if row is None:
114 return None
115 return User(
116 username=row.username,
117 permissions=set(row.permissions.split()),
118 )
119
120
121class Database(object):
122 def __init__(self, engine, logger):
123 self.engine = engine
124 self.db = None
125 self.logger = logger
126
127 async def __aenter__(self):
128 self.db = await self.engine.connect()
129 return self
130
131 async def __aexit__(self, exc_type, exc_value, traceback):
132 await self.close()
133
134 async def close(self):
135 await self.db.close()
136 self.db = None
137
138 async def get_unihash_by_taskhash_full(self, method, taskhash):
139 statement = (
140 select(
141 OuthashesV2,
142 UnihashesV2.unihash.label("unihash"),
143 )
144 .join(
145 UnihashesV2,
146 and_(
147 UnihashesV2.method == OuthashesV2.method,
148 UnihashesV2.taskhash == OuthashesV2.taskhash,
149 ),
150 )
151 .where(
152 OuthashesV2.method == method,
153 OuthashesV2.taskhash == taskhash,
154 )
155 .order_by(
156 OuthashesV2.created.asc(),
157 )
158 .limit(1)
159 )
160 self.logger.debug("%s", statement)
161 async with self.db.begin():
162 result = await self.db.execute(statement)
163 return map_row(result.first())
164
165 async def get_unihash_by_outhash(self, method, outhash):
166 statement = (
167 select(OuthashesV2, UnihashesV2.unihash.label("unihash"))
168 .join(
169 UnihashesV2,
170 and_(
171 UnihashesV2.method == OuthashesV2.method,
172 UnihashesV2.taskhash == OuthashesV2.taskhash,
173 ),
174 )
175 .where(
176 OuthashesV2.method == method,
177 OuthashesV2.outhash == outhash,
178 )
179 .order_by(
180 OuthashesV2.created.asc(),
181 )
182 .limit(1)
183 )
184 self.logger.debug("%s", statement)
185 async with self.db.begin():
186 result = await self.db.execute(statement)
187 return map_row(result.first())
188
189 async def get_outhash(self, method, outhash):
190 statement = (
191 select(OuthashesV2)
192 .where(
193 OuthashesV2.method == method,
194 OuthashesV2.outhash == outhash,
195 )
196 .order_by(
197 OuthashesV2.created.asc(),
198 )
199 .limit(1)
200 )
201
202 self.logger.debug("%s", statement)
203 async with self.db.begin():
204 result = await self.db.execute(statement)
205 return map_row(result.first())
206
207 async def get_equivalent_for_outhash(self, method, outhash, taskhash):
208 statement = (
209 select(
210 OuthashesV2.taskhash.label("taskhash"),
211 UnihashesV2.unihash.label("unihash"),
212 )
213 .join(
214 UnihashesV2,
215 and_(
216 UnihashesV2.method == OuthashesV2.method,
217 UnihashesV2.taskhash == OuthashesV2.taskhash,
218 ),
219 )
220 .where(
221 OuthashesV2.method == method,
222 OuthashesV2.outhash == outhash,
223 OuthashesV2.taskhash != taskhash,
224 )
225 .order_by(
226 OuthashesV2.created.asc(),
227 )
228 .limit(1)
229 )
230 self.logger.debug("%s", statement)
231 async with self.db.begin():
232 result = await self.db.execute(statement)
233 return map_row(result.first())
234
235 async def get_equivalent(self, method, taskhash):
236 statement = select(
237 UnihashesV2.unihash,
238 UnihashesV2.method,
239 UnihashesV2.taskhash,
240 ).where(
241 UnihashesV2.method == method,
242 UnihashesV2.taskhash == taskhash,
243 )
244 self.logger.debug("%s", statement)
245 async with self.db.begin():
246 result = await self.db.execute(statement)
247 return map_row(result.first())
248
249 async def remove(self, condition):
250 async def do_remove(table):
251 where = {}
252 for c in table.__table__.columns:
253 if c.key in condition and condition[c.key] is not None:
254 where[c] = condition[c.key]
255
256 if where:
257 statement = delete(table).where(*[(k == v) for k, v in where.items()])
258 self.logger.debug("%s", statement)
259 async with self.db.begin():
260 result = await self.db.execute(statement)
261 return result.rowcount
262
263 return 0
264
265 count = 0
266 count += await do_remove(UnihashesV2)
267 count += await do_remove(OuthashesV2)
268
269 return count
270
271 async def clean_unused(self, oldest):
272 statement = delete(OuthashesV2).where(
273 OuthashesV2.created < oldest,
274 ~(
275 select(UnihashesV2.id)
276 .where(
277 UnihashesV2.method == OuthashesV2.method,
278 UnihashesV2.taskhash == OuthashesV2.taskhash,
279 )
280 .limit(1)
281 .exists()
282 ),
283 )
284 self.logger.debug("%s", statement)
285 async with self.db.begin():
286 result = await self.db.execute(statement)
287 return result.rowcount
288
289 async def insert_unihash(self, method, taskhash, unihash):
290 statement = insert(UnihashesV2).values(
291 method=method,
292 taskhash=taskhash,
293 unihash=unihash,
294 )
295 self.logger.debug("%s", statement)
296 try:
297 async with self.db.begin():
298 await self.db.execute(statement)
299 return True
300 except IntegrityError:
301 self.logger.debug(
302 "%s, %s, %s already in unihash database", method, taskhash, unihash
303 )
304 return False
305
306 async def insert_outhash(self, data):
307 outhash_columns = set(c.key for c in OuthashesV2.__table__.columns)
308
309 data = {k: v for k, v in data.items() if k in outhash_columns}
310
311 if "created" in data and not isinstance(data["created"], datetime):
312 data["created"] = datetime.fromisoformat(data["created"])
313
314 statement = insert(OuthashesV2).values(**data)
315 self.logger.debug("%s", statement)
316 try:
317 async with self.db.begin():
318 await self.db.execute(statement)
319 return True
320 except IntegrityError:
321 self.logger.debug(
322 "%s, %s already in outhash database", data["method"], data["outhash"]
323 )
324 return False
325
326 async def _get_user(self, username):
327 statement = select(
328 Users.username,
329 Users.permissions,
330 Users.token,
331 ).where(
332 Users.username == username,
333 )
334 self.logger.debug("%s", statement)
335 async with self.db.begin():
336 result = await self.db.execute(statement)
337 return result.first()
338
339 async def lookup_user_token(self, username):
340 row = await self._get_user(username)
341 if not row:
342 return None, None
343 return map_user(row), row.token
344
345 async def lookup_user(self, username):
346 return map_user(await self._get_user(username))
347
348 async def set_user_token(self, username, token):
349 statement = (
350 update(Users)
351 .where(
352 Users.username == username,
353 )
354 .values(
355 token=token,
356 )
357 )
358 self.logger.debug("%s", statement)
359 async with self.db.begin():
360 result = await self.db.execute(statement)
361 return result.rowcount != 0
362
363 async def set_user_perms(self, username, permissions):
364 statement = (
365 update(Users)
366 .where(Users.username == username)
367 .values(permissions=" ".join(permissions))
368 )
369 self.logger.debug("%s", statement)
370 async with self.db.begin():
371 result = await self.db.execute(statement)
372 return result.rowcount != 0
373
374 async def get_all_users(self):
375 statement = select(
376 Users.username,
377 Users.permissions,
378 )
379 self.logger.debug("%s", statement)
380 async with self.db.begin():
381 result = await self.db.execute(statement)
382 return [map_user(row) for row in result]
383
384 async def new_user(self, username, permissions, token):
385 statement = insert(Users).values(
386 username=username,
387 permissions=" ".join(permissions),
388 token=token,
389 )
390 self.logger.debug("%s", statement)
391 try:
392 async with self.db.begin():
393 await self.db.execute(statement)
394 return True
395 except IntegrityError as e:
396 self.logger.debug("Cannot create new user %s: %s", username, e)
397 return False
398
399 async def delete_user(self, username):
400 statement = delete(Users).where(Users.username == username)
401 self.logger.debug("%s", statement)
402 async with self.db.begin():
403 result = await self.db.execute(statement)
404 return result.rowcount != 0
405
406 async def get_usage(self):
407 usage = {}
408 async with self.db.begin() as session:
409 for name, table in Base.metadata.tables.items():
410 statement = select(func.count()).select_from(table)
411 self.logger.debug("%s", statement)
412 result = await self.db.execute(statement)
413 usage[name] = {
414 "rows": result.scalar(),
415 }
416
417 return usage
418
419 async def get_query_columns(self):
420 columns = set()
421 for table in (UnihashesV2, OuthashesV2):
422 for c in table.__table__.columns:
423 if not isinstance(c.type, Text):
424 continue
425 columns.add(c.key)
426
427 return list(columns)