import logging from secrets import token_hex from based.db import DBConnector, ColumnCondition, ColumnUpdate, ColumnDefinition from db_addendum import AssetRefColumnDefinition, UserRefColumnDefinition from db_models import * from models import TableDefinition from secutils import hash_password import utils logger = logging.getLogger(__name__) def bootstrapDB(conn: DBConnector): if not conn.tableExists(META_INFO_TABLE_NAME): logger.info("Creating meta info table") conn.createTable( META_INFO_TABLE_NAME, META_INFO_TABLE_SCHEMA, system=True, hidden=True ) if not conn.tableExists(USER_GROUP_TABLE_NAME): logger.info("Creating user group table") conn.createTable(USER_GROUP_TABLE_NAME, USER_GROUP_TABLE_SCHEMA, system=True) if not conn.tableExists(USERS_TABLE_NAME): logger.info("Creating users table") conn.createTable(USERS_TABLE_NAME, USERS_TABLE_SCHEMA, system=True) if not conn.tableExists(USER_IN_USER_GROUP_JOIN_TABLE_NAME): logger.info("Creating user in user group join table") conn.createTable( USER_IN_USER_GROUP_JOIN_TABLE_NAME, USER_IN_USER_GROUP_JOIN_TABLE_SCHEMA, system=True, ) if not conn.tableExists(TABLE_ACCESS_TABLE_NAME): logger.info("Creating table access table") conn.createTable( TABLE_ACCESS_TABLE_NAME, TABLE_ACCESS_TABLE_SCHEMA, system=True, ) if not conn.tableExists(ASSETS_TABLE_NAME): logger.info("Creating assets table") conn.createTable( ASSETS_TABLE_NAME, ASSETS_TABLE_SCHEMA, system=True, ) meta = get_metadata(conn, "admin_created") testAdminCreated = meta and meta.value == "yes" if not testAdminCreated: logger.info("Creating admin user and group") create_user(conn, "admin", "admin") create_group(conn, "anonymous", "Default group for anonymous access") create_group(conn, "admin", "Administrator group") users = list_users(conn) groups = list_groups(conn) set_user_group(conn, users[0].id, groups[1].id) add_metadata(conn, "admin_created", "yes") def add_metadata(conn: DBConnector, name: str, value: str): try: conn.insertIntoTable(META_INFO_TABLE_NAME, {"name": name, "value": value}) return True, None except Exception as e: logger.exception(e) return False, e def get_metadata(conn: DBConnector, name: str): try: metadata = conn.filterFromTable( META_INFO_TABLE_NAME, ["*"], [ColumnCondition("name", "eq", name)] ) if len(metadata) == 0: logger.warning(f"Metadata {name} not found") return None return MetaInfo.parse_obj(metadata[0]) except Exception as e: logger.exception(e) return None def create_user(conn: DBConnector, username: str, password: str): try: hashedPwd = hash_password(password) conn.insertIntoTable( USERS_TABLE_NAME, {"username": username, "password": hashedPwd, "access_token": token_hex()}, ) return True, None except Exception as e: logger.exception(e) return False, e def update_user(conn: DBConnector, id: int, password: str, access_token: str): try: hashedPwd = hash_password(password) conn.updateDataInTable( USERS_TABLE_NAME, [ ColumnUpdate("password", hashedPwd), ColumnUpdate("access_token", access_token), ], [ ColumnCondition("id", "eq", id), ], ) return True, None except Exception as e: logger.exception(e) return False, e def delete_user( conn: DBConnector, id: int, check_references: bool = True, delete_referencing: bool = False, ): try: if check_references: table_with_user_ref: list[tuple[str, ColumnDefinition]] = [] for table_def in conn.tables(): table = TableDefinition.parse_obj(table_def) columns = utils.parse_columns_from_definition(table.columns) for column in columns: if column is UserRefColumnDefinition: table_with_user_ref.append((table.table_name, column)) if delete_referencing: for table_name, column in table_with_user_ref: conn.deleteFromTable( table_name, [ ColumnCondition(column.name, "eq", id), ], ) elif table_with_user_ref: raise Exception("User is referenced in other tables") conn.deleteFromTable( USERS_TABLE_NAME, [ ColumnCondition("id", "eq", id), ], ) return True, None except Exception as e: logger.exception(e) return False, e def get_user_by_username(conn: DBConnector, username: str): try: users = conn.filterFromTable( USERS_TABLE_NAME, ["*"], [ColumnCondition("username", "eq", username)] ) if len(users) == 0: logger.warning(f"User {username} not found") return None return User.parse_obj(users[0]) except Exception as e: logger.exception(e) return None def get_user_by_id(conn: DBConnector, user_id: int): try: users = conn.filterFromTable( USERS_TABLE_NAME, ["*"], [ColumnCondition("id", "eq", user_id)] ) if len(users) == 0: logger.warning(f"User with id {user_id} not found") return None return User.parse_obj(users[0]) except Exception as e: logger.exception(e) return None def get_user_by_access_token(conn: DBConnector, access_token: str | None): try: users = conn.filterFromTable( USERS_TABLE_NAME, ["*"], [ColumnCondition("access_token", "eq", access_token)], ) if len(users) == 0: logger.warning("Invalid access token") return None, None user = User.parse_obj(users[0]) user_group = get_user_group(conn, user.id) return user, user_group except Exception as e: logger.exception(e) return None, None def check_user(conn: DBConnector, username: str, password: str): try: hashedPwd = hash_password(password) user = conn.filterFromTable( USERS_TABLE_NAME, ["*"], [ ColumnCondition("username", "eq", username), ColumnCondition("password", "eq", hashedPwd), ], ) if len(user) == 0: logger.warning("Invalid username or password") return None return User.parse_obj(user[0]) except Exception as e: logger.exception(e) return None def create_group(conn: DBConnector, name: str, description: str = ""): try: conn.insertIntoTable( USER_GROUP_TABLE_NAME, {"name": name, "description": description} ) return True, None except Exception as e: logger.exception(e) return False, e def get_group_by_name(conn: DBConnector, name: str): try: groups = conn.filterFromTable( USER_GROUP_TABLE_NAME, ["*"], [ColumnCondition("name", "eq", name)] ) if len(groups) == 0: logger.warning(f"Group {name} not found") return None return UserGroup.parse_obj(groups[0]) except Exception as e: logger.exception(e) return None def get_group_by_id(conn: DBConnector, group_id: int): try: groups = conn.filterFromTable( USER_GROUP_TABLE_NAME, ["*"], [ColumnCondition("id", "eq", group_id)] ) if len(groups) == 0: logger.warning(f"Group with id {group_id} not found") return None return UserGroup.parse_obj(groups[0]) except Exception as e: logger.exception(e) return None def set_user_group(conn: DBConnector, user_id: int, group_id: int): try: if not conn.filterFromTable( USER_IN_USER_GROUP_JOIN_TABLE_NAME, ["*"], [ ColumnCondition("user_id", "eq", user_id), ], ): conn.insertIntoTable( USER_IN_USER_GROUP_JOIN_TABLE_NAME, {"user_id": user_id, "user_group_id": group_id}, ) else: conn.updateDataInTable( USER_IN_USER_GROUP_JOIN_TABLE_NAME, [ ColumnUpdate("user_group_id", group_id), ], [ ColumnCondition("user_id", "eq", user_id), ], ) return True, None except Exception as e: logger.exception(e) return False, e def get_user_group(conn: DBConnector, user_id: int): try: grp_usr_joint = conn.filterFromTable( USER_IN_USER_GROUP_JOIN_TABLE_NAME, ["*"], [ColumnCondition("user_id", "eq", user_id)], ) if len(grp_usr_joint) == 0: logger.warning(f"User with id {user_id} not found, so no group") return None u_i_u_g = UserInUserGroup.parse_obj(grp_usr_joint[0]) return get_group_by_id(conn, u_i_u_g.user_group_id) except Exception as e: logger.exception(e) return None def get_group_users(conn: DBConnector, group_id: int) -> list[User]: try: users = conn.filterFromTable( USER_IN_USER_GROUP_JOIN_TABLE_NAME, ["*"], [ColumnCondition("user_group_id", "eq", group_id)], ) return [*map(User.parse_obj, users)] except Exception as e: logger.exception(e) return [] def list_users(conn: DBConnector) -> list[User]: try: users = conn.selectFromTable(USERS_TABLE_NAME, ["*"]) return [*map(User.parse_obj, users)] except Exception as e: logger.exception(e) return [] def list_groups(conn: DBConnector) -> list[UserGroup]: try: groups = conn.selectFromTable(USER_GROUP_TABLE_NAME, ["*"]) return [*map(UserGroup.parse_obj, groups)] except Exception as e: logger.exception(e) return [] def create_table(conn: DBConnector, table_name: str, schema: list[ColumnDefinition]): try: conn.createTable(table_name, schema) return True, None except Exception as e: logger.exception(e) return False, e def get_table_access_level( conn: DBConnector, table_name: str, user_id: int ) -> AccessType: try: user_group = get_user_group(conn, user_id) if not user_group: return AccessType.NONE elif user_group.name == "admin": return AccessType.READ_WRITE access = conn.filterFromTable( TABLE_ACCESS_TABLE_NAME, ["*"], [ ColumnCondition("table_name", "eq", table_name), ColumnCondition("user_group_id", "eq", user_group.id), ], ) if not access: return AccessType.NONE access = TableAccess.parse_obj(access[0]) if access.access_type == "r": return AccessType.READ elif access.access_type == "w": return AccessType.WRITE elif access.access_type == "rw": return AccessType.READ_WRITE else: return AccessType.NONE except Exception as e: logger.exception(e) return AccessType.NONE def get_allowed_columns_for_group( conn: DBConnector, table_name: str, group_id: int ) -> list[str]: try: allowed_columns = conn.filterFromTable( TABLE_ACCESS_TABLE_NAME, ["*"], [ ColumnCondition("table_name", "eq", table_name), ColumnCondition("user_group_id", "eq", group_id), ], ) if not allowed_columns: return [] allowed_columns = allowed_columns[0]["allowed_columns"] if allowed_columns == "*": return ["*"] return allowed_columns except Exception as e: logger.exception(e) return [] def drop_table(conn: DBConnector, table_name: str): try: if table_name == META_INFO_TABLE_NAME: raise Exception("Cannot drop meta info table") if table_name == USER_GROUP_TABLE_NAME: raise Exception("Cannot drop user group table") if table_name == USERS_TABLE_NAME: raise Exception("Cannot drop users table") if table_name == USER_IN_USER_GROUP_JOIN_TABLE_NAME: raise Exception("Cannot drop user in user group join table") if table_name == TABLE_ACCESS_TABLE_NAME: raise Exception("Cannot drop table access table") if table_name == ASSETS_TABLE_NAME: raise Exception("Cannot drop assets table") if not conn.tableExists(table_name): raise Exception("Table does not exist") conn.dropTable(table_name) return True, None except Exception as e: logger.exception(e) return False, e def create_asset(conn: DBConnector, name: str, description: str, fid: str): try: conn.insertIntoTable( ASSETS_TABLE_NAME, { "name": name, "description": description, "fid": fid, }, ) return True, None except Exception as e: logger.exception(e) return False, None def update_asset(conn: DBConnector, asset_id: int, asset_description: str): try: conn.updateDataInTable( ASSETS_TABLE_NAME, [ColumnUpdate("description", asset_description)], [ColumnCondition("id", "eq", asset_id)], ) return True, None except Exception as e: logger.exception(e) return False, e def remove_asset( conn: DBConnector, asset_id: int, check_references: bool = True, delete_referencing: bool = False, ): try: if check_references: table_with_asset_ref: list[tuple[str, ColumnDefinition]] = [] for table_def in conn.tables(): table = TableDefinition.parse_obj(table_def) columns = utils.parse_columns_from_definition(table.columns) for column in columns: if column is AssetRefColumnDefinition: table_with_asset_ref.append((table.table_name, column)) if delete_referencing: for table_name, column in table_with_asset_ref: conn.deleteFromTable( table_name, [ ColumnCondition(column.name, "eq", asset_id), ], ) elif table_with_asset_ref: raise Exception("Asset is referenced in other tables") conn.deleteFromTable(ASSETS_TABLE_NAME, [ColumnCondition("id", "eq", asset_id)]) return True, None except Exception as e: logger.exception(e) return False, e def get_asset(conn: DBConnector, fid: str): try: assets = conn.filterFromTable( ASSETS_TABLE_NAME, ["*"], [ColumnCondition("fid", "eq", fid)] ) if len(assets) == 0: return None asset = Asset.parse_obj(assets[0]) return asset except Exception as e: logger.exception(e) return None def get_asset_by_id(conn: DBConnector, asset_id: int): try: assets = conn.filterFromTable( ASSETS_TABLE_NAME, ["*"], [ColumnCondition("id", "eq", asset_id)] ) if len(assets) == 0: return None asset = Asset.parse_obj(assets[0]) return asset except Exception as e: logger.exception(e) return None