Add the rest of the project

This commit is contained in:
Andrew 2023-03-20 18:38:45 +07:00
parent 6a2109ce3d
commit 49f2d8e924
6 changed files with 1080 additions and 0 deletions

302
app.py Normal file
View file

@ -0,0 +1,302 @@
import io
from typing import Any
from fastapi import FastAPI, status, Header, UploadFile
from starlette.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from based import db
import psycopg
from hashlib import sha256
from secrets import token_hex
from minio import Minio
from minio.error import S3Error
from minio.helpers import ObjectWriteResult
from urllib3 import HTTPResponse
from dba import *
from models import (
ColumnsDefinitionList,
ItemDeletionDefinitionList,
ItemsFieldSelectorList,
UserDefinition,
)
from utils import (
check_if_admin_access_token,
parse_columns_from_definition,
)
conninfo = "postgresql://postgres:asarch6122@localhost"
connector = db.DBConnector(conninfo)
bootstrapDB(connector)
BUCKET_NAME = "tuuli-files"
minioClient = Minio(
"localhost:8090",
access_key="mxR0F5PK8CpCM8SA",
secret_key="yFJsG70xLU3BiIMslinz6dhqKHqNpUc6",
secure=False,
)
found = minioClient.bucket_exists(BUCKET_NAME)
if found:
print(f"Bucket '{BUCKET_NAME}' already exists")
else:
minioClient.make_bucket(BUCKET_NAME)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/api/listTables")
async def listTables(access_token: str | None = Header(default=None)):
is_admin = check_if_admin_access_token(connector, access_token)
if not is_admin:
return {"error": "Not allowed"}
tables = connector.tables()
return tables
@app.post("/api/createTable/{tableName}")
async def createTable(
tableName: str,
columns: ColumnsDefinitionList,
access_token: str | None = Header(default=None),
):
is_admin = check_if_admin_access_token(connector, access_token)
if not is_admin:
return {"error": "Not allowed"}
try:
columnsDefinition = parse_columns_from_definition(",".join(columns.columns))
create_table(connector, tableName, columnsDefinition)
except psycopg.errors.UniqueViolation:
return {"error": "Username already exists"}
except Exception as e:
return {"error": str(e)}
return {"ok": True}
@app.post("/api/dropTable/{tableName}")
async def dropTable(
tableName: str,
access_token: str | None = Header(default=None),
):
is_admin = check_if_admin_access_token(connector, access_token)
if not is_admin:
return {"error": "Not allowed"}
try:
drop_table(connector, tableName)
except Exception as e:
return {"error": str(e)}
return {"ok": True}
@app.post("/api/createUser")
async def createUser(
user: UserDefinition,
access_token: str | None = Header(default=None),
):
is_admin = check_if_admin_access_token(connector, access_token)
if not is_admin:
return {"error": "Not allowed"}
try:
create_user(connector, user.username, user.password)
except psycopg.errors.UniqueViolation:
return {"error": "Username already exists"}
except Exception as e:
return {"error": str(e)}
return {"ok": True}
@app.post("/items/{tableName}")
async def items(
tableName: str,
selector: ItemsFieldSelectorList,
access_token: str | None = Header(default=None),
):
table_info = connector.getTable(tableName)
if not table_info:
return {"error": "Not allowed"}
is_admin = check_if_admin_access_token(connector, access_token)
if table_info["system"] and not is_admin:
return {"error": "Not allowed"}
columns = parse_columns_from_definition(table_info["columns"])
columnsNames = set(column.name for column in columns)
userSelectedColumns = list(set(selector.fields)) if selector.fields else ["*"]
if userSelectedColumns != ["*"]:
for column in userSelectedColumns:
if column not in columnsNames:
return {"error": f"Column {column} not found on table {tableName}"}
else:
userSelectedColumns = columnsNames
user, group = get_user_by_access_token(connector, access_token)
if not user:
return {"error": "Not allowed"}
if not is_admin:
allowedColumns = get_allowed_columns_for_group(
connector, tableName, group.id if group else -1
)
if not allowedColumns:
return {"error": "Not allowed"}
elif len(allowedColumns) == 1 and allowedColumns[0] == "*":
pass
else:
for column in userSelectedColumns:
if column not in allowedColumns:
return {"error": "Not allowed"}
print(columnsNames)
print(selector)
table_items = connector.selectFromTable(
tableName, selector.fields if selector.fields else ["*"]
)
return table_items
@app.post("/items/{tableName}/+")
async def itemsCreate(
tableName: str,
item: dict[str, str],
access_token: str | None = Header(default=None),
):
table_info = connector.getTable(tableName)
if not table_info:
return {"error": "Not found"}
is_admin = check_if_admin_access_token(connector, access_token)
if table_info["system"] and not is_admin:
return {"error": "Not allowed"}
user, group = get_user_by_access_token(connector, access_token)
if not is_admin:
allowedColumns = get_allowed_columns_for_group(
connector, tableName, group.id if group else -1
)
if not allowedColumns:
return {"error": "Not allowed"}
elif len(allowedColumns) == 1 and allowedColumns[0] == "*":
pass
else:
for column in item:
if column not in allowedColumns:
return {"error": "Not allowed"}
try:
connector.insertIntoTable(tableName, item)
except psycopg.errors.UndefinedColumn:
return {"error": "Column not found"}
except psycopg.errors.UniqueViolation:
return {"error": "Unique constraint violation"}
except Exception as e:
return {"error": str(e)}
return {"ok": True}
@app.post("/items/{tableName}/-")
async def itemsDelete(
tableName: str,
deleteWhere: ItemDeletionDefinitionList,
access_token: str | None = Header(default=None),
):
table_info = connector.getTable(tableName)
if not table_info:
return {"error": "Not found"}
is_admin = check_if_admin_access_token(connector, access_token)
if table_info["system"] and not is_admin:
return {"error": "Not allowed"}
user, group = get_user_by_access_token(connector, access_token)
if not is_admin:
allowedColumns = get_allowed_columns_for_group(
connector, tableName, group.id if group else -1
)
if not allowedColumns:
return {"error": "Not allowed"}
elif len(allowedColumns) == 1 and allowedColumns[0] == "*":
pass
else:
return {"error": "Not allowed"}
try:
connector.deleteFromTable(
tableName,
[
ColumnCondition(where.name, where.value, where.isString, where.isLike)
for where in deleteWhere.defs
],
)
except Exception as e:
return {"error": str(e)}
return {"ok": True}
@app.get("/assets/{fid}")
async def getAsset(fid: str, access_token: str | None = Header(default=None)):
asset = get_asset(connector, access_token, fid)
if not asset:
return status.HTTP_404_NOT_FOUND
response: HTTPResponse | None = None
try:
response = minioClient.get_object(BUCKET_NAME, asset.name, version_id=asset.fid)
if response is None:
return status.HTTP_404_NOT_FOUND
return StreamingResponse(
content=io.BytesIO(response.data),
media_type=response.getheader("Content-Type"),
status_code=status.HTTP_200_OK,
)
finally:
if response is not None:
response.close()
response.release_conn()
@app.post("/assets/+")
async def createAsset(
asset: UploadFile,
access_token: str | None = Header(default=None),
):
user, _ = get_user_by_access_token(connector, access_token)
if not user:
return {"error": "Not allowed"}
filename = asset.filename
if not filename:
filename = f"unnamed"
filename = f"{token_hex()}_{filename}"
result: ObjectWriteResult = minioClient.put_object(
BUCKET_NAME,
filename,
data=asset.file,
content_type=(
asset.content_type if asset.content_type else "application/octet-stream"
),
length=asset.size,
)
if not create_asset(connector, filename, "", str(result.version_id)):
return {"error": "Failed to create asset"}
return {"ok": True, "fid": result.version_id}

