Wow it is better now

This commit is contained in:
Andrew 2023-04-23 00:40:07 +07:00
parent 3afdd7276f
commit edaafcdf4f
3 changed files with 127 additions and 130 deletions

View file

@ -1 +1 @@
__version__ = "0.1.0" __version__ = "0.2.0"

View file

@ -1,37 +1,23 @@
from datetime import date from datetime import date
from psycopg.sql import SQL, Identifier, Literal, Composed from psycopg.sql import SQL, Identifier, Literal
class ColumnDefinition: class ColumnDefinition:
def __init__(self, name: str): def __init__(self, name: str, unique: bool, has_default: bool = False):
self.name = name self.name = name
self.unique = unique
self.has_default = has_default
def sql(self): def sql(self):
raise NotImplementedError raise NotImplementedError()
def serialize(self): def serialize(self):
raise NotImplementedError raise NotImplementedError()
class UniqueColumnDefinition(ColumnDefinition):
def __init__(self, wrapped: ColumnDefinition):
super().__init__(wrapped.name)
self.wrapped = wrapped
def sql(self):
return SQL("{} UNIQUE").format(self.wrapped.sql())
def serialize(self):
return f"{self.wrapped.serialize()}:unique"
def make_column_unique(column) -> UniqueColumnDefinition:
return UniqueColumnDefinition(column)
class PrimarySerialColumnDefinition(ColumnDefinition): class PrimarySerialColumnDefinition(ColumnDefinition):
def __init__(self, name: str): def __init__(self, name: str):
super().__init__(name) super().__init__(name, unique=False)
def sql(self): def sql(self):
return SQL("{} SERIAL PRIMARY KEY").format(Identifier(self.name)) return SQL("{} SERIAL PRIMARY KEY").format(Identifier(self.name))
@ -40,150 +26,106 @@ class PrimarySerialColumnDefinition(ColumnDefinition):
return f"{self.name}:serial:primary" return f"{self.name}:serial:primary"
class PrimaryUUIDColumnDefinition(ColumnDefinition):
def __init__(self, name: str):
super().__init__(name)
def sql(self):
return SQL("{} uuid DEFAULT gen_random_uuid() PRIMARY KEY").format(
Identifier(self.name)
)
def serialize(self):
return f"{self.name}:uuid:primary"
class TextColumnDefinition(ColumnDefinition): class TextColumnDefinition(ColumnDefinition):
def __init__(self, name: str, default: str | None = None): def __init__(self, name: str, default: str | None = None, unique: bool = False):
super().__init__(name) if default is not None and unique:
raise ValueError("Cannot have a default value and be unique")
super().__init__(name, unique)
self.default = default self.default = default
def sql(self): def sql(self):
if self.default is None: if self.default is None:
return SQL("{} TEXT").format(Identifier(self.name)) return SQL("{} TEXT UNIQUE" if self.unique else "{} TEXT").format(
Identifier(self.name)
)
else: else:
return SQL("{} TEXT DEFAULT {}").format( return SQL("{} TEXT DEFAULT {}").format(
Identifier(self.name), Literal(self.default) Identifier(self.name), Literal(self.default)
) )
def serialize(self): def serialize(self):
return f"{self.name}:str" return f"{self.name}:str{':unique' if self.unique else ''}{':default' if self.default is not None else ''}"
class BigintColumnDefinition(ColumnDefinition):
def __init__(self, name: str, default: int | None = None):
super().__init__(name)
self.default = default
def sql(self):
if self.default is None:
return SQL("{} BIGINT").format(Identifier(self.name))
else:
return SQL("{} BIGINT DEFAULT {}").format(
Identifier(self.name), Literal(self.default)
)
def serialize(self):
return f"{self.name}:bigint"
class BooleanColumnDefinition(ColumnDefinition): class BooleanColumnDefinition(ColumnDefinition):
def __init__(self, name: str, default: bool | None = None): def __init__(self, name: str, default: bool | None = None, unique: bool = False):
super().__init__(name) if default is not None and unique:
raise ValueError("Cannot have a default value and be unique")
super().__init__(name, unique)
self.default = default self.default = default
def sql(self): def sql(self):
if self.default is None: if self.default is None:
return SQL("{} BOOLEAN").format(Identifier(self.name)) return SQL("{} BOOLEAN UNIQUE" if self.unique else "{} BOOLEAN").format(
Identifier(self.name)
)
else: else:
return SQL("{} BOOLEAN DEFAULT {}").format( return SQL("{} BOOLEAN DEFAULT {}").format(
Identifier(self.name), Literal(self.default) Identifier(self.name), Literal(self.default)
) )
def serialize(self): def serialize(self):
return f"{self.name}:bool" return f"{self.name}:bool{':unique' if self.unique else ''}{':default' if self.default is not None else ''}"
class DateColumnDefinition(ColumnDefinition):
def __init__(self, name: str, default: date | None = None):
super().__init__(name)
self.default = default
def sql(self):
if self.default is None:
return SQL("{} DATE").format(Identifier(self.name))
else:
return SQL("{} DATE DEFAULT {}").format(
Identifier(self.name), Literal(self.default)
)
def serialize(self):
return f"{self.name}:date"
class TimestampColumnDefinition(ColumnDefinition): class TimestampColumnDefinition(ColumnDefinition):
def __init__(self, name: str, default: date | None = None): def __init__(self, name: str, default: date | None = None, unique: bool = False):
super().__init__(name) if default is not None and unique:
raise ValueError("Cannot have a default value and be unique")
super().__init__(name, unique)
self.default = default self.default = default
def sql(self): def sql(self):
if self.default is None: if self.default is None:
return SQL("{} TIMESTAMP").format(Identifier(self.name)) return SQL("{} TIMESTAMP UNIQUE" if self.unique else "{} TIMESTAMP").format(
Identifier(self.name)
)
else: else:
return SQL("{} TIMESTAMP DEFAULT {}").format( return SQL("{} TIMESTAMP DEFAULT {}").format(
Identifier(self.name), Literal(self.default) Identifier(self.name), Literal(self.default)
) )
def serialize(self): def serialize(self):
return f"{self.name}:datetime" return f"{self.name}:datetime{':unique' if self.unique else ''}{':default' if self.default is not None else ''}"
class DoubleColumnDefinition(ColumnDefinition): class DoubleColumnDefinition(ColumnDefinition):
def __init__(self, name: str, default: float | None = None): def __init__(self, name: str, default: float | None = None, unique: bool = False):
super().__init__(name) if default is not None and unique:
raise ValueError("Cannot have a default value and be unique")
super().__init__(name, unique)
self.default = default self.default = default
def sql(self): def sql(self):
if self.default is None: if self.default is None:
return SQL("{} DOUBLE PRECISION").format(Identifier(self.name)) return SQL(
"{} DOUBLE PRECISION UNIQUE" if self.unique else "{} DOUBLE PRECISION"
).format(Identifier(self.name))
else: else:
return SQL("{} DOUBLE PRECISION DEFAULT {}").format( return SQL("{} DOUBLE PRECISION DEFAULT {}").format(
Identifier(self.name), Literal(self.default) Identifier(self.name), Literal(self.default)
) )
def serialize(self): def serialize(self):
return f"{self.name}:float" return f"{self.name}:float{':unique' if self.unique else ''}{':default' if self.default is not None else ''}"
class IntegerColumnDefinition(ColumnDefinition): class IntegerColumnDefinition(ColumnDefinition):
def __init__(self, name: str, default: int | None = None): def __init__(self, name: str, default: int | None = None, unique: bool = False):
super().__init__(name) if default is not None and unique:
raise ValueError("Cannot have a default value and be unique")
super().__init__(name, unique)
self.default = default self.default = default
def sql(self): def sql(self):
if self.default is None: if self.default is None:
return SQL("{} INTEGER").format(Identifier(self.name)) return SQL("{} INTEGER UNIQUE" if self.unique else "{} INTEGER").format(
Identifier(self.name)
)
else: else:
return SQL("{} INTEGER DEFAULT {}").format( return SQL("{} INTEGER DEFAULT {}").format(
Identifier(self.name), Literal(self.default) Identifier(self.name), Literal(self.default)
) )
def serialize(self): def serialize(self):
return f"{self.name}:int" return f"{self.name}:int{':unique' if self.unique else ''}{':default' if self.default is not None else ''}"
class UUIDColumnDefinition(ColumnDefinition):
def __init__(self, name: str, default: str | None = None):
super().__init__(name)
self.default = default
def sql(self):
if self.default is None:
return SQL("{} UUID").format(Identifier(self.name))
else:
return SQL("{} UUID DEFAULT {}").format(
Identifier(self.name), Literal(self.default)
)
def serialize(self):
return f"{self.name}:uuid"

