diff --git a/db.py b/db.py index 853f683..394d80e 100644 --- a/db.py +++ b/db.py @@ -3,7 +3,7 @@ from psycopg import Connection from psycopg.rows import namedtuple_row, dict_row from psycopg.sql import SQL, Identifier, Literal from psycopg_pool import ConnectionPool -from columns import ColumnDefinition +from columns import * class _ExtendedConnection(Connection): @@ -12,39 +12,111 @@ class _ExtendedConnection(Connection): self.row_factory = dict_row +class ColumnCondition: + def __init__( + self, column: str, value: Any, isString: bool = False, isLike: bool = False + ): + self.column = column + self.value = value + self.isString = isString + self.isLike = isLike + + def sql(self): + if self.isString: + if self.isLike: + return SQL("{} LIKE {}").format( + Identifier(self.column), Literal(self.value) + ) + else: + return SQL("{} NOT LIKE {}").format( + Identifier(self.column), Literal(self.value) + ) + return SQL("{} = {}").format(Identifier(self.column), Literal(self.value)) + + +class ColumnUpdate: + def __init__(self, column: str, value: Any): + self.column = column + self.value = value + + def sql(self): + return SQL("{} = {}").format(Identifier(self.column), Literal(self.value)) + + class DBConnector: def __init__(self, conninfo: str): self._pool = ConnectionPool(conninfo, connection_class=_ExtendedConnection) + if not self.tableExists("tables_metadata"): + columns = [ + PrimaryUUIDColumnDefinition("table_id"), + TextColumnDefinition("table_name"), + TextColumnDefinition("columns"), + BooleanColumnDefinition("system", False), + BooleanColumnDefinition("hidden", False), + ] + self.createTable("tables_metadata", columns, system=True, hidden=True) + def connection(self): return self._pool.connection() - def addTableInfoToHstore(self, table_name: str, columns: list[ColumnDefinition]): - # TODO: implement this - ... + def saveTableMetadata( + self, + table_name: str, + columns: list[ColumnDefinition], + system: bool = False, + hidden: bool = False, + ): + self.insertIntoTable( + "tables_metadata", + { + "table_name": table_name, + "columns": ",".join(map(lambda c: c.serialize(), columns)), + "system": system, + "hidden": hidden, + }, + ) + + def removeTableMetadata(self, table_name: str): + self.deleteFromTable( + "tables_metadata", + [ColumnCondition("table_name", table_name, isString=True, isLike=True)], + ) def tableExists(self, table_name: str): stmt = SQL( - """SELECT EXISTS ( - SELECT FROM information_schema.tables - WHERE table_schema LIKE 'public' AND table_type LIKE 'BASE TABLE' AND table_name = {} - );""" + """ + SELECT EXISTS ( + SELECT 1 + FROM tables_metadata + WHERE table_name = {} + ) + """ ).format(Literal(table_name)) with self.connection() as conn: result = conn.execute(stmt).fetchone() + print(result) return False if result is None else result["exists"] - def createTable(self, table_name: str, columns: list[ColumnDefinition]): + def createTable( + self, + table_name: str, + columns: list[ColumnDefinition], + system: bool = False, + hidden: bool = False, + ): stmt = SQL("CREATE TABLE IF NOT EXISTS {} ({})").format( Identifier(table_name), SQL(", ").join(map(lambda c: c.sql(), columns)) ) with self.connection() as conn: conn.execute(stmt) + self.saveTableMetadata(table_name, columns, system, hidden) def dropTable(self, table_name: str): stmt = SQL("DROP TABLE IF EXISTS {}").format(Identifier(table_name)) with self.connection() as conn: conn.execute(stmt) + self.removeTableMetadata(table_name) def insertIntoTable(self, table_name: str, columns: dict[str, Any]): if len(columns) == 0: @@ -72,22 +144,12 @@ class DBConnector: return conn.execute(stmt).fetchall() def updateDataInTable( - self, table_name: str, columns: dict[str, Any], where: dict[str, Any] + self, table_name: str, columns: list[ColumnUpdate], where: list[ColumnCondition] ): stmt = SQL("UPDATE {} SET {} WHERE {}").format( Identifier(table_name), - SQL(", ").join( - map( - lambda c: Identifier(c) + SQL(" = ") + Literal(columns[c]), - columns.keys(), - ) - ), - SQL(", ").join( - map( - lambda c: Identifier(c) + SQL(" LIKE ") + Literal(where[c]), - where.keys(), - ) - ), + SQL(", ").join(map(lambda cp: cp.sql(), columns)), + SQL(", ").join(map(lambda cp: cp.sql(), where)), ) with self.connection() as conn: conn.execute(stmt) @@ -106,3 +168,11 @@ class DBConnector: ) with self.connection() as conn: return conn.execute(stmt).fetchall() + + def deleteFromTable(self, table_name: str, where: list[ColumnCondition]): + stmt = SQL("DELETE FROM {} WHERE {}").format( + Identifier(table_name), + SQL(", ").join(map(lambda cp: cp.sql(), where)), + ) + with self.connection() as conn: + conn.execute(stmt) diff --git a/db_test.py b/db_test.py index 7ce2bb8..d587e2e 100644 --- a/db_test.py +++ b/db_test.py @@ -1,6 +1,6 @@ import unittest import uuid -from db import DBConnector +from db import DBConnector, ColumnUpdate, ColumnCondition from columns import * conninfo = "postgresql://postgres:asarch6122@localhost" @@ -133,12 +133,42 @@ class TestDBConnector(unittest.TestCase): rows = connector.selectFromTable("test_table", ["name"]) self.assertEqual(rows[0]["name"], "John Doe") connector.updateDataInTable( - "test_table", {"name": "John"}, {"name": "John Doe"} + "test_table", + [ + ColumnUpdate("name", "John"), + ], + [ + ColumnCondition("name", "John Doe"), + ], ) rows = connector.selectFromTable("test_table", ["name"]) self.assertEqual(rows[0]["name"], "John") connector.dropTable("test_table") + def test_tableDeleteFrom(self): + connector.dropTable("test_table") + connector.createTable( + "test_table", + [ + PrimarySerialColumnDefinition("id"), + TextColumnDefinition("name", "John Doe"), + ], + ) + connector.insertIntoTable("test_table", {"name": "John Doe"}) + connector.insertIntoTable("test_table", {"name": "Jane Doe"}) + connector.insertIntoTable("test_table", {"name": "Mikhail Prokopenko"}) + rows = connector.selectFromTable("test_table", ["name"]) + self.assertEqual(len(rows), 3) + self.assertEqual(rows[0]["name"], "John Doe") + self.assertEqual(rows[1]["name"], "Jane Doe") + self.assertEqual(rows[2]["name"], "Mikhail Prokopenko") + connector.deleteFromTable("test_table", [ColumnCondition("name", "John Doe")]) + rows = connector.selectFromTable("test_table", ["name"]) + self.assertEqual(len(rows), 2) + self.assertEqual(rows[0]["name"], "Jane Doe") + self.assertEqual(rows[1]["name"], "Mikhail Prokopenko") + connector.dropTable("test_table") + if __name__ == "__main__": unittest.main()