tuuli_backend/dba.py

490 lines
14 KiB
Python

import logging
from secrets import token_hex
from based.db import DBConnector, ColumnCondition, ColumnUpdate, ColumnDefinition
from db_models import *
from secutils import hash_password
logger = logging.getLogger(__name__)
def bootstrapDB(conn: DBConnector):
if not conn.tableExists(META_INFO_TABLE_NAME):
logger.info("Creating meta info table")
conn.createTable(
META_INFO_TABLE_NAME, META_INFO_TABLE_SCHEMA, system=True, hidden=True
)
if not conn.tableExists(USER_GROUP_TABLE_NAME):
logger.info("Creating user group table")
conn.createTable(USER_GROUP_TABLE_NAME, USER_GROUP_TABLE_SCHEMA, system=True)
if not conn.tableExists(USERS_TABLE_NAME):
logger.info("Creating users table")
conn.createTable(USERS_TABLE_NAME, USERS_TABLE_SCHEMA, system=True)
if not conn.tableExists(USER_IN_USER_GROUP_JOIN_TABLE_NAME):
logger.info("Creating user in user group join table")
conn.createTable(
USER_IN_USER_GROUP_JOIN_TABLE_NAME,
USER_IN_USER_GROUP_JOIN_TABLE_SCHEMA,
system=True,
)
if not conn.tableExists(TABLE_ACCESS_TABLE_NAME):
logger.info("Creating table access table")
conn.createTable(
TABLE_ACCESS_TABLE_NAME,
TABLE_ACCESS_TABLE_SCHEMA,
system=True,
)
if not conn.tableExists(ASSETS_TABLE_NAME):
logger.info("Creating assets table")
conn.createTable(
ASSETS_TABLE_NAME,
ASSETS_TABLE_SCHEMA,
system=True,
)
if not conn.tableExists(ASSET_ACCESS_TABLE_NAME):
logger.info("Creating asset access table")
conn.createTable(
ASSET_ACCESS_TABLE_NAME,
ASSET_ACCESS_TABLE_SCHEMA,
system=True,
)
meta = get_metadata(conn, "admin_created")
testAdminCreated = meta and meta.value == "yes"
if not testAdminCreated:
logger.info("Creating admin user and group")
create_user(conn, "admin", "admin")
create_group(conn, "admin")
users = list_users(conn)
groups = list_groups(conn)
set_user_group(conn, users[0].id, groups[0].id)
add_metadata(conn, "admin_created", "yes")
def add_metadata(conn: DBConnector, name: str, value: str):
try:
conn.insertIntoTable(META_INFO_TABLE_NAME, {"name": name, "value": value})
return True, None
except Exception as e:
logger.exception(e)
return False, e
def get_metadata(conn: DBConnector, name: str):
try:
metadata = conn.filterFromTable(
META_INFO_TABLE_NAME, ["*"], [ColumnCondition("name", "eq", name)]
)
if len(metadata) == 0:
logger.warning(f"Metadata {name} not found")
return None
return MetaInfo.parse_obj(metadata[0])
except Exception as e:
logger.exception(e)
return None
def create_user(conn: DBConnector, username: str, password: str):
try:
hashedPwd = hash_password(password)
conn.insertIntoTable(
USERS_TABLE_NAME,
{"username": username, "password": hashedPwd, "access_token": token_hex()},
)
return True, None
except Exception as e:
logger.exception(e)
return False, e
def update_user(conn: DBConnector, id: int, password: str, access_token: str):
try:
hashedPwd = hash_password(password)
conn.updateDataInTable(
USERS_TABLE_NAME,
[
ColumnUpdate("password", hashedPwd),
ColumnUpdate("access_token", access_token),
],
[
ColumnCondition("id", "eq", id),
],
)
return True, None
except Exception as e:
logger.exception(e)
return False, e
def get_user_by_username(conn: DBConnector, username: str):
try:
users = conn.filterFromTable(
USERS_TABLE_NAME, ["*"], [ColumnCondition("username", "eq", username)]
)
if len(users) == 0:
logger.warning(f"User {username} not found")
return None
return User.parse_obj(users[0])
except Exception as e:
logger.exception(e)
return None
def get_user_by_id(conn: DBConnector, user_id: int):
try:
users = conn.filterFromTable(
USERS_TABLE_NAME, ["*"], [ColumnCondition("id", "eq", user_id)]
)
if len(users) == 0:
logger.warning(f"User with id {user_id} not found")
return None
return User.parse_obj(users[0])
except Exception as e:
logger.exception(e)
return None
def get_user_by_access_token(conn: DBConnector, access_token: str | None):
try:
users = conn.filterFromTable(
USERS_TABLE_NAME,
["*"],
[ColumnCondition("access_token", "eq", access_token)],
)
if len(users) == 0:
logger.warning("Invalid access token")
return None, None
user = User.parse_obj(users[0])
user_group = get_user_group(conn, user.id)
return user, user_group
except Exception as e:
logger.exception(e)
return None, None
def check_user(conn: DBConnector, username: str, password: str):
try:
hashedPwd = hash_password(password)
user = conn.filterFromTable(
USERS_TABLE_NAME,
["*"],
[
ColumnCondition("username", "eq", username),
ColumnCondition("password", "eq", hashedPwd),
],
)
if len(user) == 0:
logger.warning("Invalid username or password")
return None
return User.parse_obj(user[0])
except Exception as e:
logger.exception(e)
return None
def create_group(conn: DBConnector, name: str, description: str = ""):
try:
conn.insertIntoTable(
USER_GROUP_TABLE_NAME, {"name": name, "description": description}
)
return True, None
except Exception as e:
logger.exception(e)
return False, e
def get_group_by_name(conn: DBConnector, name: str):
try:
groups = conn.filterFromTable(
USER_GROUP_TABLE_NAME, ["*"], [ColumnCondition("name", "eq", name)]
)
if len(groups) == 0:
logger.warning(f"Group {name} not found")
return None
return UserGroup.parse_obj(groups[0])
except Exception as e:
logger.exception(e)
return None
def get_group_by_id(conn: DBConnector, group_id: int):
try:
groups = conn.filterFromTable(
USER_GROUP_TABLE_NAME, ["*"], [ColumnCondition("id", "eq", group_id)]
)
if len(groups) == 0:
logger.warning(f"Group with id {group_id} not found")
return None
return UserGroup.parse_obj(groups[0])
except Exception as e:
logger.exception(e)
return None
def set_user_group(conn: DBConnector, user_id: int, group_id: int):
try:
if not conn.filterFromTable(
USER_IN_USER_GROUP_JOIN_TABLE_NAME,
["*"],
[
ColumnCondition("user_id", "eq", user_id),
],
):
conn.insertIntoTable(
USER_IN_USER_GROUP_JOIN_TABLE_NAME,
{"user_id": user_id, "user_group_id": group_id},
)
else:
conn.updateDataInTable(
USER_IN_USER_GROUP_JOIN_TABLE_NAME,
[
ColumnUpdate("user_group_id", group_id),
],
[
ColumnCondition("user_id", "eq", user_id),
],
)
return True, None
except Exception as e:
logger.exception(e)
return False, e
def get_user_group(conn: DBConnector, user_id: int):
try:
grp_usr_joint = conn.filterFromTable(
USER_IN_USER_GROUP_JOIN_TABLE_NAME,
["*"],
[ColumnCondition("user_id", "eq", user_id)],
)
if len(grp_usr_joint) == 0:
logger.warning(f"User with id {user_id} not found, so no group")
return None
uiug = UserInUserGroup.parse_obj(grp_usr_joint[0])
return get_group_by_id(conn, uiug.user_group_id)
except Exception as e:
logger.exception(e)
return None
def get_group_users(conn: DBConnector, group_id: int) -> list[User]:
try:
users = conn.filterFromTable(
USER_IN_USER_GROUP_JOIN_TABLE_NAME,
["*"],
[ColumnCondition("user_group_id", "eq", group_id)],
)
return [*map(User.parse_obj, users)]
except Exception as e:
logger.exception(e)
return []
def list_users(conn: DBConnector) -> list[User]:
try:
users = conn.selectFromTable(USERS_TABLE_NAME, ["*"])
return [*map(User.parse_obj, users)]
except Exception as e:
logger.exception(e)
return []
def list_groups(conn: DBConnector):
try:
groups = conn.selectFromTable(USER_GROUP_TABLE_NAME, ["*"])
return [*map(UserGroup.parse_obj, groups)]
except Exception as e:
logger.exception(e)
return []
def create_table(conn: DBConnector, table_name: str, schema: list[ColumnDefinition]):
try:
conn.createTable(table_name, schema)
return True
except Exception as e:
logger.exception(e)
return False
def get_table_access_level(
conn: DBConnector, table_name: str, user_id: int
) -> AccessType:
try:
user_group = get_user_group(conn, user_id)
if not user_group:
return AccessType.NONE
elif user_group.name == "admin":
return AccessType.READ_WRITE
access = conn.filterFromTable(
TABLE_ACCESS_TABLE_NAME,
["*"],
[
ColumnCondition("table_name", "eq", table_name),
ColumnCondition("user_group_id", "eq", user_group.id),
],
)
if not access:
return AccessType.NONE
access = TableAccess.parse_obj(access[0])
if access.access_type == "r":
return AccessType.READ
elif access.access_type == "w":
return AccessType.WRITE
elif access.access_type == "rw":
return AccessType.READ_WRITE
else:
return AccessType.NONE
except Exception as e:
logger.exception(e)
return AccessType.NONE
def get_allowed_columns_for_group(
conn: DBConnector, table_name: str, group_id: int
) -> list[str]:
try:
allowed_columns = conn.filterFromTable(
TABLE_ACCESS_TABLE_NAME,
["*"],
[
ColumnCondition("table_name", "eq", table_name),
ColumnCondition("user_group_id", "eq", group_id),
],
)
if not allowed_columns:
return []
allowed_columns = allowed_columns[0]["allowed_columns"]
if allowed_columns == "*":
return ["*"]
return allowed_columns
except Exception as e:
logger.exception(e)
return []
def drop_table(conn: DBConnector, table_name: str):
try:
if table_name == META_INFO_TABLE_NAME:
raise Exception("Cannot drop meta info table")
if table_name == USER_GROUP_TABLE_NAME:
raise Exception("Cannot drop user group table")
if table_name == USERS_TABLE_NAME:
raise Exception("Cannot drop users table")
if table_name == USER_IN_USER_GROUP_JOIN_TABLE_NAME:
raise Exception("Cannot drop user in user group join table")
if table_name == TABLE_ACCESS_TABLE_NAME:
raise Exception("Cannot drop table access table")
if table_name == ASSETS_TABLE_NAME:
raise Exception("Cannot drop assets table")
if not conn.tableExists(table_name):
raise Exception("Table does not exist")
conn.dropTable(table_name)
return True, None
except Exception as e:
logger.exception(e)
return False, e
def create_asset(conn: DBConnector, name: str, description: str, fid: str):
try:
conn.insertIntoTable(
ASSETS_TABLE_NAME,
{
"name": name,
"description": description,
"fid": fid,
},
)
# TODO: add asset access
# TODO: add asset to minio
return True
except Exception as e:
logger.exception(e)
return None
def remove_asset(conn: DBConnector, token: str | None, asset_id: int):
try:
conn.deleteFromTable(ASSETS_TABLE_NAME, [ColumnCondition("id", "eq", asset_id)])
# TODO: remove asset access
# TODO: remove asset from minio
return True
except Exception as e:
logger.exception(e)
return False
def get_asset(conn: DBConnector, token: str | None, fid: str):
try:
user, group = get_user_by_access_token(conn, token)
assets = conn.filterFromTable(
ASSETS_TABLE_NAME, ["*"], [ColumnCondition("fid", "eq", fid)]
)
print(assets)
if len(assets) == 0:
return None
asset = Asset.parse_obj(assets[0])
asset_access = get_asset_access(conn, asset.id)
# TODO: check if user has access to asset
return asset
except Exception as e:
logger.exception(e)
return None
def create_asset_access(conn: DBConnector, asset_id: int, user_group_id: int):
try:
conn.insertIntoTable(
ASSET_ACCESS_TABLE_NAME,
{"asset_id": asset_id, "user_group_id": user_group_id},
)
return True
except Exception as e:
# NOTE: this should not happen ever
logger.exception(e)
return False
def get_asset_access(conn: DBConnector, asset_id: int):
try:
access = conn.filterFromTable(
ASSET_ACCESS_TABLE_NAME,
["*"],
[ColumnCondition("asset_id", "eq", asset_id)],
)
if not access:
return AccessType.NONE
access = AssetAccess.parse_obj(access[0])
if access.access_type == "r":
return AccessType.READ
elif access.access_type == "w":
return AccessType.WRITE
elif access.access_type == "rw":
return AccessType.READ_WRITE
else:
return AccessType.NONE
except Exception as e:
logger.exception(e)
return AccessType.NONE
def change_asset_access(
conn: DBConnector, asset_id: int, user_group_id: int, access_type: AccessType
):
# TODO: implement
raise NotImplementedError()