View file

@ -1,4 +1,4 @@
from typing import Any from typing import Any, Literal as _Literal
from psycopg import Connection from psycopg import Connection
from psycopg.rows import dict_row from psycopg.rows import dict_row
from psycopg.sql import SQL, Identifier, Literal from psycopg.sql import SQL, Identifier, Literal
@ -13,26 +13,80 @@ class _ExtendedConnection(Connection):
self.row_factory = dict_row self.row_factory = dict_row
CONDITION_OPERATORS = _Literal[
"eq",
"ne",
"gt",
"lt",
"ge",
"le",
"contains",
"not_contains",
"starts_with",
"not_starts_with",
"ends_with",
"not_ends_with",
]
class ColumnCondition: class ColumnCondition:
def __init__( def __init__(self, column: str, operator: CONDITION_OPERATORS, value: Any):
self, column: str, value: Any, isString: bool = False, isLike: bool = True
):
self.column = column self.column = column
self.operator = operator
self.value = value self.value = value
self.isString = isString
self.isLike = isLike
def sql(self): def sql(self):
if self.isString: match self.operator:
if self.isLike: case "eq":
return SQL("{} = {}").format(
Identifier(self.column), Literal(self.value)
)
case "ne":
return SQL("{} != {}").format(
Identifier(self.column), Literal(self.value)
)
case "gt":
return SQL("{} > {}").format(
Identifier(self.column), Literal(self.value)
)
case "lt":
return SQL("{} < {}").format(
Identifier(self.column), Literal(self.value)
)
case "ge":
return SQL("{} >= {}").format(
Identifier(self.column), Literal(self.value)
)
case "le":
return SQL("{} <= {}").format(
Identifier(self.column), Literal(self.value)
)
case "contains":
return SQL("{} LIKE {}").format( return SQL("{} LIKE {}").format(
Identifier(self.column), Literal(self.value) Identifier(self.column), Literal(self.value)
) )
else: case "not_contains":
return SQL("{} NOT LIKE {}").format( return SQL("{} NOT LIKE {}").format(
Identifier(self.column), Literal(self.value) Identifier(self.column), Literal(self.value)
) )
return SQL("{} = {}").format(Identifier(self.column), Literal(self.value)) case "starts_with":
return SQL("{} LIKE {}").format(
Identifier(self.column), Literal(f"{self.value}%")
)
case "not_starts_with":
return SQL("{} NOT LIKE {}").format(
Identifier(self.column), Literal(f"{self.value}%")
)
case "ends_with":
return SQL("{} LIKE {}").format(
Identifier(self.column), Literal(f"%{self.value}")
)
case "not_ends_with":
return SQL("{} NOT LIKE {}").format(
Identifier(self.column), Literal(f"%{self.value}")
)
case _:
raise ValueError(f"Unknown operator: {self.operator}")
class ColumnUpdate: class ColumnUpdate:
@ -50,7 +104,7 @@ class DBConnector:
if not self.__tableExistsInternal("tables_metadata"): if not self.__tableExistsInternal("tables_metadata"):
columns = [ columns = [
PrimaryUUIDColumnDefinition("table_id"), PrimarySerialColumnDefinition("table_id"),
TextColumnDefinition("table_name"), TextColumnDefinition("table_name"),
TextColumnDefinition("columns"), TextColumnDefinition("columns"),
BooleanColumnDefinition("system", False), BooleanColumnDefinition("system", False),
@ -58,6 +112,7 @@ class DBConnector:
] ]
self.createTable("tables_metadata", columns, system=True, hidden=True) self.createTable("tables_metadata", columns, system=True, hidden=True)
@property
def connection(self): def connection(self):
return self._pool.connection() return self._pool.connection()
@ -81,12 +136,12 @@ class DBConnector:
def removeTableMetadata(self, table_name: str): def removeTableMetadata(self, table_name: str):
self.deleteFromTable( self.deleteFromTable(
"tables_metadata", "tables_metadata",
[ColumnCondition("table_name", table_name, isString=True, isLike=True)], [ColumnCondition("table_name", "eq", table_name)],
) )
def tables(self): def tables(self):
stmt = SQL("SELECT * FROM tables_metadata") stmt = SQL("SELECT * FROM tables_metadata")
with self.connection() as conn: with self.connection as conn:
return conn.execute(stmt).fetchall() return conn.execute(stmt).fetchall()
def tableExists(self, table_name: str): def tableExists(self, table_name: str):
@ -99,7 +154,7 @@ class DBConnector:
) )
""" """
).format(Literal(table_name)) ).format(Literal(table_name))
with self.connection() as conn: with self.connection as conn:
result = conn.execute(stmt).fetchone() result = conn.execute(stmt).fetchone()
return False if result is None else result["exists"] return False if result is None else result["exists"]
@ -113,7 +168,7 @@ class DBConnector:
) )
""" """
).format(Literal(table_name)) ).format(Literal(table_name))
with self.connection() as conn: with self.connection as conn:
result = conn.execute(stmt).fetchone() result = conn.execute(stmt).fetchone()
return False if result is None else result["exists"] return False if result is None else result["exists"]
@ -127,13 +182,13 @@ class DBConnector:
stmt = SQL("CREATE TABLE IF NOT EXISTS {} ({})").format( stmt = SQL("CREATE TABLE IF NOT EXISTS {} ({})").format(
Identifier(table_name), SQL(", ").join(map(lambda c: c.sql(), columns)) Identifier(table_name), SQL(", ").join(map(lambda c: c.sql(), columns))
) )
with self.connection() as conn: with self.connection as conn:
conn.execute(stmt) conn.execute(stmt)
self.saveTableMetadata(table_name, columns, system, hidden) self.saveTableMetadata(table_name, columns, system, hidden)
def dropTable(self, table_name: str): def dropTable(self, table_name: str):
stmt = SQL("DROP TABLE IF EXISTS {}").format(Identifier(table_name)) stmt = SQL("DROP TABLE IF EXISTS {}").format(Identifier(table_name))
with self.connection() as conn: with self.connection as conn:
conn.execute(stmt) conn.execute(stmt)
self.removeTableMetadata(table_name) self.removeTableMetadata(table_name)
@ -141,7 +196,7 @@ class DBConnector:
stmt = SQL("SELECT * FROM tables_metadata WHERE table_name = {}").format( stmt = SQL("SELECT * FROM tables_metadata WHERE table_name = {}").format(
Literal(table_name) Literal(table_name)
) )
with self.connection() as conn: with self.connection as conn:
return conn.execute(stmt).fetchone() return conn.execute(stmt).fetchone()
def insertIntoTable(self, table_name: str, columns: dict[str, Any]): def insertIntoTable(self, table_name: str, columns: dict[str, Any]):
@ -153,7 +208,7 @@ class DBConnector:
SQL(", ").join(map(lambda c: Identifier(c), columns.keys())), SQL(", ").join(map(lambda c: Identifier(c), columns.keys())),
SQL(", ").join(map(lambda c: Literal(c), columns.values())), SQL(", ").join(map(lambda c: Literal(c), columns.values())),
) )
with self.connection() as conn: with self.connection as conn:
conn.execute(stmt) conn.execute(stmt)
def selectFromTable( def selectFromTable(
@ -166,7 +221,7 @@ class DBConnector:
SQL(", ").join(map(lambda c: Identifier(c), columns)), SQL(", ").join(map(lambda c: Identifier(c), columns)),
Identifier(table_name), Identifier(table_name),
) )
with self.connection() as conn: with self.connection as conn:
return conn.execute(stmt).fetchall() return conn.execute(stmt).fetchall()
def filterFromTable( def filterFromTable(
@ -183,7 +238,7 @@ class DBConnector:
Identifier(table_name), Identifier(table_name),
SQL(" AND ").join(map(lambda cp: cp.sql(), where)), SQL(" AND ").join(map(lambda cp: cp.sql(), where)),
) )
with self.connection() as conn: with self.connection as conn:
return conn.execute(stmt).fetchall() return conn.execute(stmt).fetchall()
def updateDataInTable( def updateDataInTable(
@ -194,7 +249,7 @@ class DBConnector:
SQL(", ").join(map(lambda cp: cp.sql(), columns)), SQL(", ").join(map(lambda cp: cp.sql(), columns)),
SQL(" AND ").join(map(lambda cp: cp.sql(), where)), SQL(" AND ").join(map(lambda cp: cp.sql(), where)),
) )
with self.connection() as conn: with self.connection as conn:
conn.execute(stmt) conn.execute(stmt)
def selectJoinedTables( def selectJoinedTables(
@ -209,7 +264,7 @@ class DBConnector:
Identifier(join_table), Identifier(join_table),
Identifier(join_column), Identifier(join_column),
) )
with self.connection() as conn: with self.connection as conn:
return conn.execute(stmt).fetchall() return conn.execute(stmt).fetchall()
def deleteFromTable(self, table_name: str, where: list[ColumnCondition]): def deleteFromTable(self, table_name: str, where: list[ColumnCondition]):
@ -217,5 +272,5 @@ class DBConnector:
Identifier(table_name), Identifier(table_name),
SQL(" AND ").join(map(lambda cp: cp.sql(), where)), SQL(" AND ").join(map(lambda cp: cp.sql(), where)),
) )
with self.connection() as conn: with self.connection as conn:
conn.execute(stmt) conn.execute(stmt)