137
db_models.py Normal file
View file

@ -0,0 +1,137 @@
import enum
from based.columns import (
PrimarySerialColumnDefinition,
TextColumnDefinition,
IntegerColumnDefinition,
make_column_unique,
)
from pydantic import BaseModel
class AccessType(enum.Enum):
READ = "read"
WRITE = "write"
READ_WRITE = "read_write"
NONE = "none"
META_INFO_TABLE_NAME = "meta_info"
META_INFO_TABLE_SCHEMA = [
PrimarySerialColumnDefinition("id"),
make_column_unique(TextColumnDefinition("name")),
TextColumnDefinition("value"),
TextColumnDefinition("allowed_columns", default="*"),
]
class MetaInfo(BaseModel):
id: int
name: str
value: str
allowed_columns: str
USER_GROUP_TABLE_NAME = "user_group"
USER_GROUP_TABLE_SCHEMA = [
PrimarySerialColumnDefinition("id"),
make_column_unique(TextColumnDefinition("name")),
TextColumnDefinition("description", default=""),
]
class UserGroup(BaseModel):
id: int
name: str
description: str
USERS_TABLE_NAME = "users"
USERS_TABLE_SCHEMA = [
PrimarySerialColumnDefinition("id"),
make_column_unique(TextColumnDefinition("username")),
TextColumnDefinition("password"),
TextColumnDefinition("access_token"),
]
class User(BaseModel):
id: int
username: str
password: str
access_token: str
def to_dict(self, safe=True):
d = {
"id": self.id,
"username": self.username,
}
if not safe:
d["access_token"] = self.access_token
d["password"] = self.password
return d
USER_IN_USER_GROUP_JOIN_TABLE_NAME = "user_in_user_group"
USER_IN_USER_GROUP_JOIN_TABLE_SCHEMA = [
PrimarySerialColumnDefinition("id"),
IntegerColumnDefinition("user_id"),
IntegerColumnDefinition("user_group_id"),
]
class UserInUserGroup(BaseModel):
id: int
user_id: int
user_group_id: int
TABLE_ACCESS_TABLE_NAME = "table_access"
TABLE_ACCESS_TABLE_SCHEMA = [
PrimarySerialColumnDefinition("id"),
IntegerColumnDefinition("user_group_id"),
TextColumnDefinition("table_name"),
TextColumnDefinition("access_type"),
TextColumnDefinition("allowed_columns", default="*"),
]
class TableAccess(BaseModel):
id: int
user_group_id: int
table_name: str
access_type: str
allowed_columns: str
ASSETS_TABLE_NAME = "assets"
ASSETS_TABLE_SCHEMA = [
PrimarySerialColumnDefinition("id"),
TextColumnDefinition("name"),
TextColumnDefinition("description", default=""),
TextColumnDefinition("fid"),
TextColumnDefinition("catalog", default="/root"),
]
class Asset(BaseModel):
id: int
name: str
description: str
fid: str
catalog: str
ASSET_ACCESS_TABLE_NAME = "asset_access"
ASSET_ACCESS_TABLE_SCHEMA = [
PrimarySerialColumnDefinition("id"),
IntegerColumnDefinition("user_group_id"),
IntegerColumnDefinition("asset_id"),
TextColumnDefinition("access_type"),
]
class AssetAccess(BaseModel):
id: int
user_group_id: int
asset_id: int
access_type: str

