from typing import Any, NamedTuple from psycopg import Connection from psycopg.rows import namedtuple_row, dict_row from psycopg.sql import SQL, Identifier, Literal from psycopg_pool import ConnectionPool from columns import * class _ExtendedConnection(Connection): def __init__(self, *args, **kwargs): super(_ExtendedConnection, self).__init__(*args, **kwargs) self.row_factory = dict_row class ColumnCondition: def __init__( self, column: str, value: Any, isString: bool = False, isLike: bool = False ): self.column = column self.value = value self.isString = isString self.isLike = isLike def sql(self): if self.isString: if self.isLike: return SQL("{} LIKE {}").format( Identifier(self.column), Literal(self.value) ) else: return SQL("{} NOT LIKE {}").format( Identifier(self.column), Literal(self.value) ) return SQL("{} = {}").format(Identifier(self.column), Literal(self.value)) class ColumnUpdate: def __init__(self, column: str, value: Any): self.column = column self.value = value def sql(self): return SQL("{} = {}").format(Identifier(self.column), Literal(self.value)) class DBConnector: def __init__(self, conninfo: str): self._pool = ConnectionPool(conninfo, connection_class=_ExtendedConnection) if not self.tableExists("tables_metadata"): columns = [ PrimaryUUIDColumnDefinition("table_id"), TextColumnDefinition("table_name"), TextColumnDefinition("columns"), BooleanColumnDefinition("system", False), BooleanColumnDefinition("hidden", False), ] self.createTable("tables_metadata", columns, system=True, hidden=True) def connection(self): return self._pool.connection() def saveTableMetadata( self, table_name: str, columns: list[ColumnDefinition], system: bool = False, hidden: bool = False, ): self.insertIntoTable( "tables_metadata", { "table_name": table_name, "columns": ",".join(map(lambda c: c.serialize(), columns)), "system": system, "hidden": hidden, }, ) def removeTableMetadata(self, table_name: str): self.deleteFromTable( "tables_metadata", [ColumnCondition("table_name", table_name, isString=True, isLike=True)], ) def tableExists(self, table_name: str): stmt = SQL( """ SELECT EXISTS ( SELECT 1 FROM tables_metadata WHERE table_name = {} ) """ ).format(Literal(table_name)) with self.connection() as conn: result = conn.execute(stmt).fetchone() print(result) return False if result is None else result["exists"] def createTable( self, table_name: str, columns: list[ColumnDefinition], system: bool = False, hidden: bool = False, ): stmt = SQL("CREATE TABLE IF NOT EXISTS {} ({})").format( Identifier(table_name), SQL(", ").join(map(lambda c: c.sql(), columns)) ) with self.connection() as conn: conn.execute(stmt) self.saveTableMetadata(table_name, columns, system, hidden) def dropTable(self, table_name: str): stmt = SQL("DROP TABLE IF EXISTS {}").format(Identifier(table_name)) with self.connection() as conn: conn.execute(stmt) self.removeTableMetadata(table_name) def insertIntoTable(self, table_name: str, columns: dict[str, Any]): if len(columns) == 0: stmt = SQL("INSERT INTO {} DEFAULT VALUES").format(Identifier(table_name)) else: stmt = SQL("INSERT INTO {} ({}) VALUES ({})").format( Identifier(table_name), SQL(", ").join(map(lambda c: Identifier(c), columns.keys())), SQL(", ").join(map(lambda c: Literal(c), columns.values())), ) with self.connection() as conn: conn.execute(stmt) def selectFromTable( self, table_name: str, columns: list[str] ) -> list[dict[str, Any]]: if len(columns) == 1 and columns[0] == "*": stmt = SQL("SELECT * FROM {}").format(Identifier(table_name)) else: stmt = SQL("SELECT {} FROM {}").format( SQL(", ").join(map(lambda c: Identifier(c), columns)), Identifier(table_name), ) with self.connection() as conn: return conn.execute(stmt).fetchall() def updateDataInTable( self, table_name: str, columns: list[ColumnUpdate], where: list[ColumnCondition] ): stmt = SQL("UPDATE {} SET {} WHERE {}").format( Identifier(table_name), SQL(", ").join(map(lambda cp: cp.sql(), columns)), SQL(", ").join(map(lambda cp: cp.sql(), where)), ) with self.connection() as conn: conn.execute(stmt) def selectJoinedTables( self, base_table: str, join_table: str, columns: list[str], join_column: str ): stmt = SQL("SELECT {} FROM {} JOIN {} ON {}.{} = {}.{}").format( SQL(", ").join(map(lambda c: Identifier(c), columns)), Identifier(base_table), Identifier(join_table), Identifier(base_table), Identifier(join_column), Identifier(join_table), Identifier(join_column), ) with self.connection() as conn: return conn.execute(stmt).fetchall() def deleteFromTable(self, table_name: str, where: list[ColumnCondition]): stmt = SQL("DELETE FROM {} WHERE {}").format( Identifier(table_name), SQL(", ").join(map(lambda cp: cp.sql(), where)), ) with self.connection() as conn: conn.execute(stmt)