diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/columns.py b/columns.py new file mode 100644 index 0000000..dec9b68 --- /dev/null +++ b/columns.py @@ -0,0 +1,152 @@ +from datetime import date +from psycopg.sql import SQL, Identifier, Literal, Composed + + +class ColumnDefinition: + def __init__(self, name: str): + self.name = name + + def sql(self): + raise NotImplementedError + + +class UniqueColumnDefinition(ColumnDefinition): + def __init__(self, name: str): + super().__init__(name) + + def sql(self): + return SQL("{} UNIQUE").format(Identifier(self.name)) + + +def make_column_unique(column) -> UniqueColumnDefinition: + return UniqueColumnDefinition(column) + + +class PrimarySerialColumnDefinition(ColumnDefinition): + def __init__(self, name: str): + super().__init__(name) + + def sql(self): + return SQL("{} SERIAL PRIMARY KEY").format(Identifier(self.name)) + + +class PrimaryUUIDColumnDefinition(ColumnDefinition): + def __init__(self, name: str): + super().__init__(name) + + def sql(self): + return SQL("{} uuid DEFAULT gen_random_uuid() PRIMARY KEY").format( + Identifier(self.name) + ) + + +class TextColumnDefinition(ColumnDefinition): + def __init__(self, name: str, default: str | None = None): + super().__init__(name) + self.default = default + + def sql(self): + if self.default is None: + return SQL("{} TEXT").format(Identifier(self.name)) + else: + return SQL("{} TEXT DEFAULT {}").format( + Identifier(self.name), Literal(self.default) + ) + + +class BigintColumnDefinition(ColumnDefinition): + def __init__(self, name: str, default: int | None = None): + super().__init__(name) + self.default = default + + def sql(self): + if self.default is None: + return SQL("{} BIGINT").format(Identifier(self.name)) + else: + return SQL("{} BIGINT DEFAULT {}").format( + Identifier(self.name), Literal(self.default) + ) + + +class BooleanColumnDefinition(ColumnDefinition): + def __init__(self, name: str, default: bool | None = None): + super().__init__(name) + self.default = default + + def sql(self): + if self.default is None: + return SQL("{} BOOLEAN").format(Identifier(self.name)) + else: + return SQL("{} BOOLEAN DEFAULT {}").format( + Identifier(self.name), Literal(self.default) + ) + + +class DateColumnDefinition(ColumnDefinition): + def __init__(self, name: str, default: date | None = None): + super().__init__(name) + self.default = default + + def sql(self): + if self.default is None: + return SQL("{} DATE").format(Identifier(self.name)) + else: + return SQL("{} DATE DEFAULT {}").format( + Identifier(self.name), Literal(self.default) + ) + + +class TimestampColumnDefinition(ColumnDefinition): + def __init__(self, name: str, default: date | None = None): + super().__init__(name) + self.default = default + + def sql(self): + if self.default is None: + return SQL("{} TIMESTAMP").format(Identifier(self.name)) + else: + return SQL("{} TIMESTAMP DEFAULT {}").format( + Identifier(self.name), Literal(self.default) + ) + + +class DoubleColumnDefinition(ColumnDefinition): + def __init__(self, name: str, default: float | None = None): + super().__init__(name) + self.default = default + + def sql(self): + if self.default is None: + return SQL("{} DOUBLE PRECISION").format(Identifier(self.name)) + else: + return SQL("{} DOUBLE PRECISION DEFAULT {}").format( + Identifier(self.name), Literal(self.default) + ) + + +class IntegerColumnDefinition(ColumnDefinition): + def __init__(self, name: str, default: int | None = None): + super().__init__(name) + self.default = default + + def sql(self): + if self.default is None: + return SQL("{} INTEGER").format(Identifier(self.name)) + else: + return SQL("{} INTEGER DEFAULT {}").format( + Identifier(self.name), Literal(self.default) + ) + + +class UUIDColumnDefinition(ColumnDefinition): + def __init__(self, name: str, default: str | None = None): + super().__init__(name) + self.default = default + + def sql(self): + if self.default is None: + return SQL("{} UUID").format(Identifier(self.name)) + else: + return SQL("{} UUID DEFAULT {}").format( + Identifier(self.name), Literal(self.default) + ) diff --git a/db.py b/db.py new file mode 100644 index 0000000..853f683 --- /dev/null +++ b/db.py @@ -0,0 +1,108 @@ +from typing import Any, NamedTuple +from psycopg import Connection +from psycopg.rows import namedtuple_row, dict_row +from psycopg.sql import SQL, Identifier, Literal +from psycopg_pool import ConnectionPool +from columns import ColumnDefinition + + +class _ExtendedConnection(Connection): + def __init__(self, *args, **kwargs): + super(_ExtendedConnection, self).__init__(*args, **kwargs) + self.row_factory = dict_row + + +class DBConnector: + def __init__(self, conninfo: str): + self._pool = ConnectionPool(conninfo, connection_class=_ExtendedConnection) + + def connection(self): + return self._pool.connection() + + def addTableInfoToHstore(self, table_name: str, columns: list[ColumnDefinition]): + # TODO: implement this + ... + + def tableExists(self, table_name: str): + stmt = SQL( + """SELECT EXISTS ( + SELECT FROM information_schema.tables + WHERE table_schema LIKE 'public' AND table_type LIKE 'BASE TABLE' AND table_name = {} + );""" + ).format(Literal(table_name)) + with self.connection() as conn: + result = conn.execute(stmt).fetchone() + return False if result is None else result["exists"] + + def createTable(self, table_name: str, columns: list[ColumnDefinition]): + stmt = SQL("CREATE TABLE IF NOT EXISTS {} ({})").format( + Identifier(table_name), SQL(", ").join(map(lambda c: c.sql(), columns)) + ) + with self.connection() as conn: + conn.execute(stmt) + + def dropTable(self, table_name: str): + stmt = SQL("DROP TABLE IF EXISTS {}").format(Identifier(table_name)) + with self.connection() as conn: + conn.execute(stmt) + + def insertIntoTable(self, table_name: str, columns: dict[str, Any]): + if len(columns) == 0: + stmt = SQL("INSERT INTO {} DEFAULT VALUES").format(Identifier(table_name)) + else: + stmt = SQL("INSERT INTO {} ({}) VALUES ({})").format( + Identifier(table_name), + SQL(", ").join(map(lambda c: Identifier(c), columns.keys())), + SQL(", ").join(map(lambda c: Literal(c), columns.values())), + ) + with self.connection() as conn: + conn.execute(stmt) + + def selectFromTable( + self, table_name: str, columns: list[str] + ) -> list[dict[str, Any]]: + if len(columns) == 1 and columns[0] == "*": + stmt = SQL("SELECT * FROM {}").format(Identifier(table_name)) + else: + stmt = SQL("SELECT {} FROM {}").format( + SQL(", ").join(map(lambda c: Identifier(c), columns)), + Identifier(table_name), + ) + with self.connection() as conn: + return conn.execute(stmt).fetchall() + + def updateDataInTable( + self, table_name: str, columns: dict[str, Any], where: dict[str, Any] + ): + stmt = SQL("UPDATE {} SET {} WHERE {}").format( + Identifier(table_name), + SQL(", ").join( + map( + lambda c: Identifier(c) + SQL(" = ") + Literal(columns[c]), + columns.keys(), + ) + ), + SQL(", ").join( + map( + lambda c: Identifier(c) + SQL(" LIKE ") + Literal(where[c]), + where.keys(), + ) + ), + ) + with self.connection() as conn: + conn.execute(stmt) + + def selectJoinedTables( + self, base_table: str, join_table: str, columns: list[str], join_column: str + ): + stmt = SQL("SELECT {} FROM {} JOIN {} ON {}.{} = {}.{}").format( + SQL(", ").join(map(lambda c: Identifier(c), columns)), + Identifier(base_table), + Identifier(join_table), + Identifier(base_table), + Identifier(join_column), + Identifier(join_table), + Identifier(join_column), + ) + with self.connection() as conn: + return conn.execute(stmt).fetchall() diff --git a/db_test.py b/db_test.py new file mode 100644 index 0000000..7ce2bb8 --- /dev/null +++ b/db_test.py @@ -0,0 +1,144 @@ +import unittest +import uuid +from db import DBConnector +from columns import * + +conninfo = "postgresql://postgres:asarch6122@localhost" +connector = DBConnector(conninfo) + + +class TestDBConnector(unittest.TestCase): + def test_connection(self): + with connector.connection() as conn: + one: dict = conn.execute('SELECT 1 as "ONE"').fetchone() # type: ignore + self.assertIsNotNone(one) + self.assertEqual(one["ONE"], 1) + + def test_columnDefinition(self): + with connector.connection() as conn: + cdef1 = PrimarySerialColumnDefinition("id") + cdef2 = PrimaryUUIDColumnDefinition("uid") + cdef3 = TextColumnDefinition("name") + cdef4 = TextColumnDefinition("name", "John Doe") + self.assertEqual(cdef1.sql().as_string(conn), '"id" SERIAL PRIMARY KEY') + self.assertEqual( + cdef2.sql().as_string(conn), + '"uid" uuid DEFAULT gen_random_uuid() PRIMARY KEY', + ) + self.assertEqual(cdef3.sql().as_string(conn), '"name" TEXT') + self.assertEqual( + cdef4.sql().as_string(conn), "\"name\" TEXT DEFAULT 'John Doe'" + ) + + def test_tableCreation(self): + connector.dropTable("test_table") + self.assertFalse(connector.tableExists("test_table")) + connector.createTable( + "test_table", + [ + PrimarySerialColumnDefinition("id"), + TextColumnDefinition("name", "John Doe"), + ], + ) + self.assertTrue(connector.tableExists("test_table")) + connector.dropTable("test_table") + self.assertFalse(connector.tableExists("test_table")) + + def test_differentColumnsWorking(self): + connector.dropTable("test_table") + self.assertFalse(connector.tableExists("test_table")) + connector.createTable( + "test_table", + [ + PrimarySerialColumnDefinition("id"), + TextColumnDefinition("textcol", "John Doe"), + BigintColumnDefinition("bigintcol", 2**30), + BooleanColumnDefinition("boolcol", True), + DateColumnDefinition("datecol", date.today()), + TimestampColumnDefinition("tscol", date.today()), + DoubleColumnDefinition("doublecol", 3.14), + IntegerColumnDefinition("intcol", 100), + UUIDColumnDefinition("uuidcol", "00000000-0000-0000-0000-000000000000"), + ], + ) + self.assertTrue(connector.tableExists("test_table")) + connector.insertIntoTable("test_table", {}) + connector.insertIntoTable( + "test_table", + { + "textcol": "Jane Doe", + "bigintcol": 3**30, + "boolcol": False, + "datecol": date.today(), + "tscol": date.today(), + "doublecol": 3.14, + "intcol": 100, + "uuidcol": uuid.uuid4(), + }, + ) + res = connector.selectFromTable("test_table", ["*"]) + self.assertEqual(len(res), 2) + self.assertEqual(res[0]["textcol"], "John Doe") + self.assertEqual(res[0]["bigintcol"], 2**30) + self.assertEqual(res[0]["boolcol"], True) + self.assertIsNotNone(res[0]["datecol"]) + self.assertIsNotNone(res[0]["tscol"]) + self.assertEqual(res[0]["doublecol"], 3.14) + self.assertEqual(res[0]["intcol"], 100) + self.assertEqual( + res[0]["uuidcol"], uuid.UUID("00000000-0000-0000-0000-000000000000") + ) + self.assertEqual(res[1]["textcol"], "Jane Doe") + self.assertEqual(res[1]["bigintcol"], 3**30) + self.assertEqual(res[1]["boolcol"], False) + self.assertIsNotNone(res[1]["datecol"]) + self.assertIsNotNone(res[1]["tscol"]) + self.assertEqual(res[1]["doublecol"], 3.14) + self.assertEqual(res[1]["intcol"], 100) + self.assertIsNotNone(res[1]["uuidcol"]) + connector.dropTable("test_table") + self.assertFalse(connector.tableExists("test_table")) + + def test_tableInsertSelect(self): + connector.dropTable("test_table") + connector.createTable( + "test_table", + [ + PrimarySerialColumnDefinition("id"), + TextColumnDefinition("name", "John Doe"), + ], + ) + connector.insertIntoTable("test_table", {"name": "John Doe"}) + connector.insertIntoTable("test_table", {"name": "Jane Doe"}) + connector.insertIntoTable("test_table", {"name": "John Smith"}) + connector.insertIntoTable("test_table", {"name": "Jane Smith"}) + rows = connector.selectFromTable("test_table", ["name"]) + self.assertEqual(len(rows), 4) + self.assertEqual(rows[0]["name"], "John Doe") + self.assertEqual(rows[1]["name"], "Jane Doe") + self.assertEqual(rows[2]["name"], "John Smith") + self.assertEqual(rows[3]["name"], "Jane Smith") + connector.dropTable("test_table") + + def test_tableUpdate(self): + connector.dropTable("test_table") + connector.createTable( + "test_table", + [ + PrimarySerialColumnDefinition("id"), + TextColumnDefinition("name", "John Doe"), + ], + ) + connector.insertIntoTable("test_table", {"name": "John Doe"}) + rows = connector.selectFromTable("test_table", ["name"]) + self.assertEqual(rows[0]["name"], "John Doe") + connector.updateDataInTable( + "test_table", {"name": "John"}, {"name": "John Doe"} + ) + rows = connector.selectFromTable("test_table", ["name"]) + self.assertEqual(rows[0]["name"], "John") + connector.dropTable("test_table") + + +if __name__ == "__main__": + unittest.main() diff --git a/requirement.txt b/requirement.txt new file mode 100644 index 0000000..adf8c4a --- /dev/null +++ b/requirement.txt @@ -0,0 +1,5 @@ +psycopg==3.1.8 +psycopg-binary==3.1.8 +psycopg-pool==3.1.5 +typing_extensions==4.4.0 +tzdata==2022.7