459
dba.py Normal file
View file

@ -0,0 +1,459 @@
from hashlib import sha256
import logging
from secrets import token_hex
from based.db import DBConnector, ColumnCondition, ColumnUpdate, ColumnDefinition
from db_models import *
logger = logging.getLogger(__name__)
def bootstrapDB(conn: DBConnector):
if not conn.tableExists(META_INFO_TABLE_NAME):
conn.createTable(
META_INFO_TABLE_NAME, META_INFO_TABLE_SCHEMA, system=True, hidden=True
)
if not conn.tableExists(USER_GROUP_TABLE_NAME):
conn.createTable(USER_GROUP_TABLE_NAME, USER_GROUP_TABLE_SCHEMA, system=True)
if not conn.tableExists(USERS_TABLE_NAME):
conn.createTable(USERS_TABLE_NAME, USERS_TABLE_SCHEMA, system=True)
if not conn.tableExists(USER_IN_USER_GROUP_JOIN_TABLE_NAME):
conn.createTable(
USER_IN_USER_GROUP_JOIN_TABLE_NAME,
USER_IN_USER_GROUP_JOIN_TABLE_SCHEMA,
system=True,
)
if not conn.tableExists(TABLE_ACCESS_TABLE_NAME):
conn.createTable(
TABLE_ACCESS_TABLE_NAME,
TABLE_ACCESS_TABLE_SCHEMA,
system=True,
)
if not conn.tableExists(ASSETS_TABLE_NAME):
conn.createTable(
ASSETS_TABLE_NAME,
ASSETS_TABLE_SCHEMA,
system=True,
)
if not conn.tableExists(ASSET_ACCESS_TABLE_NAME):
conn.createTable(
ASSET_ACCESS_TABLE_NAME,
ASSET_ACCESS_TABLE_SCHEMA,
system=True,
)
meta = get_metadata(conn, "admin_created")
testAdminCreated = meta and meta.value == "yes"
if not testAdminCreated:
create_user(conn, "admin", "admin")
create_group(conn, "admin")
users = list_users(conn)
groups = list_groups(conn)
set_user_group(conn, users[0].id, groups[0].id)
add_metadata(conn, "admin_created", "yes")
def add_metadata(conn: DBConnector, name: str, value: str):
try:
conn.insertIntoTable(META_INFO_TABLE_NAME, {"name": name, "value": value})
return True, None
except Exception as e:
logger.exception(e)
return False, e
def get_metadata(conn: DBConnector, name: str):
try:
metadata = conn.filterFromTable(
META_INFO_TABLE_NAME, ["*"], [ColumnCondition("name", name)]
)
if len(metadata) == 0:
logger.warning(f"Metadata {name} not found")
return None
return MetaInfo.parse_obj(metadata[0])
except Exception as e:
logger.exception(e)
return None
def create_user(conn: DBConnector, username: str, password: str):
try:
hashedPwd = sha256(password.encode("utf-8"), usedforsecurity=True).hexdigest()
conn.insertIntoTable(
USERS_TABLE_NAME,
{"username": username, "password": hashedPwd, "access_token": token_hex()},
)
return True, None
except Exception as e:
logger.exception(e)
return False, e
def get_user_by_username(conn: DBConnector, username: str):
try:
users = conn.filterFromTable(
USERS_TABLE_NAME, ["*"], [ColumnCondition("username", username)]
)
if len(users) == 0:
logger.warning(f"User {username} not found")
return None
return User.parse_obj(users[0])
except Exception as e:
logger.exception(e)
return None
def get_user_by_id(conn: DBConnector, user_id: int):
try:
users = conn.filterFromTable(
USERS_TABLE_NAME, ["*"], [ColumnCondition("id", user_id)]
)
if len(users) == 0:
logger.warning(f"User with id {user_id} not found")
return None
return User.parse_obj(users[0])
except Exception as e:
logger.exception(e)
return None
def get_user_by_access_token(conn: DBConnector, access_token: str | None):
try:
users = conn.filterFromTable(
USERS_TABLE_NAME, ["*"], [ColumnCondition("access_token", access_token)]
)
if len(users) == 0:
logger.warning("Invalid access token")
return None, None
user = User.parse_obj(users[0])
user_group = get_user_group(conn, user.id)
return user, user_group
except Exception as e:
logger.exception(e)
return None, None
def check_user(conn: DBConnector, username: str, password: str):
try:
hashedPwd = sha256(password.encode("utf-8"), usedforsecurity=True).hexdigest()
user = conn.filterFromTable(
USERS_TABLE_NAME,
["*"],
[
ColumnCondition("username", username),
ColumnCondition("password", hashedPwd),
],
)
if len(user) == 0:
logger.warning("Invalid username or password")
return None
return User.parse_obj(user[0])
except Exception as e:
logger.exception(e)
return None
def create_group(conn: DBConnector, name: str, description: str = ""):
try:
conn.insertIntoTable(
USER_GROUP_TABLE_NAME, {"name": name, "description": description}
)
return True, None
except Exception as e:
logger.exception(e)
return False, e
def get_group_by_name(conn: DBConnector, name: str):
try:
groups = conn.filterFromTable(
USER_GROUP_TABLE_NAME, ["*"], [ColumnCondition("name", name)]
)
if len(groups) == 0:
logger.warning(f"Group {name} not found")
return None
return UserGroup.parse_obj(groups[0])
except Exception as e:
logger.exception(e)
return None
def get_group_by_id(conn: DBConnector, group_id: int):
try:
groups = conn.filterFromTable(
USER_GROUP_TABLE_NAME, ["*"], [ColumnCondition("id", group_id)]
)
if len(groups) == 0:
logger.warning(f"Group with id {group_id} not found")
return None
return UserGroup.parse_obj(groups[0])
except Exception as e:
logger.exception(e)
return None
def set_user_group(conn: DBConnector, user_id: int, group_id: int):
try:
if not conn.filterFromTable(
USER_IN_USER_GROUP_JOIN_TABLE_NAME,
["*"],
[
ColumnCondition("user_id", user_id),
],
):
conn.insertIntoTable(
USER_IN_USER_GROUP_JOIN_TABLE_NAME,
{"user_id": user_id, "user_group_id": group_id},
)
else:
conn.updateDataInTable(
USER_IN_USER_GROUP_JOIN_TABLE_NAME,
[
ColumnUpdate("user_group_id", group_id),
],
[
ColumnCondition("user_id", user_id),
],
)
return True, None
except Exception as e:
logger.exception(e)
return False, e
def get_user_group(conn: DBConnector, user_id: int):
try:
grp_usr_joint = conn.filterFromTable(
USER_IN_USER_GROUP_JOIN_TABLE_NAME,
["*"],
[ColumnCondition("user_id", user_id)],
)
if len(grp_usr_joint) == 0:
logger.warning(f"User with id {user_id} not found, so no group")
return None
uiug = UserInUserGroup.parse_obj(grp_usr_joint[0])
return get_group_by_id(conn, uiug.user_group_id)
except Exception as e:
logger.exception(e)
return None
def get_group_users(conn: DBConnector, group_id: int) -> list[User]:
try:
users = conn.filterFromTable(
USER_IN_USER_GROUP_JOIN_TABLE_NAME,
["*"],
[ColumnCondition("user_group_id", group_id)],
)
return [*map(User.parse_obj, users)]
except Exception as e:
logger.exception(e)
return []
def list_users(conn: DBConnector) -> list[User]:
try:
users = conn.selectFromTable(USERS_TABLE_NAME, ["*"])
return [*map(User.parse_obj, users)]
except Exception as e:
logger.exception(e)
return []
def list_groups(conn: DBConnector):
try:
groups = conn.selectFromTable(USER_GROUP_TABLE_NAME, ["*"])
return [*map(UserGroup.parse_obj, groups)]
except Exception as e:
logger.exception(e)
return []
def create_table(conn: DBConnector, table_name: str, schema: list[ColumnDefinition]):
try:
conn.createTable(table_name, schema)
return True
except Exception as e:
logger.exception(e)
return False
def get_table_access_level(
conn: DBConnector, table_name: str, user_id: int
) -> AccessType:
try:
user_group = get_user_group(conn, user_id)
if not user_group:
return AccessType.NONE
elif user_group.name == "admin":
return AccessType.READ_WRITE
access = conn.filterFromTable(
TABLE_ACCESS_TABLE_NAME,
["*"],
[
ColumnCondition("table_name", table_name),
ColumnCondition("user_group_id", user_group.id),
],
)
if not access:
return AccessType.NONE
access = TableAccess.parse_obj(access[0])
if access.access_type == "r":
return AccessType.READ
elif access.access_type == "w":
return AccessType.WRITE
elif access.access_type == "rw":
return AccessType.READ_WRITE
else:
return AccessType.NONE
except Exception as e:
logger.exception(e)
return AccessType.NONE
def get_allowed_columns_for_group(
conn: DBConnector, table_name: str, group_id: int
) -> list[str]:
try:
allowed_columns = conn.filterFromTable(
TABLE_ACCESS_TABLE_NAME,
["*"],
[
ColumnCondition("table_name", table_name),
ColumnCondition("user_group_id", group_id),
],
)
if not allowed_columns:
return []
allowed_columns = allowed_columns[0]["allowed_columns"]
if allowed_columns == "*":
return ["*"]
return allowed_columns
except Exception as e:
logger.exception(e)
return []
def drop_table(conn: DBConnector, table_name: str):
try:
if table_name == META_INFO_TABLE_NAME:
raise Exception("Cannot drop meta info table")
if table_name == USER_GROUP_TABLE_NAME:
raise Exception("Cannot drop user group table")
if table_name == USERS_TABLE_NAME:
raise Exception("Cannot drop users table")
if table_name == USER_IN_USER_GROUP_JOIN_TABLE_NAME:
raise Exception("Cannot drop user in user group join table")
if table_name == TABLE_ACCESS_TABLE_NAME:
raise Exception("Cannot drop table access table")
if table_name == ASSETS_TABLE_NAME:
raise Exception("Cannot drop assets table")
if not conn.tableExists(table_name):
raise Exception("Table does not exist")
conn.dropTable(table_name)
return True, None
except Exception as e:
logger.exception(e)
return False, e
def create_asset(conn: DBConnector, name: str, description: str, fid: str):
try:
conn.insertIntoTable(
ASSETS_TABLE_NAME,
{
"name": name,
"description": description,
"fid": fid,
},
)
# TODO: add asset access
# TODO: add asset to seaweedfs
return True
except Exception as e:
logger.exception(e)
return None
def remove_asset(conn: DBConnector, token: str | None, asset_id: int):
try:
conn.deleteFromTable(ASSETS_TABLE_NAME, [ColumnCondition("id", asset_id)])
# TODO: remove asset access
# TODO: remove asset from seaweedfs
return True
except Exception as e:
logger.exception(e)
return False
def get_asset(conn: DBConnector, token: str | None, fid: str):
try:
user, group = get_user_by_access_token(conn, token)
assets = conn.filterFromTable(
ASSETS_TABLE_NAME, ["*"], [ColumnCondition("fid", fid)]
)
print(assets)
if len(assets) == 0:
return None
asset = Asset.parse_obj(assets[0])
asset_access = get_asset_access(conn, asset.id)
# TODO: check if user has access to asset
return asset
except Exception as e:
logger.exception(e)
return None
def create_asset_access(conn: DBConnector, asset_id: int, user_group_id: int):
try:
conn.insertIntoTable(
ASSET_ACCESS_TABLE_NAME,
{"asset_id": asset_id, "user_group_id": user_group_id},
)
return True
except Exception as e:
# NOTE: this should not happen ever
logger.exception(e)
return False
def get_asset_access(conn: DBConnector, asset_id: int):
try:
access = conn.filterFromTable(
ASSET_ACCESS_TABLE_NAME, ["*"], [ColumnCondition("asset_id", asset_id)]
)
if not access:
return AccessType.NONE
access = AssetAccess.parse_obj(access[0])
if access.access_type == "r":
return AccessType.READ
elif access.access_type == "w":
return AccessType.WRITE
elif access.access_type == "rw":
return AccessType.READ_WRITE
else:
return AccessType.NONE
except Exception as e:
logger.exception(e)
return AccessType.NONE
def change_asset_access(
conn: DBConnector, asset_id: int, user_group_id: int, access_type: AccessType
):
# TODO: implement
raise NotImplementedError()

