from hashlib import sha256 import logging from secrets import token_hex from based.db import DBConnector, ColumnCondition, ColumnUpdate, ColumnDefinition from db_models import * logger = logging.getLogger(__name__) def bootstrapDB(conn: DBConnector): if not conn.tableExists(META_INFO_TABLE_NAME): conn.createTable( META_INFO_TABLE_NAME, META_INFO_TABLE_SCHEMA, system=True, hidden=True ) if not conn.tableExists(USER_GROUP_TABLE_NAME): conn.createTable(USER_GROUP_TABLE_NAME, USER_GROUP_TABLE_SCHEMA, system=True) if not conn.tableExists(USERS_TABLE_NAME): conn.createTable(USERS_TABLE_NAME, USERS_TABLE_SCHEMA, system=True) if not conn.tableExists(USER_IN_USER_GROUP_JOIN_TABLE_NAME): conn.createTable( USER_IN_USER_GROUP_JOIN_TABLE_NAME, USER_IN_USER_GROUP_JOIN_TABLE_SCHEMA, system=True, ) if not conn.tableExists(TABLE_ACCESS_TABLE_NAME): conn.createTable( TABLE_ACCESS_TABLE_NAME, TABLE_ACCESS_TABLE_SCHEMA, system=True, ) if not conn.tableExists(ASSETS_TABLE_NAME): conn.createTable( ASSETS_TABLE_NAME, ASSETS_TABLE_SCHEMA, system=True, ) if not conn.tableExists(ASSET_ACCESS_TABLE_NAME): conn.createTable( ASSET_ACCESS_TABLE_NAME, ASSET_ACCESS_TABLE_SCHEMA, system=True, ) meta = get_metadata(conn, "admin_created") testAdminCreated = meta and meta.value == "yes" if not testAdminCreated: create_user(conn, "admin", "admin") create_group(conn, "admin") users = list_users(conn) groups = list_groups(conn) set_user_group(conn, users[0].id, groups[0].id) add_metadata(conn, "admin_created", "yes") def add_metadata(conn: DBConnector, name: str, value: str): try: conn.insertIntoTable(META_INFO_TABLE_NAME, {"name": name, "value": value}) return True, None except Exception as e: logger.exception(e) return False, e def get_metadata(conn: DBConnector, name: str): try: metadata = conn.filterFromTable( META_INFO_TABLE_NAME, ["*"], [ColumnCondition("name", name)] ) if len(metadata) == 0: logger.warning(f"Metadata {name} not found") return None return MetaInfo.parse_obj(metadata[0]) except Exception as e: logger.exception(e) return None def create_user(conn: DBConnector, username: str, password: str): try: hashedPwd = sha256(password.encode("utf-8"), usedforsecurity=True).hexdigest() conn.insertIntoTable( USERS_TABLE_NAME, {"username": username, "password": hashedPwd, "access_token": token_hex()}, ) return True, None except Exception as e: logger.exception(e) return False, e def get_user_by_username(conn: DBConnector, username: str): try: users = conn.filterFromTable( USERS_TABLE_NAME, ["*"], [ColumnCondition("username", username)] ) if len(users) == 0: logger.warning(f"User {username} not found") return None return User.parse_obj(users[0]) except Exception as e: logger.exception(e) return None def get_user_by_id(conn: DBConnector, user_id: int): try: users = conn.filterFromTable( USERS_TABLE_NAME, ["*"], [ColumnCondition("id", user_id)] ) if len(users) == 0: logger.warning(f"User with id {user_id} not found") return None return User.parse_obj(users[0]) except Exception as e: logger.exception(e) return None def get_user_by_access_token(conn: DBConnector, access_token: str | None): try: users = conn.filterFromTable( USERS_TABLE_NAME, ["*"], [ColumnCondition("access_token", access_token)] ) if len(users) == 0: logger.warning("Invalid access token") return None, None user = User.parse_obj(users[0]) user_group = get_user_group(conn, user.id) return user, user_group except Exception as e: logger.exception(e) return None, None def check_user(conn: DBConnector, username: str, password: str): try: hashedPwd = sha256(password.encode("utf-8"), usedforsecurity=True).hexdigest() user = conn.filterFromTable( USERS_TABLE_NAME, ["*"], [ ColumnCondition("username", username), ColumnCondition("password", hashedPwd), ], ) if len(user) == 0: logger.warning("Invalid username or password") return None return User.parse_obj(user[0]) except Exception as e: logger.exception(e) return None def create_group(conn: DBConnector, name: str, description: str = ""): try: conn.insertIntoTable( USER_GROUP_TABLE_NAME, {"name": name, "description": description} ) return True, None except Exception as e: logger.exception(e) return False, e def get_group_by_name(conn: DBConnector, name: str): try: groups = conn.filterFromTable( USER_GROUP_TABLE_NAME, ["*"], [ColumnCondition("name", name)] ) if len(groups) == 0: logger.warning(f"Group {name} not found") return None return UserGroup.parse_obj(groups[0]) except Exception as e: logger.exception(e) return None def get_group_by_id(conn: DBConnector, group_id: int): try: groups = conn.filterFromTable( USER_GROUP_TABLE_NAME, ["*"], [ColumnCondition("id", group_id)] ) if len(groups) == 0: logger.warning(f"Group with id {group_id} not found") return None return UserGroup.parse_obj(groups[0]) except Exception as e: logger.exception(e) return None def set_user_group(conn: DBConnector, user_id: int, group_id: int): try: if not conn.filterFromTable( USER_IN_USER_GROUP_JOIN_TABLE_NAME, ["*"], [ ColumnCondition("user_id", user_id), ], ): conn.insertIntoTable( USER_IN_USER_GROUP_JOIN_TABLE_NAME, {"user_id": user_id, "user_group_id": group_id}, ) else: conn.updateDataInTable( USER_IN_USER_GROUP_JOIN_TABLE_NAME, [ ColumnUpdate("user_group_id", group_id), ], [ ColumnCondition("user_id", user_id), ], ) return True, None except Exception as e: logger.exception(e) return False, e def get_user_group(conn: DBConnector, user_id: int): try: grp_usr_joint = conn.filterFromTable( USER_IN_USER_GROUP_JOIN_TABLE_NAME, ["*"], [ColumnCondition("user_id", user_id)], ) if len(grp_usr_joint) == 0: logger.warning(f"User with id {user_id} not found, so no group") return None uiug = UserInUserGroup.parse_obj(grp_usr_joint[0]) return get_group_by_id(conn, uiug.user_group_id) except Exception as e: logger.exception(e) return None def get_group_users(conn: DBConnector, group_id: int) -> list[User]: try: users = conn.filterFromTable( USER_IN_USER_GROUP_JOIN_TABLE_NAME, ["*"], [ColumnCondition("user_group_id", group_id)], ) return [*map(User.parse_obj, users)] except Exception as e: logger.exception(e) return [] def list_users(conn: DBConnector) -> list[User]: try: users = conn.selectFromTable(USERS_TABLE_NAME, ["*"]) return [*map(User.parse_obj, users)] except Exception as e: logger.exception(e) return [] def list_groups(conn: DBConnector): try: groups = conn.selectFromTable(USER_GROUP_TABLE_NAME, ["*"]) return [*map(UserGroup.parse_obj, groups)] except Exception as e: logger.exception(e) return [] def create_table(conn: DBConnector, table_name: str, schema: list[ColumnDefinition]): try: conn.createTable(table_name, schema) return True except Exception as e: logger.exception(e) return False def get_table_access_level( conn: DBConnector, table_name: str, user_id: int ) -> AccessType: try: user_group = get_user_group(conn, user_id) if not user_group: return AccessType.NONE elif user_group.name == "admin": return AccessType.READ_WRITE access = conn.filterFromTable( TABLE_ACCESS_TABLE_NAME, ["*"], [ ColumnCondition("table_name", table_name), ColumnCondition("user_group_id", user_group.id), ], ) if not access: return AccessType.NONE access = TableAccess.parse_obj(access[0]) if access.access_type == "r": return AccessType.READ elif access.access_type == "w": return AccessType.WRITE elif access.access_type == "rw": return AccessType.READ_WRITE else: return AccessType.NONE except Exception as e: logger.exception(e) return AccessType.NONE def get_allowed_columns_for_group( conn: DBConnector, table_name: str, group_id: int ) -> list[str]: try: allowed_columns = conn.filterFromTable( TABLE_ACCESS_TABLE_NAME, ["*"], [ ColumnCondition("table_name", table_name), ColumnCondition("user_group_id", group_id), ], ) if not allowed_columns: return [] allowed_columns = allowed_columns[0]["allowed_columns"] if allowed_columns == "*": return ["*"] return allowed_columns except Exception as e: logger.exception(e) return [] def drop_table(conn: DBConnector, table_name: str): try: if table_name == META_INFO_TABLE_NAME: raise Exception("Cannot drop meta info table") if table_name == USER_GROUP_TABLE_NAME: raise Exception("Cannot drop user group table") if table_name == USERS_TABLE_NAME: raise Exception("Cannot drop users table") if table_name == USER_IN_USER_GROUP_JOIN_TABLE_NAME: raise Exception("Cannot drop user in user group join table") if table_name == TABLE_ACCESS_TABLE_NAME: raise Exception("Cannot drop table access table") if table_name == ASSETS_TABLE_NAME: raise Exception("Cannot drop assets table") if not conn.tableExists(table_name): raise Exception("Table does not exist") conn.dropTable(table_name) return True, None except Exception as e: logger.exception(e) return False, e def create_asset(conn: DBConnector, name: str, description: str, fid: str): try: conn.insertIntoTable( ASSETS_TABLE_NAME, { "name": name, "description": description, "fid": fid, }, ) # TODO: add asset access # TODO: add asset to seaweedfs return True except Exception as e: logger.exception(e) return None def remove_asset(conn: DBConnector, token: str | None, asset_id: int): try: conn.deleteFromTable(ASSETS_TABLE_NAME, [ColumnCondition("id", asset_id)]) # TODO: remove asset access # TODO: remove asset from seaweedfs return True except Exception as e: logger.exception(e) return False def get_asset(conn: DBConnector, token: str | None, fid: str): try: user, group = get_user_by_access_token(conn, token) assets = conn.filterFromTable( ASSETS_TABLE_NAME, ["*"], [ColumnCondition("fid", fid)] ) print(assets) if len(assets) == 0: return None asset = Asset.parse_obj(assets[0]) asset_access = get_asset_access(conn, asset.id) # TODO: check if user has access to asset return asset except Exception as e: logger.exception(e) return None def create_asset_access(conn: DBConnector, asset_id: int, user_group_id: int): try: conn.insertIntoTable( ASSET_ACCESS_TABLE_NAME, {"asset_id": asset_id, "user_group_id": user_group_id}, ) return True except Exception as e: # NOTE: this should not happen ever logger.exception(e) return False def get_asset_access(conn: DBConnector, asset_id: int): try: access = conn.filterFromTable( ASSET_ACCESS_TABLE_NAME, ["*"], [ColumnCondition("asset_id", asset_id)] ) if not access: return AccessType.NONE access = AssetAccess.parse_obj(access[0]) if access.access_type == "r": return AccessType.READ elif access.access_type == "w": return AccessType.WRITE elif access.access_type == "rw": return AccessType.READ_WRITE else: return AccessType.NONE except Exception as e: logger.exception(e) return AccessType.NONE def change_asset_access( conn: DBConnector, asset_id: int, user_group_id: int, access_type: AccessType ): # TODO: implement raise NotImplementedError()