Wow it is better now

This commit is contained in:
Andrew 2023-04-23 00:40:07 +07:00
parent 3afdd7276f
commit edaafcdf4f
3 changed files with 127 additions and 130 deletions

View file

@ -1 +1 @@
__version__ = "0.1.0"
__version__ = "0.2.0"

View file

@ -1,37 +1,23 @@
from datetime import date
from psycopg.sql import SQL, Identifier, Literal, Composed
from psycopg.sql import SQL, Identifier, Literal
class ColumnDefinition:
def __init__(self, name: str):
def __init__(self, name: str, unique: bool, has_default: bool = False):
self.name = name
self.unique = unique
self.has_default = has_default
def sql(self):
raise NotImplementedError
raise NotImplementedError()
def serialize(self):
raise NotImplementedError
class UniqueColumnDefinition(ColumnDefinition):
def __init__(self, wrapped: ColumnDefinition):
super().__init__(wrapped.name)
self.wrapped = wrapped
def sql(self):
return SQL("{} UNIQUE").format(self.wrapped.sql())
def serialize(self):
return f"{self.wrapped.serialize()}:unique"
def make_column_unique(column) -> UniqueColumnDefinition:
return UniqueColumnDefinition(column)
raise NotImplementedError()
class PrimarySerialColumnDefinition(ColumnDefinition):
def __init__(self, name: str):
super().__init__(name)
super().__init__(name, unique=False)
def sql(self):
return SQL("{} SERIAL PRIMARY KEY").format(Identifier(self.name))
@ -40,150 +26,106 @@ class PrimarySerialColumnDefinition(ColumnDefinition):
return f"{self.name}:serial:primary"
class PrimaryUUIDColumnDefinition(ColumnDefinition):
def __init__(self, name: str):
super().__init__(name)
def sql(self):
return SQL("{} uuid DEFAULT gen_random_uuid() PRIMARY KEY").format(
Identifier(self.name)
)
def serialize(self):
return f"{self.name}:uuid:primary"
class TextColumnDefinition(ColumnDefinition):
def __init__(self, name: str, default: str | None = None):
super().__init__(name)
def __init__(self, name: str, default: str | None = None, unique: bool = False):
if default is not None and unique:
raise ValueError("Cannot have a default value and be unique")
super().__init__(name, unique)
self.default = default
def sql(self):
if self.default is None:
return SQL("{} TEXT").format(Identifier(self.name))
return SQL("{} TEXT UNIQUE" if self.unique else "{} TEXT").format(
Identifier(self.name)
)
else:
return SQL("{} TEXT DEFAULT {}").format(
Identifier(self.name), Literal(self.default)
)
def serialize(self):
return f"{self.name}:str"
class BigintColumnDefinition(ColumnDefinition):
def __init__(self, name: str, default: int | None = None):
super().__init__(name)
self.default = default
def sql(self):
if self.default is None:
return SQL("{} BIGINT").format(Identifier(self.name))
else:
return SQL("{} BIGINT DEFAULT {}").format(
Identifier(self.name), Literal(self.default)
)
def serialize(self):
return f"{self.name}:bigint"
return f"{self.name}:str{':unique' if self.unique else ''}{':default' if self.default is not None else ''}"
class BooleanColumnDefinition(ColumnDefinition):
def __init__(self, name: str, default: bool | None = None):
super().__init__(name)
def __init__(self, name: str, default: bool | None = None, unique: bool = False):
if default is not None and unique:
raise ValueError("Cannot have a default value and be unique")
super().__init__(name, unique)
self.default = default
def sql(self):
if self.default is None:
return SQL("{} BOOLEAN").format(Identifier(self.name))
return SQL("{} BOOLEAN UNIQUE" if self.unique else "{} BOOLEAN").format(
Identifier(self.name)
)
else:
return SQL("{} BOOLEAN DEFAULT {}").format(
Identifier(self.name), Literal(self.default)
)
def serialize(self):
return f"{self.name}:bool"
class DateColumnDefinition(ColumnDefinition):
def __init__(self, name: str, default: date | None = None):
super().__init__(name)
self.default = default
def sql(self):
if self.default is None:
return SQL("{} DATE").format(Identifier(self.name))
else:
return SQL("{} DATE DEFAULT {}").format(
Identifier(self.name), Literal(self.default)
)
def serialize(self):
return f"{self.name}:date"
return f"{self.name}:bool{':unique' if self.unique else ''}{':default' if self.default is not None else ''}"
class TimestampColumnDefinition(ColumnDefinition):
def __init__(self, name: str, default: date | None = None):
super().__init__(name)
def __init__(self, name: str, default: date | None = None, unique: bool = False):
if default is not None and unique:
raise ValueError("Cannot have a default value and be unique")
super().__init__(name, unique)
self.default = default
def sql(self):
if self.default is None:
return SQL("{} TIMESTAMP").format(Identifier(self.name))
return SQL("{} TIMESTAMP UNIQUE" if self.unique else "{} TIMESTAMP").format(
Identifier(self.name)
)
else:
return SQL("{} TIMESTAMP DEFAULT {}").format(
Identifier(self.name), Literal(self.default)
)
def serialize(self):
return f"{self.name}:datetime"
return f"{self.name}:datetime{':unique' if self.unique else ''}{':default' if self.default is not None else ''}"
class DoubleColumnDefinition(ColumnDefinition):
def __init__(self, name: str, default: float | None = None):
super().__init__(name)
def __init__(self, name: str, default: float | None = None, unique: bool = False):
if default is not None and unique:
raise ValueError("Cannot have a default value and be unique")
super().__init__(name, unique)
self.default = default
def sql(self):
if self.default is None:
return SQL("{} DOUBLE PRECISION").format(Identifier(self.name))
return SQL(
"{} DOUBLE PRECISION UNIQUE" if self.unique else "{} DOUBLE PRECISION"
).format(Identifier(self.name))
else:
return SQL("{} DOUBLE PRECISION DEFAULT {}").format(
Identifier(self.name), Literal(self.default)
)
def serialize(self):
return f"{self.name}:float"
return f"{self.name}:float{':unique' if self.unique else ''}{':default' if self.default is not None else ''}"
class IntegerColumnDefinition(ColumnDefinition):
def __init__(self, name: str, default: int | None = None):
super().__init__(name)
def __init__(self, name: str, default: int | None = None, unique: bool = False):
if default is not None and unique:
raise ValueError("Cannot have a default value and be unique")
super().__init__(name, unique)
self.default = default
def sql(self):
if self.default is None:
return SQL("{} INTEGER").format(Identifier(self.name))
return SQL("{} INTEGER UNIQUE" if self.unique else "{} INTEGER").format(
Identifier(self.name)
)
else:
return SQL("{} INTEGER DEFAULT {}").format(
Identifier(self.name), Literal(self.default)
)
def serialize(self):
return f"{self.name}:int"
class UUIDColumnDefinition(ColumnDefinition):
def __init__(self, name: str, default: str | None = None):
super().__init__(name)
self.default = default
def sql(self):
if self.default is None:
return SQL("{} UUID").format(Identifier(self.name))
else:
return SQL("{} UUID DEFAULT {}").format(
Identifier(self.name), Literal(self.default)
)
def serialize(self):
return f"{self.name}:uuid"
return f"{self.name}:int{':unique' if self.unique else ''}{':default' if self.default is not None else ''}"