26
models.py Normal file
View file

@ -0,0 +1,26 @@
from typing import Any
from pydantic import BaseModel
class ItemsFieldSelectorList(BaseModel):
fields: list[str] = []
class ColumnsDefinitionList(BaseModel):
columns: list[str]
class UserDefinition(BaseModel):
username: str
password: str
class ColumnDefinition(BaseModel):
name: str
value: Any
isString: bool = False
isLike: bool = True
class ItemDeletionDefinitionList(BaseModel):
defs: list[ColumnDefinition]

29
requirements.txt Normal file
View file

@ -0,0 +1,29 @@
aiofiles==23.1.0
based @ file:///C:/Users/Admin/Documents/repos/based/wheels/based-0.1.0-py2.py3-none-any.whl
black==23.1.0
blinker==1.5
click==8.1.3
colorama==0.4.6
h11==0.14.0
h2==4.1.0
hpack==4.0.0
hypercorn==0.14.3
hyperframe==6.0.1
itsdangerous==2.1.2
Jinja2==3.1.2
MarkupSafe==2.1.2
mypy-extensions==1.0.0
packaging==23.0
pathspec==0.11.0
platformdirs==3.1.1
priority==2.0.0
psycopg==3.1.8
psycopg-binary==3.1.8
psycopg-pool==3.1.6
quart==0.18.3
toml==0.10.2
tomli==2.0.1
typing_extensions==4.5.0
tzdata==2022.7
Werkzeug==2.2.3
wsproto==1.2.0

