blob: f65036be9393e3af8de46d0381a9b4187e52d6d9 [file] [log] [blame]
Patrick Williamsac13d5f2023-11-24 18:59:46 -06001#! /usr/bin/env python3
2#
3# Copyright (C) 2023 Garmin Ltd.
4#
5# SPDX-License-Identifier: GPL-2.0-only
6#
7import sqlite3
8import logging
9from contextlib import closing
10from . import User
11
12logger = logging.getLogger("hashserv.sqlite")
13
14UNIHASH_TABLE_DEFINITION = (
15 ("method", "TEXT NOT NULL", "UNIQUE"),
16 ("taskhash", "TEXT NOT NULL", "UNIQUE"),
17 ("unihash", "TEXT NOT NULL", ""),
18)
19
20UNIHASH_TABLE_COLUMNS = tuple(name for name, _, _ in UNIHASH_TABLE_DEFINITION)
21
22OUTHASH_TABLE_DEFINITION = (
23 ("method", "TEXT NOT NULL", "UNIQUE"),
24 ("taskhash", "TEXT NOT NULL", "UNIQUE"),
25 ("outhash", "TEXT NOT NULL", "UNIQUE"),
26 ("created", "DATETIME", ""),
27 # Optional fields
28 ("owner", "TEXT", ""),
29 ("PN", "TEXT", ""),
30 ("PV", "TEXT", ""),
31 ("PR", "TEXT", ""),
32 ("task", "TEXT", ""),
33 ("outhash_siginfo", "TEXT", ""),
34)
35
36OUTHASH_TABLE_COLUMNS = tuple(name for name, _, _ in OUTHASH_TABLE_DEFINITION)
37
38USERS_TABLE_DEFINITION = (
39 ("username", "TEXT NOT NULL", "UNIQUE"),
40 ("token", "TEXT NOT NULL", ""),
41 ("permissions", "TEXT NOT NULL", ""),
42)
43
44USERS_TABLE_COLUMNS = tuple(name for name, _, _ in USERS_TABLE_DEFINITION)
45
46
47def _make_table(cursor, name, definition):
48 cursor.execute(
49 """
50 CREATE TABLE IF NOT EXISTS {name} (
51 id INTEGER PRIMARY KEY AUTOINCREMENT,
52 {fields}
53 UNIQUE({unique})
54 )
55 """.format(
56 name=name,
57 fields=" ".join("%s %s," % (name, typ) for name, typ, _ in definition),
58 unique=", ".join(
59 name for name, _, flags in definition if "UNIQUE" in flags
60 ),
61 )
62 )
63
64
65def map_user(row):
66 if row is None:
67 return None
68 return User(
69 username=row["username"],
70 permissions=set(row["permissions"].split()),
71 )
72
73
74class DatabaseEngine(object):
75 def __init__(self, dbname, sync):
76 self.dbname = dbname
77 self.logger = logger
78 self.sync = sync
79
80 async def create(self):
81 db = sqlite3.connect(self.dbname)
82 db.row_factory = sqlite3.Row
83
84 with closing(db.cursor()) as cursor:
85 _make_table(cursor, "unihashes_v2", UNIHASH_TABLE_DEFINITION)
86 _make_table(cursor, "outhashes_v2", OUTHASH_TABLE_DEFINITION)
87 _make_table(cursor, "users", USERS_TABLE_DEFINITION)
88
89 cursor.execute("PRAGMA journal_mode = WAL")
90 cursor.execute(
91 "PRAGMA synchronous = %s" % ("NORMAL" if self.sync else "OFF")
92 )
93
94 # Drop old indexes
95 cursor.execute("DROP INDEX IF EXISTS taskhash_lookup")
96 cursor.execute("DROP INDEX IF EXISTS outhash_lookup")
97 cursor.execute("DROP INDEX IF EXISTS taskhash_lookup_v2")
98 cursor.execute("DROP INDEX IF EXISTS outhash_lookup_v2")
99
100 # TODO: Upgrade from tasks_v2?
101 cursor.execute("DROP TABLE IF EXISTS tasks_v2")
102
103 # Create new indexes
104 cursor.execute(
105 "CREATE INDEX IF NOT EXISTS taskhash_lookup_v3 ON unihashes_v2 (method, taskhash)"
106 )
107 cursor.execute(
108 "CREATE INDEX IF NOT EXISTS outhash_lookup_v3 ON outhashes_v2 (method, outhash)"
109 )
110
111 def connect(self, logger):
112 return Database(logger, self.dbname)
113
114
115class Database(object):
116 def __init__(self, logger, dbname, sync=True):
117 self.dbname = dbname
118 self.logger = logger
119
120 self.db = sqlite3.connect(self.dbname)
121 self.db.row_factory = sqlite3.Row
122
123 with closing(self.db.cursor()) as cursor:
124 cursor.execute("SELECT sqlite_version()")
125
126 version = []
127 for v in cursor.fetchone()[0].split("."):
128 try:
129 version.append(int(v))
130 except ValueError:
131 version.append(v)
132
133 self.sqlite_version = tuple(version)
134
135 async def __aenter__(self):
136 return self
137
138 async def __aexit__(self, exc_type, exc_value, traceback):
139 await self.close()
140
141 async def close(self):
142 self.db.close()
143
144 async def get_unihash_by_taskhash_full(self, method, taskhash):
145 with closing(self.db.cursor()) as cursor:
146 cursor.execute(
147 """
148 SELECT *, unihashes_v2.unihash AS unihash FROM outhashes_v2
149 INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash
150 WHERE outhashes_v2.method=:method AND outhashes_v2.taskhash=:taskhash
151 ORDER BY outhashes_v2.created ASC
152 LIMIT 1
153 """,
154 {
155 "method": method,
156 "taskhash": taskhash,
157 },
158 )
159 return cursor.fetchone()
160
161 async def get_unihash_by_outhash(self, method, outhash):
162 with closing(self.db.cursor()) as cursor:
163 cursor.execute(
164 """
165 SELECT *, unihashes_v2.unihash AS unihash FROM outhashes_v2
166 INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash
167 WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash
168 ORDER BY outhashes_v2.created ASC
169 LIMIT 1
170 """,
171 {
172 "method": method,
173 "outhash": outhash,
174 },
175 )
176 return cursor.fetchone()
177
178 async def get_outhash(self, method, outhash):
179 with closing(self.db.cursor()) as cursor:
180 cursor.execute(
181 """
182 SELECT * FROM outhashes_v2
183 WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash
184 ORDER BY outhashes_v2.created ASC
185 LIMIT 1
186 """,
187 {
188 "method": method,
189 "outhash": outhash,
190 },
191 )
192 return cursor.fetchone()
193
194 async def get_equivalent_for_outhash(self, method, outhash, taskhash):
195 with closing(self.db.cursor()) as cursor:
196 cursor.execute(
197 """
198 SELECT outhashes_v2.taskhash AS taskhash, unihashes_v2.unihash AS unihash FROM outhashes_v2
199 INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash
200 -- Select any matching output hash except the one we just inserted
201 WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash AND outhashes_v2.taskhash!=:taskhash
202 -- Pick the oldest hash
203 ORDER BY outhashes_v2.created ASC
204 LIMIT 1
205 """,
206 {
207 "method": method,
208 "outhash": outhash,
209 "taskhash": taskhash,
210 },
211 )
212 return cursor.fetchone()
213
214 async def get_equivalent(self, method, taskhash):
215 with closing(self.db.cursor()) as cursor:
216 cursor.execute(
217 "SELECT taskhash, method, unihash FROM unihashes_v2 WHERE method=:method AND taskhash=:taskhash",
218 {
219 "method": method,
220 "taskhash": taskhash,
221 },
222 )
223 return cursor.fetchone()
224
225 async def remove(self, condition):
226 def do_remove(columns, table_name, cursor):
227 where = {}
228 for c in columns:
229 if c in condition and condition[c] is not None:
230 where[c] = condition[c]
231
232 if where:
233 query = ("DELETE FROM %s WHERE " % table_name) + " AND ".join(
234 "%s=:%s" % (k, k) for k in where.keys()
235 )
236 cursor.execute(query, where)
237 return cursor.rowcount
238
239 return 0
240
241 count = 0
242 with closing(self.db.cursor()) as cursor:
243 count += do_remove(OUTHASH_TABLE_COLUMNS, "outhashes_v2", cursor)
244 count += do_remove(UNIHASH_TABLE_COLUMNS, "unihashes_v2", cursor)
245 self.db.commit()
246
247 return count
248
249 async def clean_unused(self, oldest):
250 with closing(self.db.cursor()) as cursor:
251 cursor.execute(
252 """
253 DELETE FROM outhashes_v2 WHERE created<:oldest AND NOT EXISTS (
254 SELECT unihashes_v2.id FROM unihashes_v2 WHERE unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash LIMIT 1
255 )
256 """,
257 {
258 "oldest": oldest,
259 },
260 )
261 self.db.commit()
262 return cursor.rowcount
263
264 async def insert_unihash(self, method, taskhash, unihash):
265 with closing(self.db.cursor()) as cursor:
266 prevrowid = cursor.lastrowid
267 cursor.execute(
268 """
269 INSERT OR IGNORE INTO unihashes_v2 (method, taskhash, unihash) VALUES(:method, :taskhash, :unihash)
270 """,
271 {
272 "method": method,
273 "taskhash": taskhash,
274 "unihash": unihash,
275 },
276 )
277 self.db.commit()
278 return cursor.lastrowid != prevrowid
279
280 async def insert_outhash(self, data):
281 data = {k: v for k, v in data.items() if k in OUTHASH_TABLE_COLUMNS}
282 keys = sorted(data.keys())
283 query = "INSERT OR IGNORE INTO outhashes_v2 ({fields}) VALUES({values})".format(
284 fields=", ".join(keys),
285 values=", ".join(":" + k for k in keys),
286 )
287 with closing(self.db.cursor()) as cursor:
288 prevrowid = cursor.lastrowid
289 cursor.execute(query, data)
290 self.db.commit()
291 return cursor.lastrowid != prevrowid
292
293 def _get_user(self, username):
294 with closing(self.db.cursor()) as cursor:
295 cursor.execute(
296 """
297 SELECT username, permissions, token FROM users WHERE username=:username
298 """,
299 {
300 "username": username,
301 },
302 )
303 return cursor.fetchone()
304
305 async def lookup_user_token(self, username):
306 row = self._get_user(username)
307 if row is None:
308 return None, None
309 return map_user(row), row["token"]
310
311 async def lookup_user(self, username):
312 return map_user(self._get_user(username))
313
314 async def set_user_token(self, username, token):
315 with closing(self.db.cursor()) as cursor:
316 cursor.execute(
317 """
318 UPDATE users SET token=:token WHERE username=:username
319 """,
320 {
321 "username": username,
322 "token": token,
323 },
324 )
325 self.db.commit()
326 return cursor.rowcount != 0
327
328 async def set_user_perms(self, username, permissions):
329 with closing(self.db.cursor()) as cursor:
330 cursor.execute(
331 """
332 UPDATE users SET permissions=:permissions WHERE username=:username
333 """,
334 {
335 "username": username,
336 "permissions": " ".join(permissions),
337 },
338 )
339 self.db.commit()
340 return cursor.rowcount != 0
341
342 async def get_all_users(self):
343 with closing(self.db.cursor()) as cursor:
344 cursor.execute("SELECT username, permissions FROM users")
345 return [map_user(r) for r in cursor.fetchall()]
346
347 async def new_user(self, username, permissions, token):
348 with closing(self.db.cursor()) as cursor:
349 try:
350 cursor.execute(
351 """
352 INSERT INTO users (username, token, permissions) VALUES (:username, :token, :permissions)
353 """,
354 {
355 "username": username,
356 "token": token,
357 "permissions": " ".join(permissions),
358 },
359 )
360 self.db.commit()
361 return True
362 except sqlite3.IntegrityError:
363 return False
364
365 async def delete_user(self, username):
366 with closing(self.db.cursor()) as cursor:
367 cursor.execute(
368 """
369 DELETE FROM users WHERE username=:username
370 """,
371 {
372 "username": username,
373 },
374 )
375 self.db.commit()
376 return cursor.rowcount != 0
377
378 async def get_usage(self):
379 usage = {}
380 with closing(self.db.cursor()) as cursor:
381 if self.sqlite_version >= (3, 33):
382 table_name = "sqlite_schema"
383 else:
384 table_name = "sqlite_master"
385
386 cursor.execute(
387 f"""
388 SELECT name FROM {table_name} WHERE type = 'table' AND name NOT LIKE 'sqlite_%'
389 """
390 )
391 for row in cursor.fetchall():
392 cursor.execute(
393 """
394 SELECT COUNT() FROM %s
395 """
396 % row["name"],
397 )
398 usage[row["name"]] = {
399 "rows": cursor.fetchone()[0],
400 }
401 return usage
402
403 async def get_query_columns(self):
404 columns = set()
405 for name, typ, _ in UNIHASH_TABLE_DEFINITION + OUTHASH_TABLE_DEFINITION:
406 if typ.startswith("TEXT"):
407 columns.add(name)
408 return list(columns)