178 lines
6.1 KiB
Python
178 lines
6.1 KiB
Python
from typing import Any, NamedTuple
|
|
from psycopg import Connection
|
|
from psycopg.rows import namedtuple_row, dict_row
|
|
from psycopg.sql import SQL, Identifier, Literal
|
|
from psycopg_pool import ConnectionPool
|
|
from columns import *
|
|
|
|
|
|
class _ExtendedConnection(Connection):
|
|
def __init__(self, *args, **kwargs):
|
|
super(_ExtendedConnection, self).__init__(*args, **kwargs)
|
|
self.row_factory = dict_row
|
|
|
|
|
|
class ColumnCondition:
|
|
def __init__(
|
|
self, column: str, value: Any, isString: bool = False, isLike: bool = False
|
|
):
|
|
self.column = column
|
|
self.value = value
|
|
self.isString = isString
|
|
self.isLike = isLike
|
|
|
|
def sql(self):
|
|
if self.isString:
|
|
if self.isLike:
|
|
return SQL("{} LIKE {}").format(
|
|
Identifier(self.column), Literal(self.value)
|
|
)
|
|
else:
|
|
return SQL("{} NOT LIKE {}").format(
|
|
Identifier(self.column), Literal(self.value)
|
|
)
|
|
return SQL("{} = {}").format(Identifier(self.column), Literal(self.value))
|
|
|
|
|
|
class ColumnUpdate:
|
|
def __init__(self, column: str, value: Any):
|
|
self.column = column
|
|
self.value = value
|
|
|
|
def sql(self):
|
|
return SQL("{} = {}").format(Identifier(self.column), Literal(self.value))
|
|
|
|
|
|
class DBConnector:
|
|
def __init__(self, conninfo: str):
|
|
self._pool = ConnectionPool(conninfo, connection_class=_ExtendedConnection)
|
|
|
|
if not self.tableExists("tables_metadata"):
|
|
columns = [
|
|
PrimaryUUIDColumnDefinition("table_id"),
|
|
TextColumnDefinition("table_name"),
|
|
TextColumnDefinition("columns"),
|
|
BooleanColumnDefinition("system", False),
|
|
BooleanColumnDefinition("hidden", False),
|
|
]
|
|
self.createTable("tables_metadata", columns, system=True, hidden=True)
|
|
|
|
def connection(self):
|
|
return self._pool.connection()
|
|
|
|
def saveTableMetadata(
|
|
self,
|
|
table_name: str,
|
|
columns: list[ColumnDefinition],
|
|
system: bool = False,
|
|
hidden: bool = False,
|
|
):
|
|
self.insertIntoTable(
|
|
"tables_metadata",
|
|
{
|
|
"table_name": table_name,
|
|
"columns": ",".join(map(lambda c: c.serialize(), columns)),
|
|
"system": system,
|
|
"hidden": hidden,
|
|
},
|
|
)
|
|
|
|
def removeTableMetadata(self, table_name: str):
|
|
self.deleteFromTable(
|
|
"tables_metadata",
|
|
[ColumnCondition("table_name", table_name, isString=True, isLike=True)],
|
|
)
|
|
|
|
def tableExists(self, table_name: str):
|
|
stmt = SQL(
|
|
"""
|
|
SELECT EXISTS (
|
|
SELECT 1
|
|
FROM tables_metadata
|
|
WHERE table_name = {}
|
|
)
|
|
"""
|
|
).format(Literal(table_name))
|
|
with self.connection() as conn:
|
|
result = conn.execute(stmt).fetchone()
|
|
print(result)
|
|
return False if result is None else result["exists"]
|
|
|
|
def createTable(
|
|
self,
|
|
table_name: str,
|
|
columns: list[ColumnDefinition],
|
|
system: bool = False,
|
|
hidden: bool = False,
|
|
):
|
|
stmt = SQL("CREATE TABLE IF NOT EXISTS {} ({})").format(
|
|
Identifier(table_name), SQL(", ").join(map(lambda c: c.sql(), columns))
|
|
)
|
|
with self.connection() as conn:
|
|
conn.execute(stmt)
|
|
self.saveTableMetadata(table_name, columns, system, hidden)
|
|
|
|
def dropTable(self, table_name: str):
|
|
stmt = SQL("DROP TABLE IF EXISTS {}").format(Identifier(table_name))
|
|
with self.connection() as conn:
|
|
conn.execute(stmt)
|
|
self.removeTableMetadata(table_name)
|
|
|
|
def insertIntoTable(self, table_name: str, columns: dict[str, Any]):
|
|
if len(columns) == 0:
|
|
stmt = SQL("INSERT INTO {} DEFAULT VALUES").format(Identifier(table_name))
|
|
else:
|
|
stmt = SQL("INSERT INTO {} ({}) VALUES ({})").format(
|
|
Identifier(table_name),
|
|
SQL(", ").join(map(lambda c: Identifier(c), columns.keys())),
|
|
SQL(", ").join(map(lambda c: Literal(c), columns.values())),
|
|
)
|
|
with self.connection() as conn:
|
|
conn.execute(stmt)
|
|
|
|
def selectFromTable(
|
|
self, table_name: str, columns: list[str]
|
|
) -> list[dict[str, Any]]:
|
|
if len(columns) == 1 and columns[0] == "*":
|
|
stmt = SQL("SELECT * FROM {}").format(Identifier(table_name))
|
|
else:
|
|
stmt = SQL("SELECT {} FROM {}").format(
|
|
SQL(", ").join(map(lambda c: Identifier(c), columns)),
|
|
Identifier(table_name),
|
|
)
|
|
with self.connection() as conn:
|
|
return conn.execute(stmt).fetchall()
|
|
|
|
def updateDataInTable(
|
|
self, table_name: str, columns: list[ColumnUpdate], where: list[ColumnCondition]
|
|
):
|
|
stmt = SQL("UPDATE {} SET {} WHERE {}").format(
|
|
Identifier(table_name),
|
|
SQL(", ").join(map(lambda cp: cp.sql(), columns)),
|
|
SQL(", ").join(map(lambda cp: cp.sql(), where)),
|
|
)
|
|
with self.connection() as conn:
|
|
conn.execute(stmt)
|
|
|
|
def selectJoinedTables(
|
|
self, base_table: str, join_table: str, columns: list[str], join_column: str
|
|
):
|
|
stmt = SQL("SELECT {} FROM {} JOIN {} ON {}.{} = {}.{}").format(
|
|
SQL(", ").join(map(lambda c: Identifier(c), columns)),
|
|
Identifier(base_table),
|
|
Identifier(join_table),
|
|
Identifier(base_table),
|
|
Identifier(join_column),
|
|
Identifier(join_table),
|
|
Identifier(join_column),
|
|
)
|
|
with self.connection() as conn:
|
|
return conn.execute(stmt).fetchall()
|
|
|
|
def deleteFromTable(self, table_name: str, where: list[ColumnCondition]):
|
|
stmt = SQL("DELETE FROM {} WHERE {}").format(
|
|
Identifier(table_name),
|
|
SQL(", ").join(map(lambda cp: cp.sql(), where)),
|
|
)
|
|
with self.connection() as conn:
|
|
conn.execute(stmt)
|