127
utils.py Normal file
View file

@ -0,0 +1,127 @@
from based.db import DBConnector
from based.columns import (
ColumnDefinition,
make_column_unique,
PrimarySerialColumnDefinition,
PrimaryUUIDColumnDefinition,
TextColumnDefinition,
BigintColumnDefinition,
BooleanColumnDefinition,
DateColumnDefinition,
TimestampColumnDefinition,
DoubleColumnDefinition,
IntegerColumnDefinition,
UUIDColumnDefinition,
)
from dba import get_user_by_access_token
def check_if_admin_access_token(
connector: DBConnector, access_token: str | None
) -> bool:
if access_token is None:
return False
user, group = get_user_by_access_token(connector, access_token)
if user is None or group is None or group.name != "admin":
return False
return True
def get_column_from_definition(definition: str) -> ColumnDefinition | None:
match definition.split(":"):
case [name, "serial", "primary"]:
return PrimarySerialColumnDefinition(name)
case [name, "uuid", "primary"]:
return PrimaryUUIDColumnDefinition(name)
case [name, "str"]:
return TextColumnDefinition(name)
case [name, "str", "unique"]:
return make_column_unique(TextColumnDefinition(name))
case [name, "str", "default", default]:
return TextColumnDefinition(name, default=default)
case [name, "str", "default", default, "unique"]:
return make_column_unique(TextColumnDefinition(name, default=default))
case [name, "bigint"]:
return BigintColumnDefinition(name)
case [name, "bigint", "unique"]:
return make_column_unique(BigintColumnDefinition(name))
case [name, "bigint", "default", default]:
return BigintColumnDefinition(name, default=int(default))
case [name, "bigint", "default", default, "unique"]:
return make_column_unique(
BigintColumnDefinition(name, default=int(default))
)
case [name, "bool"]:
return BooleanColumnDefinition(name)
case [name, "bool", "unique"]:
return make_column_unique(BooleanColumnDefinition(name))
case [name, "bool", "default", default]:
return BooleanColumnDefinition(name, default=bool(default))
case [name, "bool", "default", default, "unique"]:
return make_column_unique(
BooleanColumnDefinition(name, default=bool(default))
)
case [name, "date"]:
return DateColumnDefinition(name)
case [name, "date", "unique"]:
return make_column_unique(DateColumnDefinition(name))
# TODO: Add default value for date
case [name, "datetime"]:
return TimestampColumnDefinition(name)
case [name, "datetime", "unique"]:
return make_column_unique(TimestampColumnDefinition(name))
# TODO: Add default value for timestamp
case [name, "float"]:
return DoubleColumnDefinition(name)
case [name, "float", "unique"]:
return make_column_unique(DoubleColumnDefinition(name))
case [name, "float", "default", default]:
return DoubleColumnDefinition(name, default=float(default))
case [name, "float", "default", default, "unique"]:
return make_column_unique(
DoubleColumnDefinition(name, default=float(default))
)
case [name, "int"]:
return IntegerColumnDefinition(name)
case [name, "int", "unique"]:
return make_column_unique(IntegerColumnDefinition(name))
case [name, "int", "default", default]:
return IntegerColumnDefinition(name, default=int(default))
case [name, "int", "default", default, "unique"]:
return make_column_unique(
IntegerColumnDefinition(name, default=int(default))
)
case [name, "uuid"]:
return UUIDColumnDefinition(name)
case [name, "uuid", "unique"]:
return make_column_unique(UUIDColumnDefinition(name))
case [name, "uuid", "default", default]:
return UUIDColumnDefinition(name, default=default)
case [name, "uuid", "default", default, "unique"]:
return make_column_unique(UUIDColumnDefinition(name, default=default))
return None
def parse_columns_from_definition(definition: str) -> list[ColumnDefinition]:
columns = []
for column_definition in definition.split(","):
if column_definition == "":
continue
column = get_column_from_definition(column_definition)
if column is None:
raise ValueError(f"Invalid column definition: {column_definition}")
columns.append(column)
return columns