from typing import Any, NamedTuple from psycopg import Connection from psycopg.rows import namedtuple_row, dict_row from psycopg.sql import SQL, Identifier, Literal from psycopg_pool import ConnectionPool from columns import ColumnDefinition class _ExtendedConnection(Connection): def __init__(self, *args, **kwargs): super(_ExtendedConnection, self).__init__(*args, **kwargs) self.row_factory = dict_row class DBConnector: def __init__(self, conninfo: str): self._pool = ConnectionPool(conninfo, connection_class=_ExtendedConnection) def connection(self): return self._pool.connection() def addTableInfoToHstore(self, table_name: str, columns: list[ColumnDefinition]): # TODO: implement this ... def tableExists(self, table_name: str): stmt = SQL( """SELECT EXISTS ( SELECT FROM information_schema.tables WHERE table_schema LIKE 'public' AND table_type LIKE 'BASE TABLE' AND table_name = {} );""" ).format(Literal(table_name)) with self.connection() as conn: result = conn.execute(stmt).fetchone() return False if result is None else result["exists"] def createTable(self, table_name: str, columns: list[ColumnDefinition]): stmt = SQL("CREATE TABLE IF NOT EXISTS {} ({})").format( Identifier(table_name), SQL(", ").join(map(lambda c: c.sql(), columns)) ) with self.connection() as conn: conn.execute(stmt) def dropTable(self, table_name: str): stmt = SQL("DROP TABLE IF EXISTS {}").format(Identifier(table_name)) with self.connection() as conn: conn.execute(stmt) def insertIntoTable(self, table_name: str, columns: dict[str, Any]): if len(columns) == 0: stmt = SQL("INSERT INTO {} DEFAULT VALUES").format(Identifier(table_name)) else: stmt = SQL("INSERT INTO {} ({}) VALUES ({})").format( Identifier(table_name), SQL(", ").join(map(lambda c: Identifier(c), columns.keys())), SQL(", ").join(map(lambda c: Literal(c), columns.values())), ) with self.connection() as conn: conn.execute(stmt) def selectFromTable( self, table_name: str, columns: list[str] ) -> list[dict[str, Any]]: if len(columns) == 1 and columns[0] == "*": stmt = SQL("SELECT * FROM {}").format(Identifier(table_name)) else: stmt = SQL("SELECT {} FROM {}").format( SQL(", ").join(map(lambda c: Identifier(c), columns)), Identifier(table_name), ) with self.connection() as conn: return conn.execute(stmt).fetchall() def updateDataInTable( self, table_name: str, columns: dict[str, Any], where: dict[str, Any] ): stmt = SQL("UPDATE {} SET {} WHERE {}").format( Identifier(table_name), SQL(", ").join( map( lambda c: Identifier(c) + SQL(" = ") + Literal(columns[c]), columns.keys(), ) ), SQL(", ").join( map( lambda c: Identifier(c) + SQL(" LIKE ") + Literal(where[c]), where.keys(), ) ), ) with self.connection() as conn: conn.execute(stmt) def selectJoinedTables( self, base_table: str, join_table: str, columns: list[str], join_column: str ): stmt = SQL("SELECT {} FROM {} JOIN {} ON {}.{} = {}.{}").format( SQL(", ").join(map(lambda c: Identifier(c), columns)), Identifier(base_table), Identifier(join_table), Identifier(base_table), Identifier(join_column), Identifier(join_table), Identifier(join_column), ) with self.connection() as conn: return conn.execute(stmt).fetchall()