diff --git a/app.py b/app.py new file mode 100644 index 0000000..dcb509d --- /dev/null +++ b/app.py @@ -0,0 +1,302 @@ +import io +from typing import Any +from fastapi import FastAPI, status, Header, UploadFile +from starlette.responses import StreamingResponse +from fastapi.middleware.cors import CORSMiddleware +from based import db +import psycopg +from hashlib import sha256 +from secrets import token_hex +from minio import Minio +from minio.error import S3Error +from minio.helpers import ObjectWriteResult +from urllib3 import HTTPResponse + +from dba import * +from models import ( + ColumnsDefinitionList, + ItemDeletionDefinitionList, + ItemsFieldSelectorList, + UserDefinition, +) +from utils import ( + check_if_admin_access_token, + parse_columns_from_definition, +) + +conninfo = "postgresql://postgres:asarch6122@localhost" +connector = db.DBConnector(conninfo) + +bootstrapDB(connector) + +BUCKET_NAME = "tuuli-files" +minioClient = Minio( + "localhost:8090", + access_key="mxR0F5PK8CpCM8SA", + secret_key="yFJsG70xLU3BiIMslinz6dhqKHqNpUc6", + secure=False, +) +found = minioClient.bucket_exists(BUCKET_NAME) +if found: + print(f"Bucket '{BUCKET_NAME}' already exists") +else: + minioClient.make_bucket(BUCKET_NAME) + +app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +@app.get("/api/listTables") +async def listTables(access_token: str | None = Header(default=None)): + is_admin = check_if_admin_access_token(connector, access_token) + if not is_admin: + return {"error": "Not allowed"} + + tables = connector.tables() + return tables + + +@app.post("/api/createTable/{tableName}") +async def createTable( + tableName: str, + columns: ColumnsDefinitionList, + access_token: str | None = Header(default=None), +): + is_admin = check_if_admin_access_token(connector, access_token) + if not is_admin: + return {"error": "Not allowed"} + + try: + columnsDefinition = parse_columns_from_definition(",".join(columns.columns)) + create_table(connector, tableName, columnsDefinition) + except psycopg.errors.UniqueViolation: + return {"error": "Username already exists"} + except Exception as e: + return {"error": str(e)} + + return {"ok": True} + + +@app.post("/api/dropTable/{tableName}") +async def dropTable( + tableName: str, + access_token: str | None = Header(default=None), +): + is_admin = check_if_admin_access_token(connector, access_token) + if not is_admin: + return {"error": "Not allowed"} + + try: + drop_table(connector, tableName) + except Exception as e: + return {"error": str(e)} + + return {"ok": True} + + +@app.post("/api/createUser") +async def createUser( + user: UserDefinition, + access_token: str | None = Header(default=None), +): + is_admin = check_if_admin_access_token(connector, access_token) + if not is_admin: + return {"error": "Not allowed"} + + try: + create_user(connector, user.username, user.password) + except psycopg.errors.UniqueViolation: + return {"error": "Username already exists"} + except Exception as e: + return {"error": str(e)} + + return {"ok": True} + + +@app.post("/items/{tableName}") +async def items( + tableName: str, + selector: ItemsFieldSelectorList, + access_token: str | None = Header(default=None), +): + table_info = connector.getTable(tableName) + if not table_info: + return {"error": "Not allowed"} + + is_admin = check_if_admin_access_token(connector, access_token) + if table_info["system"] and not is_admin: + return {"error": "Not allowed"} + + columns = parse_columns_from_definition(table_info["columns"]) + columnsNames = set(column.name for column in columns) + userSelectedColumns = list(set(selector.fields)) if selector.fields else ["*"] + if userSelectedColumns != ["*"]: + for column in userSelectedColumns: + if column not in columnsNames: + return {"error": f"Column {column} not found on table {tableName}"} + else: + userSelectedColumns = columnsNames + + user, group = get_user_by_access_token(connector, access_token) + if not user: + return {"error": "Not allowed"} + + if not is_admin: + allowedColumns = get_allowed_columns_for_group( + connector, tableName, group.id if group else -1 + ) + if not allowedColumns: + return {"error": "Not allowed"} + elif len(allowedColumns) == 1 and allowedColumns[0] == "*": + pass + else: + for column in userSelectedColumns: + if column not in allowedColumns: + return {"error": "Not allowed"} + + print(columnsNames) + print(selector) + + table_items = connector.selectFromTable( + tableName, selector.fields if selector.fields else ["*"] + ) + + return table_items + + +@app.post("/items/{tableName}/+") +async def itemsCreate( + tableName: str, + item: dict[str, str], + access_token: str | None = Header(default=None), +): + table_info = connector.getTable(tableName) + if not table_info: + return {"error": "Not found"} + + is_admin = check_if_admin_access_token(connector, access_token) + if table_info["system"] and not is_admin: + return {"error": "Not allowed"} + + user, group = get_user_by_access_token(connector, access_token) + if not is_admin: + allowedColumns = get_allowed_columns_for_group( + connector, tableName, group.id if group else -1 + ) + if not allowedColumns: + return {"error": "Not allowed"} + elif len(allowedColumns) == 1 and allowedColumns[0] == "*": + pass + else: + for column in item: + if column not in allowedColumns: + return {"error": "Not allowed"} + + try: + connector.insertIntoTable(tableName, item) + except psycopg.errors.UndefinedColumn: + return {"error": "Column not found"} + except psycopg.errors.UniqueViolation: + return {"error": "Unique constraint violation"} + except Exception as e: + return {"error": str(e)} + + return {"ok": True} + + +@app.post("/items/{tableName}/-") +async def itemsDelete( + tableName: str, + deleteWhere: ItemDeletionDefinitionList, + access_token: str | None = Header(default=None), +): + table_info = connector.getTable(tableName) + if not table_info: + return {"error": "Not found"} + + is_admin = check_if_admin_access_token(connector, access_token) + if table_info["system"] and not is_admin: + return {"error": "Not allowed"} + + user, group = get_user_by_access_token(connector, access_token) + if not is_admin: + allowedColumns = get_allowed_columns_for_group( + connector, tableName, group.id if group else -1 + ) + if not allowedColumns: + return {"error": "Not allowed"} + elif len(allowedColumns) == 1 and allowedColumns[0] == "*": + pass + else: + return {"error": "Not allowed"} + + try: + connector.deleteFromTable( + tableName, + [ + ColumnCondition(where.name, where.value, where.isString, where.isLike) + for where in deleteWhere.defs + ], + ) + except Exception as e: + return {"error": str(e)} + + return {"ok": True} + + +@app.get("/assets/{fid}") +async def getAsset(fid: str, access_token: str | None = Header(default=None)): + asset = get_asset(connector, access_token, fid) + if not asset: + return status.HTTP_404_NOT_FOUND + + response: HTTPResponse | None = None + try: + response = minioClient.get_object(BUCKET_NAME, asset.name, version_id=asset.fid) + if response is None: + return status.HTTP_404_NOT_FOUND + + return StreamingResponse( + content=io.BytesIO(response.data), + media_type=response.getheader("Content-Type"), + status_code=status.HTTP_200_OK, + ) + finally: + if response is not None: + response.close() + response.release_conn() + + +@app.post("/assets/+") +async def createAsset( + asset: UploadFile, + access_token: str | None = Header(default=None), +): + user, _ = get_user_by_access_token(connector, access_token) + if not user: + return {"error": "Not allowed"} + + filename = asset.filename + if not filename: + filename = f"unnamed" + filename = f"{token_hex()}_{filename}" + + result: ObjectWriteResult = minioClient.put_object( + BUCKET_NAME, + filename, + data=asset.file, + content_type=( + asset.content_type if asset.content_type else "application/octet-stream" + ), + length=asset.size, + ) + + if not create_asset(connector, filename, "", str(result.version_id)): + return {"error": "Failed to create asset"} + return {"ok": True, "fid": result.version_id} diff --git a/db_models.py b/db_models.py new file mode 100644 index 0000000..ae9bc88 --- /dev/null +++ b/db_models.py @@ -0,0 +1,137 @@ +import enum +from based.columns import ( + PrimarySerialColumnDefinition, + TextColumnDefinition, + IntegerColumnDefinition, + make_column_unique, +) +from pydantic import BaseModel + + +class AccessType(enum.Enum): + READ = "read" + WRITE = "write" + READ_WRITE = "read_write" + NONE = "none" + + +META_INFO_TABLE_NAME = "meta_info" +META_INFO_TABLE_SCHEMA = [ + PrimarySerialColumnDefinition("id"), + make_column_unique(TextColumnDefinition("name")), + TextColumnDefinition("value"), + TextColumnDefinition("allowed_columns", default="*"), +] + + +class MetaInfo(BaseModel): + id: int + name: str + value: str + allowed_columns: str + + +USER_GROUP_TABLE_NAME = "user_group" +USER_GROUP_TABLE_SCHEMA = [ + PrimarySerialColumnDefinition("id"), + make_column_unique(TextColumnDefinition("name")), + TextColumnDefinition("description", default=""), +] + + +class UserGroup(BaseModel): + id: int + name: str + description: str + + +USERS_TABLE_NAME = "users" +USERS_TABLE_SCHEMA = [ + PrimarySerialColumnDefinition("id"), + make_column_unique(TextColumnDefinition("username")), + TextColumnDefinition("password"), + TextColumnDefinition("access_token"), +] + + +class User(BaseModel): + id: int + username: str + password: str + access_token: str + + def to_dict(self, safe=True): + d = { + "id": self.id, + "username": self.username, + } + if not safe: + d["access_token"] = self.access_token + d["password"] = self.password + return d + + +USER_IN_USER_GROUP_JOIN_TABLE_NAME = "user_in_user_group" +USER_IN_USER_GROUP_JOIN_TABLE_SCHEMA = [ + PrimarySerialColumnDefinition("id"), + IntegerColumnDefinition("user_id"), + IntegerColumnDefinition("user_group_id"), +] + + +class UserInUserGroup(BaseModel): + id: int + user_id: int + user_group_id: int + + +TABLE_ACCESS_TABLE_NAME = "table_access" +TABLE_ACCESS_TABLE_SCHEMA = [ + PrimarySerialColumnDefinition("id"), + IntegerColumnDefinition("user_group_id"), + TextColumnDefinition("table_name"), + TextColumnDefinition("access_type"), + TextColumnDefinition("allowed_columns", default="*"), +] + + +class TableAccess(BaseModel): + id: int + user_group_id: int + table_name: str + access_type: str + allowed_columns: str + + +ASSETS_TABLE_NAME = "assets" +ASSETS_TABLE_SCHEMA = [ + PrimarySerialColumnDefinition("id"), + TextColumnDefinition("name"), + TextColumnDefinition("description", default=""), + TextColumnDefinition("fid"), + TextColumnDefinition("catalog", default="/root"), +] + + +class Asset(BaseModel): + id: int + name: str + description: str + fid: str + catalog: str + + +ASSET_ACCESS_TABLE_NAME = "asset_access" +ASSET_ACCESS_TABLE_SCHEMA = [ + PrimarySerialColumnDefinition("id"), + IntegerColumnDefinition("user_group_id"), + IntegerColumnDefinition("asset_id"), + TextColumnDefinition("access_type"), +] + + +class AssetAccess(BaseModel): + id: int + user_group_id: int + asset_id: int + access_type: str diff --git a/dba.py b/dba.py new file mode 100644 index 0000000..6a27b1d --- /dev/null +++ b/dba.py @@ -0,0 +1,459 @@ +from hashlib import sha256 +import logging +from secrets import token_hex +from based.db import DBConnector, ColumnCondition, ColumnUpdate, ColumnDefinition +from db_models import * + +logger = logging.getLogger(__name__) + + +def bootstrapDB(conn: DBConnector): + if not conn.tableExists(META_INFO_TABLE_NAME): + conn.createTable( + META_INFO_TABLE_NAME, META_INFO_TABLE_SCHEMA, system=True, hidden=True + ) + + if not conn.tableExists(USER_GROUP_TABLE_NAME): + conn.createTable(USER_GROUP_TABLE_NAME, USER_GROUP_TABLE_SCHEMA, system=True) + + if not conn.tableExists(USERS_TABLE_NAME): + conn.createTable(USERS_TABLE_NAME, USERS_TABLE_SCHEMA, system=True) + + if not conn.tableExists(USER_IN_USER_GROUP_JOIN_TABLE_NAME): + conn.createTable( + USER_IN_USER_GROUP_JOIN_TABLE_NAME, + USER_IN_USER_GROUP_JOIN_TABLE_SCHEMA, + system=True, + ) + + if not conn.tableExists(TABLE_ACCESS_TABLE_NAME): + conn.createTable( + TABLE_ACCESS_TABLE_NAME, + TABLE_ACCESS_TABLE_SCHEMA, + system=True, + ) + + if not conn.tableExists(ASSETS_TABLE_NAME): + conn.createTable( + ASSETS_TABLE_NAME, + ASSETS_TABLE_SCHEMA, + system=True, + ) + + if not conn.tableExists(ASSET_ACCESS_TABLE_NAME): + conn.createTable( + ASSET_ACCESS_TABLE_NAME, + ASSET_ACCESS_TABLE_SCHEMA, + system=True, + ) + + meta = get_metadata(conn, "admin_created") + testAdminCreated = meta and meta.value == "yes" + if not testAdminCreated: + create_user(conn, "admin", "admin") + create_group(conn, "admin") + + users = list_users(conn) + groups = list_groups(conn) + + set_user_group(conn, users[0].id, groups[0].id) + add_metadata(conn, "admin_created", "yes") + + +def add_metadata(conn: DBConnector, name: str, value: str): + try: + conn.insertIntoTable(META_INFO_TABLE_NAME, {"name": name, "value": value}) + return True, None + except Exception as e: + logger.exception(e) + return False, e + + +def get_metadata(conn: DBConnector, name: str): + try: + metadata = conn.filterFromTable( + META_INFO_TABLE_NAME, ["*"], [ColumnCondition("name", name)] + ) + if len(metadata) == 0: + logger.warning(f"Metadata {name} not found") + return None + return MetaInfo.parse_obj(metadata[0]) + except Exception as e: + logger.exception(e) + return None + + +def create_user(conn: DBConnector, username: str, password: str): + try: + hashedPwd = sha256(password.encode("utf-8"), usedforsecurity=True).hexdigest() + conn.insertIntoTable( + USERS_TABLE_NAME, + {"username": username, "password": hashedPwd, "access_token": token_hex()}, + ) + return True, None + except Exception as e: + logger.exception(e) + return False, e + + +def get_user_by_username(conn: DBConnector, username: str): + try: + users = conn.filterFromTable( + USERS_TABLE_NAME, ["*"], [ColumnCondition("username", username)] + ) + if len(users) == 0: + logger.warning(f"User {username} not found") + return None + return User.parse_obj(users[0]) + except Exception as e: + logger.exception(e) + return None + + +def get_user_by_id(conn: DBConnector, user_id: int): + try: + users = conn.filterFromTable( + USERS_TABLE_NAME, ["*"], [ColumnCondition("id", user_id)] + ) + if len(users) == 0: + logger.warning(f"User with id {user_id} not found") + return None + return User.parse_obj(users[0]) + except Exception as e: + logger.exception(e) + return None + + +def get_user_by_access_token(conn: DBConnector, access_token: str | None): + try: + users = conn.filterFromTable( + USERS_TABLE_NAME, ["*"], [ColumnCondition("access_token", access_token)] + ) + if len(users) == 0: + logger.warning("Invalid access token") + return None, None + + user = User.parse_obj(users[0]) + user_group = get_user_group(conn, user.id) + + return user, user_group + except Exception as e: + logger.exception(e) + return None, None + + +def check_user(conn: DBConnector, username: str, password: str): + try: + hashedPwd = sha256(password.encode("utf-8"), usedforsecurity=True).hexdigest() + user = conn.filterFromTable( + USERS_TABLE_NAME, + ["*"], + [ + ColumnCondition("username", username), + ColumnCondition("password", hashedPwd), + ], + ) + if len(user) == 0: + logger.warning("Invalid username or password") + return None + return User.parse_obj(user[0]) + except Exception as e: + logger.exception(e) + return None + + +def create_group(conn: DBConnector, name: str, description: str = ""): + try: + conn.insertIntoTable( + USER_GROUP_TABLE_NAME, {"name": name, "description": description} + ) + return True, None + except Exception as e: + logger.exception(e) + return False, e + + +def get_group_by_name(conn: DBConnector, name: str): + try: + groups = conn.filterFromTable( + USER_GROUP_TABLE_NAME, ["*"], [ColumnCondition("name", name)] + ) + if len(groups) == 0: + logger.warning(f"Group {name} not found") + return None + return UserGroup.parse_obj(groups[0]) + except Exception as e: + logger.exception(e) + return None + + +def get_group_by_id(conn: DBConnector, group_id: int): + try: + groups = conn.filterFromTable( + USER_GROUP_TABLE_NAME, ["*"], [ColumnCondition("id", group_id)] + ) + if len(groups) == 0: + logger.warning(f"Group with id {group_id} not found") + return None + return UserGroup.parse_obj(groups[0]) + except Exception as e: + logger.exception(e) + return None + + +def set_user_group(conn: DBConnector, user_id: int, group_id: int): + try: + if not conn.filterFromTable( + USER_IN_USER_GROUP_JOIN_TABLE_NAME, + ["*"], + [ + ColumnCondition("user_id", user_id), + ], + ): + conn.insertIntoTable( + USER_IN_USER_GROUP_JOIN_TABLE_NAME, + {"user_id": user_id, "user_group_id": group_id}, + ) + else: + conn.updateDataInTable( + USER_IN_USER_GROUP_JOIN_TABLE_NAME, + [ + ColumnUpdate("user_group_id", group_id), + ], + [ + ColumnCondition("user_id", user_id), + ], + ) + return True, None + except Exception as e: + logger.exception(e) + return False, e + + +def get_user_group(conn: DBConnector, user_id: int): + try: + grp_usr_joint = conn.filterFromTable( + USER_IN_USER_GROUP_JOIN_TABLE_NAME, + ["*"], + [ColumnCondition("user_id", user_id)], + ) + if len(grp_usr_joint) == 0: + logger.warning(f"User with id {user_id} not found, so no group") + return None + + uiug = UserInUserGroup.parse_obj(grp_usr_joint[0]) + return get_group_by_id(conn, uiug.user_group_id) + except Exception as e: + logger.exception(e) + return None + + +def get_group_users(conn: DBConnector, group_id: int) -> list[User]: + try: + users = conn.filterFromTable( + USER_IN_USER_GROUP_JOIN_TABLE_NAME, + ["*"], + [ColumnCondition("user_group_id", group_id)], + ) + return [*map(User.parse_obj, users)] + except Exception as e: + logger.exception(e) + return [] + + +def list_users(conn: DBConnector) -> list[User]: + try: + users = conn.selectFromTable(USERS_TABLE_NAME, ["*"]) + return [*map(User.parse_obj, users)] + except Exception as e: + logger.exception(e) + return [] + + +def list_groups(conn: DBConnector): + try: + groups = conn.selectFromTable(USER_GROUP_TABLE_NAME, ["*"]) + return [*map(UserGroup.parse_obj, groups)] + except Exception as e: + logger.exception(e) + return [] + + +def create_table(conn: DBConnector, table_name: str, schema: list[ColumnDefinition]): + try: + conn.createTable(table_name, schema) + return True + except Exception as e: + logger.exception(e) + return False + + +def get_table_access_level( + conn: DBConnector, table_name: str, user_id: int +) -> AccessType: + try: + user_group = get_user_group(conn, user_id) + if not user_group: + return AccessType.NONE + elif user_group.name == "admin": + return AccessType.READ_WRITE + + access = conn.filterFromTable( + TABLE_ACCESS_TABLE_NAME, + ["*"], + [ + ColumnCondition("table_name", table_name), + ColumnCondition("user_group_id", user_group.id), + ], + ) + if not access: + return AccessType.NONE + access = TableAccess.parse_obj(access[0]) + if access.access_type == "r": + return AccessType.READ + elif access.access_type == "w": + return AccessType.WRITE + elif access.access_type == "rw": + return AccessType.READ_WRITE + else: + return AccessType.NONE + except Exception as e: + logger.exception(e) + return AccessType.NONE + + +def get_allowed_columns_for_group( + conn: DBConnector, table_name: str, group_id: int +) -> list[str]: + try: + allowed_columns = conn.filterFromTable( + TABLE_ACCESS_TABLE_NAME, + ["*"], + [ + ColumnCondition("table_name", table_name), + ColumnCondition("user_group_id", group_id), + ], + ) + if not allowed_columns: + return [] + allowed_columns = allowed_columns[0]["allowed_columns"] + if allowed_columns == "*": + return ["*"] + return allowed_columns + except Exception as e: + logger.exception(e) + return [] + + +def drop_table(conn: DBConnector, table_name: str): + try: + if table_name == META_INFO_TABLE_NAME: + raise Exception("Cannot drop meta info table") + if table_name == USER_GROUP_TABLE_NAME: + raise Exception("Cannot drop user group table") + if table_name == USERS_TABLE_NAME: + raise Exception("Cannot drop users table") + if table_name == USER_IN_USER_GROUP_JOIN_TABLE_NAME: + raise Exception("Cannot drop user in user group join table") + if table_name == TABLE_ACCESS_TABLE_NAME: + raise Exception("Cannot drop table access table") + if table_name == ASSETS_TABLE_NAME: + raise Exception("Cannot drop assets table") + + if not conn.tableExists(table_name): + raise Exception("Table does not exist") + + conn.dropTable(table_name) + return True, None + except Exception as e: + logger.exception(e) + return False, e + + +def create_asset(conn: DBConnector, name: str, description: str, fid: str): + try: + conn.insertIntoTable( + ASSETS_TABLE_NAME, + { + "name": name, + "description": description, + "fid": fid, + }, + ) + # TODO: add asset access + # TODO: add asset to seaweedfs + return True + except Exception as e: + logger.exception(e) + return None + + +def remove_asset(conn: DBConnector, token: str | None, asset_id: int): + try: + conn.deleteFromTable(ASSETS_TABLE_NAME, [ColumnCondition("id", asset_id)]) + # TODO: remove asset access + # TODO: remove asset from seaweedfs + return True + except Exception as e: + logger.exception(e) + return False + + +def get_asset(conn: DBConnector, token: str | None, fid: str): + try: + user, group = get_user_by_access_token(conn, token) + assets = conn.filterFromTable( + ASSETS_TABLE_NAME, ["*"], [ColumnCondition("fid", fid)] + ) + print(assets) + if len(assets) == 0: + return None + asset = Asset.parse_obj(assets[0]) + asset_access = get_asset_access(conn, asset.id) + # TODO: check if user has access to asset + + return asset + except Exception as e: + logger.exception(e) + return None + + +def create_asset_access(conn: DBConnector, asset_id: int, user_group_id: int): + try: + conn.insertIntoTable( + ASSET_ACCESS_TABLE_NAME, + {"asset_id": asset_id, "user_group_id": user_group_id}, + ) + return True + except Exception as e: + # NOTE: this should not happen ever + logger.exception(e) + return False + + +def get_asset_access(conn: DBConnector, asset_id: int): + try: + access = conn.filterFromTable( + ASSET_ACCESS_TABLE_NAME, ["*"], [ColumnCondition("asset_id", asset_id)] + ) + if not access: + return AccessType.NONE + access = AssetAccess.parse_obj(access[0]) + if access.access_type == "r": + return AccessType.READ + elif access.access_type == "w": + return AccessType.WRITE + elif access.access_type == "rw": + return AccessType.READ_WRITE + else: + return AccessType.NONE + except Exception as e: + logger.exception(e) + return AccessType.NONE + + +def change_asset_access( + conn: DBConnector, asset_id: int, user_group_id: int, access_type: AccessType +): + # TODO: implement + raise NotImplementedError() diff --git a/models.py b/models.py new file mode 100644 index 0000000..fc464bb --- /dev/null +++ b/models.py @@ -0,0 +1,26 @@ +from typing import Any +from pydantic import BaseModel + + +class ItemsFieldSelectorList(BaseModel): + fields: list[str] = [] + + +class ColumnsDefinitionList(BaseModel): + columns: list[str] + + +class UserDefinition(BaseModel): + username: str + password: str + + +class ColumnDefinition(BaseModel): + name: str + value: Any + isString: bool = False + isLike: bool = True + + +class ItemDeletionDefinitionList(BaseModel): + defs: list[ColumnDefinition] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..f883c32 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,29 @@ +aiofiles==23.1.0 +based @ file:///C:/Users/Admin/Documents/repos/based/wheels/based-0.1.0-py2.py3-none-any.whl +black==23.1.0 +blinker==1.5 +click==8.1.3 +colorama==0.4.6 +h11==0.14.0 +h2==4.1.0 +hpack==4.0.0 +hypercorn==0.14.3 +hyperframe==6.0.1 +itsdangerous==2.1.2 +Jinja2==3.1.2 +MarkupSafe==2.1.2 +mypy-extensions==1.0.0 +packaging==23.0 +pathspec==0.11.0 +platformdirs==3.1.1 +priority==2.0.0 +psycopg==3.1.8 +psycopg-binary==3.1.8 +psycopg-pool==3.1.6 +quart==0.18.3 +toml==0.10.2 +tomli==2.0.1 +typing_extensions==4.5.0 +tzdata==2022.7 +Werkzeug==2.2.3 +wsproto==1.2.0 diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..88e1dd4 --- /dev/null +++ b/utils.py @@ -0,0 +1,127 @@ +from based.db import DBConnector +from based.columns import ( + ColumnDefinition, + make_column_unique, + PrimarySerialColumnDefinition, + PrimaryUUIDColumnDefinition, + TextColumnDefinition, + BigintColumnDefinition, + BooleanColumnDefinition, + DateColumnDefinition, + TimestampColumnDefinition, + DoubleColumnDefinition, + IntegerColumnDefinition, + UUIDColumnDefinition, +) +from dba import get_user_by_access_token + + +def check_if_admin_access_token( + connector: DBConnector, access_token: str | None +) -> bool: + if access_token is None: + return False + + user, group = get_user_by_access_token(connector, access_token) + if user is None or group is None or group.name != "admin": + return False + + return True + + +def get_column_from_definition(definition: str) -> ColumnDefinition | None: + match definition.split(":"): + case [name, "serial", "primary"]: + return PrimarySerialColumnDefinition(name) + + case [name, "uuid", "primary"]: + return PrimaryUUIDColumnDefinition(name) + + case [name, "str"]: + return TextColumnDefinition(name) + case [name, "str", "unique"]: + return make_column_unique(TextColumnDefinition(name)) + case [name, "str", "default", default]: + return TextColumnDefinition(name, default=default) + case [name, "str", "default", default, "unique"]: + return make_column_unique(TextColumnDefinition(name, default=default)) + + case [name, "bigint"]: + return BigintColumnDefinition(name) + case [name, "bigint", "unique"]: + return make_column_unique(BigintColumnDefinition(name)) + case [name, "bigint", "default", default]: + return BigintColumnDefinition(name, default=int(default)) + case [name, "bigint", "default", default, "unique"]: + return make_column_unique( + BigintColumnDefinition(name, default=int(default)) + ) + + case [name, "bool"]: + return BooleanColumnDefinition(name) + case [name, "bool", "unique"]: + return make_column_unique(BooleanColumnDefinition(name)) + case [name, "bool", "default", default]: + return BooleanColumnDefinition(name, default=bool(default)) + case [name, "bool", "default", default, "unique"]: + return make_column_unique( + BooleanColumnDefinition(name, default=bool(default)) + ) + + case [name, "date"]: + return DateColumnDefinition(name) + case [name, "date", "unique"]: + return make_column_unique(DateColumnDefinition(name)) + # TODO: Add default value for date + + case [name, "datetime"]: + return TimestampColumnDefinition(name) + case [name, "datetime", "unique"]: + return make_column_unique(TimestampColumnDefinition(name)) + # TODO: Add default value for timestamp + + case [name, "float"]: + return DoubleColumnDefinition(name) + case [name, "float", "unique"]: + return make_column_unique(DoubleColumnDefinition(name)) + case [name, "float", "default", default]: + return DoubleColumnDefinition(name, default=float(default)) + case [name, "float", "default", default, "unique"]: + return make_column_unique( + DoubleColumnDefinition(name, default=float(default)) + ) + + case [name, "int"]: + return IntegerColumnDefinition(name) + case [name, "int", "unique"]: + return make_column_unique(IntegerColumnDefinition(name)) + case [name, "int", "default", default]: + return IntegerColumnDefinition(name, default=int(default)) + case [name, "int", "default", default, "unique"]: + return make_column_unique( + IntegerColumnDefinition(name, default=int(default)) + ) + + case [name, "uuid"]: + return UUIDColumnDefinition(name) + case [name, "uuid", "unique"]: + return make_column_unique(UUIDColumnDefinition(name)) + case [name, "uuid", "default", default]: + return UUIDColumnDefinition(name, default=default) + case [name, "uuid", "default", default, "unique"]: + return make_column_unique(UUIDColumnDefinition(name, default=default)) + + return None + + +def parse_columns_from_definition(definition: str) -> list[ColumnDefinition]: + columns = [] + for column_definition in definition.split(","): + if column_definition == "": + continue + column = get_column_from_definition(column_definition) + if column is None: + raise ValueError(f"Invalid column definition: {column_definition}") + columns.append(column) + + return columns