View file

@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Literal as _Literal
from psycopg import Connection
from psycopg.rows import dict_row
from psycopg.sql import SQL, Identifier, Literal
@ -13,26 +13,80 @@ class _ExtendedConnection(Connection):
self.row_factory = dict_row
CONDITION_OPERATORS = _Literal[
"eq",
"ne",
"gt",
"lt",
"ge",
"le",
"contains",
"not_contains",
"starts_with",
"not_starts_with",
"ends_with",
"not_ends_with",
]
class ColumnCondition:
def __init__(
self, column: str, value: Any, isString: bool = False, isLike: bool = True
):
def __init__(self, column: str, operator: CONDITION_OPERATORS, value: Any):
self.column = column
self.operator = operator
self.value = value
self.isString = isString
self.isLike = isLike
def sql(self):
if self.isString:
if self.isLike:
match self.operator:
case "eq":
return SQL("{} = {}").format(
Identifier(self.column), Literal(self.value)
)
case "ne":
return SQL("{} != {}").format(
Identifier(self.column), Literal(self.value)
)
case "gt":
return SQL("{} > {}").format(
Identifier(self.column), Literal(self.value)
)
case "lt":
return SQL("{} < {}").format(
Identifier(self.column), Literal(self.value)
)
case "ge":
return SQL("{} >= {}").format(
Identifier(self.column), Literal(self.value)
)
case "le":
return SQL("{} <= {}").format(
Identifier(self.column), Literal(self.value)
)
case "contains":
return SQL("{} LIKE {}").format(
Identifier(self.column), Literal(self.value)
)
else:
case "not_contains":
return SQL("{} NOT LIKE {}").format(
Identifier(self.column), Literal(self.value)
)
return SQL("{} = {}").format(Identifier(self.column), Literal(self.value))
case "starts_with":
return SQL("{} LIKE {}").format(
Identifier(self.column), Literal(f"{self.value}%")
)
case "not_starts_with":
return SQL("{} NOT LIKE {}").format(
Identifier(self.column), Literal(f"{self.value}%")
)
case "ends_with":
return SQL("{} LIKE {}").format(
Identifier(self.column), Literal(f"%{self.value}")
)
case "not_ends_with":
return SQL("{} NOT LIKE {}").format(
Identifier(self.column), Literal(f"%{self.value}")
)
case _:
raise ValueError(f"Unknown operator: {self.operator}")
class ColumnUpdate:
@ -50,7 +104,7 @@ class DBConnector:
if not self.__tableExistsInternal("tables_metadata"):
columns = [
PrimaryUUIDColumnDefinition("table_id"),
PrimarySerialColumnDefinition("table_id"),
TextColumnDefinition("table_name"),
TextColumnDefinition("columns"),
BooleanColumnDefinition("system", False),
@ -58,6 +112,7 @@ class DBConnector:
]
self.createTable("tables_metadata", columns, system=True, hidden=True)
@property
def connection(self):
return self._pool.connection()
@ -81,12 +136,12 @@ class DBConnector:
def removeTableMetadata(self, table_name: str):
self.deleteFromTable(
"tables_metadata",
[ColumnCondition("table_name", table_name, isString=True, isLike=True)],
[ColumnCondition("table_name", "eq", table_name)],
)
def tables(self):
stmt = SQL("SELECT * FROM tables_metadata")
with self.connection() as conn:
with self.connection as conn:
return conn.execute(stmt).fetchall()
def tableExists(self, table_name: str):
@ -99,7 +154,7 @@ class DBConnector:
)
"""
).format(Literal(table_name))
with self.connection() as conn:
with self.connection as conn:
result = conn.execute(stmt).fetchone()
return False if result is None else result["exists"]
@ -113,7 +168,7 @@ class DBConnector:
)
"""
).format(Literal(table_name))
with self.connection() as conn:
with self.connection as conn:
result = conn.execute(stmt).fetchone()
return False if result is None else result["exists"]
@ -127,13 +182,13 @@ class DBConnector:
stmt = SQL("CREATE TABLE IF NOT EXISTS {} ({})").format(
Identifier(table_name), SQL(", ").join(map(lambda c: c.sql(), columns))
)
with self.connection() as conn:
with self.connection as conn:
conn.execute(stmt)
self.saveTableMetadata(table_name, columns, system, hidden)
def dropTable(self, table_name: str):
stmt = SQL("DROP TABLE IF EXISTS {}").format(Identifier(table_name))
with self.connection() as conn:
with self.connection as conn:
conn.execute(stmt)
self.removeTableMetadata(table_name)
@ -141,7 +196,7 @@ class DBConnector:
stmt = SQL("SELECT * FROM tables_metadata WHERE table_name = {}").format(
Literal(table_name)
)
with self.connection() as conn:
with self.connection as conn:
return conn.execute(stmt).fetchone()
def insertIntoTable(self, table_name: str, columns: dict[str, Any]):
@ -153,7 +208,7 @@ class DBConnector:
SQL(", ").join(map(lambda c: Identifier(c), columns.keys())),
SQL(", ").join(map(lambda c: Literal(c), columns.values())),
)
with self.connection() as conn:
with self.connection as conn:
conn.execute(stmt)
def selectFromTable(
@ -166,7 +221,7 @@ class DBConnector:
SQL(", ").join(map(lambda c: Identifier(c), columns)),
Identifier(table_name),
)
with self.connection() as conn:
with self.connection as conn:
return conn.execute(stmt).fetchall()
def filterFromTable(
@ -183,7 +238,7 @@ class DBConnector:
Identifier(table_name),
SQL(" AND ").join(map(lambda cp: cp.sql(), where)),
)
with self.connection() as conn:
with self.connection as conn:
return conn.execute(stmt).fetchall()
def updateDataInTable(
@ -194,7 +249,7 @@ class DBConnector:
SQL(", ").join(map(lambda cp: cp.sql(), columns)),
SQL(" AND ").join(map(lambda cp: cp.sql(), where)),
)
with self.connection() as conn:
with self.connection as conn:
conn.execute(stmt)
def selectJoinedTables(
@ -209,7 +264,7 @@ class DBConnector:
Identifier(join_table),
Identifier(join_column),
)
with self.connection() as conn:
with self.connection as conn:
return conn.execute(stmt).fetchall()
def deleteFromTable(self, table_name: str, where: list[ColumnCondition]):
@ -217,5 +272,5 @@ class DBConnector:
Identifier(table_name),
SQL(" AND ").join(map(lambda cp: cp.sql(), where)),
)
with self.connection() as conn:
with self.connection as conn:
conn.execute(stmt)