blob: f93cb2c1dd9ea8af64f6024d4634eca73a1215de [file] [log] [blame]
#! /usr/bin/env python3
#
# Copyright (C) 2023 Garmin Ltd.
#
# SPDX-License-Identifier: GPL-2.0-only
#
import sqlite3
import logging
from contextlib import closing
from . import User
logger = logging.getLogger("hashserv.sqlite")
UNIHASH_TABLE_DEFINITION = (
("method", "TEXT NOT NULL", "UNIQUE"),
("taskhash", "TEXT NOT NULL", "UNIQUE"),
("unihash", "TEXT NOT NULL", ""),
)
UNIHASH_TABLE_COLUMNS = tuple(name for name, _, _ in UNIHASH_TABLE_DEFINITION)
OUTHASH_TABLE_DEFINITION = (
("method", "TEXT NOT NULL", "UNIQUE"),
("taskhash", "TEXT NOT NULL", "UNIQUE"),
("outhash", "TEXT NOT NULL", "UNIQUE"),
("created", "DATETIME", ""),
# Optional fields
("owner", "TEXT", ""),
("PN", "TEXT", ""),
("PV", "TEXT", ""),
("PR", "TEXT", ""),
("task", "TEXT", ""),
("outhash_siginfo", "TEXT", ""),
)
OUTHASH_TABLE_COLUMNS = tuple(name for name, _, _ in OUTHASH_TABLE_DEFINITION)
USERS_TABLE_DEFINITION = (
("username", "TEXT NOT NULL", "UNIQUE"),
("token", "TEXT NOT NULL", ""),
("permissions", "TEXT NOT NULL", ""),
)
USERS_TABLE_COLUMNS = tuple(name for name, _, _ in USERS_TABLE_DEFINITION)
def _make_table(cursor, name, definition):
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS {name} (
id INTEGER PRIMARY KEY AUTOINCREMENT,
{fields}
UNIQUE({unique})
)
""".format(
name=name,
fields=" ".join("%s %s," % (name, typ) for name, typ, _ in definition),
unique=", ".join(
name for name, _, flags in definition if "UNIQUE" in flags
),
)
)
def map_user(row):
if row is None:
return None
return User(
username=row["username"],
permissions=set(row["permissions"].split()),
)
class DatabaseEngine(object):
def __init__(self, dbname, sync):
self.dbname = dbname
self.logger = logger
self.sync = sync
async def create(self):
db = sqlite3.connect(self.dbname)
db.row_factory = sqlite3.Row
with closing(db.cursor()) as cursor:
_make_table(cursor, "unihashes_v2", UNIHASH_TABLE_DEFINITION)
_make_table(cursor, "outhashes_v2", OUTHASH_TABLE_DEFINITION)
_make_table(cursor, "users", USERS_TABLE_DEFINITION)
cursor.execute("PRAGMA journal_mode = WAL")
cursor.execute(
"PRAGMA synchronous = %s" % ("NORMAL" if self.sync else "OFF")
)
# Drop old indexes
cursor.execute("DROP INDEX IF EXISTS taskhash_lookup")
cursor.execute("DROP INDEX IF EXISTS outhash_lookup")
cursor.execute("DROP INDEX IF EXISTS taskhash_lookup_v2")
cursor.execute("DROP INDEX IF EXISTS outhash_lookup_v2")
# TODO: Upgrade from tasks_v2?
cursor.execute("DROP TABLE IF EXISTS tasks_v2")
# Create new indexes
cursor.execute(
"CREATE INDEX IF NOT EXISTS taskhash_lookup_v3 ON unihashes_v2 (method, taskhash)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS outhash_lookup_v3 ON outhashes_v2 (method, outhash)"
)
def connect(self, logger):
return Database(logger, self.dbname, self.sync)
class Database(object):
def __init__(self, logger, dbname, sync):
self.dbname = dbname
self.logger = logger
self.db = sqlite3.connect(self.dbname)
self.db.row_factory = sqlite3.Row
with closing(self.db.cursor()) as cursor:
cursor.execute("PRAGMA journal_mode = WAL")
cursor.execute(
"PRAGMA synchronous = %s" % ("NORMAL" if sync else "OFF")
)
cursor.execute("SELECT sqlite_version()")
version = []
for v in cursor.fetchone()[0].split("."):
try:
version.append(int(v))
except ValueError:
version.append(v)
self.sqlite_version = tuple(version)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_value, traceback):
await self.close()
async def close(self):
self.db.close()
async def get_unihash_by_taskhash_full(self, method, taskhash):
with closing(self.db.cursor()) as cursor:
cursor.execute(
"""
SELECT *, unihashes_v2.unihash AS unihash FROM outhashes_v2
INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash
WHERE outhashes_v2.method=:method AND outhashes_v2.taskhash=:taskhash
ORDER BY outhashes_v2.created ASC
LIMIT 1
""",
{
"method": method,
"taskhash": taskhash,
},
)
return cursor.fetchone()
async def get_unihash_by_outhash(self, method, outhash):
with closing(self.db.cursor()) as cursor:
cursor.execute(
"""
SELECT *, unihashes_v2.unihash AS unihash FROM outhashes_v2
INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash
WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash
ORDER BY outhashes_v2.created ASC
LIMIT 1
""",
{
"method": method,
"outhash": outhash,
},
)
return cursor.fetchone()
async def get_outhash(self, method, outhash):
with closing(self.db.cursor()) as cursor:
cursor.execute(
"""
SELECT * FROM outhashes_v2
WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash
ORDER BY outhashes_v2.created ASC
LIMIT 1
""",
{
"method": method,
"outhash": outhash,
},
)
return cursor.fetchone()
async def get_equivalent_for_outhash(self, method, outhash, taskhash):
with closing(self.db.cursor()) as cursor:
cursor.execute(
"""
SELECT outhashes_v2.taskhash AS taskhash, unihashes_v2.unihash AS unihash FROM outhashes_v2
INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash
-- Select any matching output hash except the one we just inserted
WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash AND outhashes_v2.taskhash!=:taskhash
-- Pick the oldest hash
ORDER BY outhashes_v2.created ASC
LIMIT 1
""",
{
"method": method,
"outhash": outhash,
"taskhash": taskhash,
},
)
return cursor.fetchone()
async def get_equivalent(self, method, taskhash):
with closing(self.db.cursor()) as cursor:
cursor.execute(
"SELECT taskhash, method, unihash FROM unihashes_v2 WHERE method=:method AND taskhash=:taskhash",
{
"method": method,
"taskhash": taskhash,
},
)
return cursor.fetchone()
async def remove(self, condition):
def do_remove(columns, table_name, cursor):
where = {}
for c in columns:
if c in condition and condition[c] is not None:
where[c] = condition[c]
if where:
query = ("DELETE FROM %s WHERE " % table_name) + " AND ".join(
"%s=:%s" % (k, k) for k in where.keys()
)
cursor.execute(query, where)
return cursor.rowcount
return 0
count = 0
with closing(self.db.cursor()) as cursor:
count += do_remove(OUTHASH_TABLE_COLUMNS, "outhashes_v2", cursor)
count += do_remove(UNIHASH_TABLE_COLUMNS, "unihashes_v2", cursor)
self.db.commit()
return count
async def clean_unused(self, oldest):
with closing(self.db.cursor()) as cursor:
cursor.execute(
"""
DELETE FROM outhashes_v2 WHERE created<:oldest AND NOT EXISTS (
SELECT unihashes_v2.id FROM unihashes_v2 WHERE unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash LIMIT 1
)
""",
{
"oldest": oldest,
},
)
self.db.commit()
return cursor.rowcount
async def insert_unihash(self, method, taskhash, unihash):
with closing(self.db.cursor()) as cursor:
prevrowid = cursor.lastrowid
cursor.execute(
"""
INSERT OR IGNORE INTO unihashes_v2 (method, taskhash, unihash) VALUES(:method, :taskhash, :unihash)
""",
{
"method": method,
"taskhash": taskhash,
"unihash": unihash,
},
)
self.db.commit()
return cursor.lastrowid != prevrowid
async def insert_outhash(self, data):
data = {k: v for k, v in data.items() if k in OUTHASH_TABLE_COLUMNS}
keys = sorted(data.keys())
query = "INSERT OR IGNORE INTO outhashes_v2 ({fields}) VALUES({values})".format(
fields=", ".join(keys),
values=", ".join(":" + k for k in keys),
)
with closing(self.db.cursor()) as cursor:
prevrowid = cursor.lastrowid
cursor.execute(query, data)
self.db.commit()
return cursor.lastrowid != prevrowid
def _get_user(self, username):
with closing(self.db.cursor()) as cursor:
cursor.execute(
"""
SELECT username, permissions, token FROM users WHERE username=:username
""",
{
"username": username,
},
)
return cursor.fetchone()
async def lookup_user_token(self, username):
row = self._get_user(username)
if row is None:
return None, None
return map_user(row), row["token"]
async def lookup_user(self, username):
return map_user(self._get_user(username))
async def set_user_token(self, username, token):
with closing(self.db.cursor()) as cursor:
cursor.execute(
"""
UPDATE users SET token=:token WHERE username=:username
""",
{
"username": username,
"token": token,
},
)
self.db.commit()
return cursor.rowcount != 0
async def set_user_perms(self, username, permissions):
with closing(self.db.cursor()) as cursor:
cursor.execute(
"""
UPDATE users SET permissions=:permissions WHERE username=:username
""",
{
"username": username,
"permissions": " ".join(permissions),
},
)
self.db.commit()
return cursor.rowcount != 0
async def get_all_users(self):
with closing(self.db.cursor()) as cursor:
cursor.execute("SELECT username, permissions FROM users")
return [map_user(r) for r in cursor.fetchall()]
async def new_user(self, username, permissions, token):
with closing(self.db.cursor()) as cursor:
try:
cursor.execute(
"""
INSERT INTO users (username, token, permissions) VALUES (:username, :token, :permissions)
""",
{
"username": username,
"token": token,
"permissions": " ".join(permissions),
},
)
self.db.commit()
return True
except sqlite3.IntegrityError:
return False
async def delete_user(self, username):
with closing(self.db.cursor()) as cursor:
cursor.execute(
"""
DELETE FROM users WHERE username=:username
""",
{
"username": username,
},
)
self.db.commit()
return cursor.rowcount != 0
async def get_usage(self):
usage = {}
with closing(self.db.cursor()) as cursor:
if self.sqlite_version >= (3, 33):
table_name = "sqlite_schema"
else:
table_name = "sqlite_master"
cursor.execute(
f"""
SELECT name FROM {table_name} WHERE type = 'table' AND name NOT LIKE 'sqlite_%'
"""
)
for row in cursor.fetchall():
cursor.execute(
"""
SELECT COUNT() FROM %s
"""
% row["name"],
)
usage[row["name"]] = {
"rows": cursor.fetchone()[0],
}
return usage
async def get_query_columns(self):
columns = set()
for name, typ, _ in UNIHASH_TABLE_DEFINITION + OUTHASH_TABLE_DEFINITION:
if typ.startswith("TEXT"):
columns.add(name)
return list(columns)