Initial commit
This commit is contained in:
parent
0cfda573ca
commit
df21774463
5 changed files with 409 additions and 0 deletions
0
__init__.py
Normal file
0
__init__.py
Normal file
152
columns.py
Normal file
152
columns.py
Normal file
|
|
@ -0,0 +1,152 @@
|
||||||
|
from datetime import date
|
||||||
|
from psycopg.sql import SQL, Identifier, Literal, Composed
|
||||||
|
|
||||||
|
|
||||||
|
class ColumnDefinition:
|
||||||
|
def __init__(self, name: str):
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def sql(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class UniqueColumnDefinition(ColumnDefinition):
|
||||||
|
def __init__(self, name: str):
|
||||||
|
super().__init__(name)
|
||||||
|
|
||||||
|
def sql(self):
|
||||||
|
return SQL("{} UNIQUE").format(Identifier(self.name))
|
||||||
|
|
||||||
|
|
||||||
|
def make_column_unique(column) -> UniqueColumnDefinition:
|
||||||
|
return UniqueColumnDefinition(column)
|
||||||
|
|
||||||
|
|
||||||
|
class PrimarySerialColumnDefinition(ColumnDefinition):
|
||||||
|
def __init__(self, name: str):
|
||||||
|
super().__init__(name)
|
||||||
|
|
||||||
|
def sql(self):
|
||||||
|
return SQL("{} SERIAL PRIMARY KEY").format(Identifier(self.name))
|
||||||
|
|
||||||
|
|
||||||
|
class PrimaryUUIDColumnDefinition(ColumnDefinition):
|
||||||
|
def __init__(self, name: str):
|
||||||
|
super().__init__(name)
|
||||||
|
|
||||||
|
def sql(self):
|
||||||
|
return SQL("{} uuid DEFAULT gen_random_uuid() PRIMARY KEY").format(
|
||||||
|
Identifier(self.name)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TextColumnDefinition(ColumnDefinition):
|
||||||
|
def __init__(self, name: str, default: str | None = None):
|
||||||
|
super().__init__(name)
|
||||||
|
self.default = default
|
||||||
|
|
||||||
|
def sql(self):
|
||||||
|
if self.default is None:
|
||||||
|
return SQL("{} TEXT").format(Identifier(self.name))
|
||||||
|
else:
|
||||||
|
return SQL("{} TEXT DEFAULT {}").format(
|
||||||
|
Identifier(self.name), Literal(self.default)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BigintColumnDefinition(ColumnDefinition):
|
||||||
|
def __init__(self, name: str, default: int | None = None):
|
||||||
|
super().__init__(name)
|
||||||
|
self.default = default
|
||||||
|
|
||||||
|
def sql(self):
|
||||||
|
if self.default is None:
|
||||||
|
return SQL("{} BIGINT").format(Identifier(self.name))
|
||||||
|
else:
|
||||||
|
return SQL("{} BIGINT DEFAULT {}").format(
|
||||||
|
Identifier(self.name), Literal(self.default)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BooleanColumnDefinition(ColumnDefinition):
|
||||||
|
def __init__(self, name: str, default: bool | None = None):
|
||||||
|
super().__init__(name)
|
||||||
|
self.default = default
|
||||||
|
|
||||||
|
def sql(self):
|
||||||
|
if self.default is None:
|
||||||
|
return SQL("{} BOOLEAN").format(Identifier(self.name))
|
||||||
|
else:
|
||||||
|
return SQL("{} BOOLEAN DEFAULT {}").format(
|
||||||
|
Identifier(self.name), Literal(self.default)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DateColumnDefinition(ColumnDefinition):
|
||||||
|
def __init__(self, name: str, default: date | None = None):
|
||||||
|
super().__init__(name)
|
||||||
|
self.default = default
|
||||||
|
|
||||||
|
def sql(self):
|
||||||
|
if self.default is None:
|
||||||
|
return SQL("{} DATE").format(Identifier(self.name))
|
||||||
|
else:
|
||||||
|
return SQL("{} DATE DEFAULT {}").format(
|
||||||
|
Identifier(self.name), Literal(self.default)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TimestampColumnDefinition(ColumnDefinition):
|
||||||
|
def __init__(self, name: str, default: date | None = None):
|
||||||
|
super().__init__(name)
|
||||||
|
self.default = default
|
||||||
|
|
||||||
|
def sql(self):
|
||||||
|
if self.default is None:
|
||||||
|
return SQL("{} TIMESTAMP").format(Identifier(self.name))
|
||||||
|
else:
|
||||||
|
return SQL("{} TIMESTAMP DEFAULT {}").format(
|
||||||
|
Identifier(self.name), Literal(self.default)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DoubleColumnDefinition(ColumnDefinition):
|
||||||
|
def __init__(self, name: str, default: float | None = None):
|
||||||
|
super().__init__(name)
|
||||||
|
self.default = default
|
||||||
|
|
||||||
|
def sql(self):
|
||||||
|
if self.default is None:
|
||||||
|
return SQL("{} DOUBLE PRECISION").format(Identifier(self.name))
|
||||||
|
else:
|
||||||
|
return SQL("{} DOUBLE PRECISION DEFAULT {}").format(
|
||||||
|
Identifier(self.name), Literal(self.default)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class IntegerColumnDefinition(ColumnDefinition):
|
||||||
|
def __init__(self, name: str, default: int | None = None):
|
||||||
|
super().__init__(name)
|
||||||
|
self.default = default
|
||||||
|
|
||||||
|
def sql(self):
|
||||||
|
if self.default is None:
|
||||||
|
return SQL("{} INTEGER").format(Identifier(self.name))
|
||||||
|
else:
|
||||||
|
return SQL("{} INTEGER DEFAULT {}").format(
|
||||||
|
Identifier(self.name), Literal(self.default)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UUIDColumnDefinition(ColumnDefinition):
|
||||||
|
def __init__(self, name: str, default: str | None = None):
|
||||||
|
super().__init__(name)
|
||||||
|
self.default = default
|
||||||
|
|
||||||
|
def sql(self):
|
||||||
|
if self.default is None:
|
||||||
|
return SQL("{} UUID").format(Identifier(self.name))
|
||||||
|
else:
|
||||||
|
return SQL("{} UUID DEFAULT {}").format(
|
||||||
|
Identifier(self.name), Literal(self.default)
|
||||||
|
)
|
||||||
108
db.py
Normal file
108
db.py
Normal file
|
|
@ -0,0 +1,108 @@
|
||||||
|
from typing import Any, NamedTuple
|
||||||
|
from psycopg import Connection
|
||||||
|
from psycopg.rows import namedtuple_row, dict_row
|
||||||
|
from psycopg.sql import SQL, Identifier, Literal
|
||||||
|
from psycopg_pool import ConnectionPool
|
||||||
|
from columns import ColumnDefinition
|
||||||
|
|
||||||
|
|
||||||
|
class _ExtendedConnection(Connection):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(_ExtendedConnection, self).__init__(*args, **kwargs)
|
||||||
|
self.row_factory = dict_row
|
||||||
|
|
||||||
|
|
||||||
|
class DBConnector:
|
||||||
|
def __init__(self, conninfo: str):
|
||||||
|
self._pool = ConnectionPool(conninfo, connection_class=_ExtendedConnection)
|
||||||
|
|
||||||
|
def connection(self):
|
||||||
|
return self._pool.connection()
|
||||||
|
|
||||||
|
def addTableInfoToHstore(self, table_name: str, columns: list[ColumnDefinition]):
|
||||||
|
# TODO: implement this
|
||||||
|
...
|
||||||
|
|
||||||
|
def tableExists(self, table_name: str):
|
||||||
|
stmt = SQL(
|
||||||
|
"""SELECT EXISTS (
|
||||||
|
SELECT FROM information_schema.tables
|
||||||
|
WHERE table_schema LIKE 'public' AND table_type LIKE 'BASE TABLE' AND table_name = {}
|
||||||
|
);"""
|
||||||
|
).format(Literal(table_name))
|
||||||
|
with self.connection() as conn:
|
||||||
|
result = conn.execute(stmt).fetchone()
|
||||||
|
return False if result is None else result["exists"]
|
||||||
|
|
||||||
|
def createTable(self, table_name: str, columns: list[ColumnDefinition]):
|
||||||
|
stmt = SQL("CREATE TABLE IF NOT EXISTS {} ({})").format(
|
||||||
|
Identifier(table_name), SQL(", ").join(map(lambda c: c.sql(), columns))
|
||||||
|
)
|
||||||
|
with self.connection() as conn:
|
||||||
|
conn.execute(stmt)
|
||||||
|
|
||||||
|
def dropTable(self, table_name: str):
|
||||||
|
stmt = SQL("DROP TABLE IF EXISTS {}").format(Identifier(table_name))
|
||||||
|
with self.connection() as conn:
|
||||||
|
conn.execute(stmt)
|
||||||
|
|
||||||
|
def insertIntoTable(self, table_name: str, columns: dict[str, Any]):
|
||||||
|
if len(columns) == 0:
|
||||||
|
stmt = SQL("INSERT INTO {} DEFAULT VALUES").format(Identifier(table_name))
|
||||||
|
else:
|
||||||
|
stmt = SQL("INSERT INTO {} ({}) VALUES ({})").format(
|
||||||
|
Identifier(table_name),
|
||||||
|
SQL(", ").join(map(lambda c: Identifier(c), columns.keys())),
|
||||||
|
SQL(", ").join(map(lambda c: Literal(c), columns.values())),
|
||||||
|
)
|
||||||
|
with self.connection() as conn:
|
||||||
|
conn.execute(stmt)
|
||||||
|
|
||||||
|
def selectFromTable(
|
||||||
|
self, table_name: str, columns: list[str]
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
if len(columns) == 1 and columns[0] == "*":
|
||||||
|
stmt = SQL("SELECT * FROM {}").format(Identifier(table_name))
|
||||||
|
else:
|
||||||
|
stmt = SQL("SELECT {} FROM {}").format(
|
||||||
|
SQL(", ").join(map(lambda c: Identifier(c), columns)),
|
||||||
|
Identifier(table_name),
|
||||||
|
)
|
||||||
|
with self.connection() as conn:
|
||||||
|
return conn.execute(stmt).fetchall()
|
||||||
|
|
||||||
|
def updateDataInTable(
|
||||||
|
self, table_name: str, columns: dict[str, Any], where: dict[str, Any]
|
||||||
|
):
|
||||||
|
stmt = SQL("UPDATE {} SET {} WHERE {}").format(
|
||||||
|
Identifier(table_name),
|
||||||
|
SQL(", ").join(
|
||||||
|
map(
|
||||||
|
lambda c: Identifier(c) + SQL(" = ") + Literal(columns[c]),
|
||||||
|
columns.keys(),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
SQL(", ").join(
|
||||||
|
map(
|
||||||
|
lambda c: Identifier(c) + SQL(" LIKE ") + Literal(where[c]),
|
||||||
|
where.keys(),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
with self.connection() as conn:
|
||||||
|
conn.execute(stmt)
|
||||||
|
|
||||||
|
def selectJoinedTables(
|
||||||
|
self, base_table: str, join_table: str, columns: list[str], join_column: str
|
||||||
|
):
|
||||||
|
stmt = SQL("SELECT {} FROM {} JOIN {} ON {}.{} = {}.{}").format(
|
||||||
|
SQL(", ").join(map(lambda c: Identifier(c), columns)),
|
||||||
|
Identifier(base_table),
|
||||||
|
Identifier(join_table),
|
||||||
|
Identifier(base_table),
|
||||||
|
Identifier(join_column),
|
||||||
|
Identifier(join_table),
|
||||||
|
Identifier(join_column),
|
||||||
|
)
|
||||||
|
with self.connection() as conn:
|
||||||
|
return conn.execute(stmt).fetchall()
|
||||||
144
db_test.py
Normal file
144
db_test.py
Normal file
|
|
@ -0,0 +1,144 @@
|
||||||
|
import unittest
|
||||||
|
import uuid
|
||||||
|
from db import DBConnector
|
||||||
|
from columns import *
|
||||||
|
|
||||||
|
conninfo = "postgresql://postgres:asarch6122@localhost"
|
||||||
|
connector = DBConnector(conninfo)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDBConnector(unittest.TestCase):
|
||||||
|
def test_connection(self):
|
||||||
|
with connector.connection() as conn:
|
||||||
|
one: dict = conn.execute('SELECT 1 as "ONE"').fetchone() # type: ignore
|
||||||
|
self.assertIsNotNone(one)
|
||||||
|
self.assertEqual(one["ONE"], 1)
|
||||||
|
|
||||||
|
def test_columnDefinition(self):
|
||||||
|
with connector.connection() as conn:
|
||||||
|
cdef1 = PrimarySerialColumnDefinition("id")
|
||||||
|
cdef2 = PrimaryUUIDColumnDefinition("uid")
|
||||||
|
cdef3 = TextColumnDefinition("name")
|
||||||
|
cdef4 = TextColumnDefinition("name", "John Doe")
|
||||||
|
self.assertEqual(cdef1.sql().as_string(conn), '"id" SERIAL PRIMARY KEY')
|
||||||
|
self.assertEqual(
|
||||||
|
cdef2.sql().as_string(conn),
|
||||||
|
'"uid" uuid DEFAULT gen_random_uuid() PRIMARY KEY',
|
||||||
|
)
|
||||||
|
self.assertEqual(cdef3.sql().as_string(conn), '"name" TEXT')
|
||||||
|
self.assertEqual(
|
||||||
|
cdef4.sql().as_string(conn), "\"name\" TEXT DEFAULT 'John Doe'"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_tableCreation(self):
|
||||||
|
connector.dropTable("test_table")
|
||||||
|
self.assertFalse(connector.tableExists("test_table"))
|
||||||
|
connector.createTable(
|
||||||
|
"test_table",
|
||||||
|
[
|
||||||
|
PrimarySerialColumnDefinition("id"),
|
||||||
|
TextColumnDefinition("name", "John Doe"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
self.assertTrue(connector.tableExists("test_table"))
|
||||||
|
connector.dropTable("test_table")
|
||||||
|
self.assertFalse(connector.tableExists("test_table"))
|
||||||
|
|
||||||
|
def test_differentColumnsWorking(self):
|
||||||
|
connector.dropTable("test_table")
|
||||||
|
self.assertFalse(connector.tableExists("test_table"))
|
||||||
|
connector.createTable(
|
||||||
|
"test_table",
|
||||||
|
[
|
||||||
|
PrimarySerialColumnDefinition("id"),
|
||||||
|
TextColumnDefinition("textcol", "John Doe"),
|
||||||
|
BigintColumnDefinition("bigintcol", 2**30),
|
||||||
|
BooleanColumnDefinition("boolcol", True),
|
||||||
|
DateColumnDefinition("datecol", date.today()),
|
||||||
|
TimestampColumnDefinition("tscol", date.today()),
|
||||||
|
DoubleColumnDefinition("doublecol", 3.14),
|
||||||
|
IntegerColumnDefinition("intcol", 100),
|
||||||
|
UUIDColumnDefinition("uuidcol", "00000000-0000-0000-0000-000000000000"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
self.assertTrue(connector.tableExists("test_table"))
|
||||||
|
connector.insertIntoTable("test_table", {})
|
||||||
|
connector.insertIntoTable(
|
||||||
|
"test_table",
|
||||||
|
{
|
||||||
|
"textcol": "Jane Doe",
|
||||||
|
"bigintcol": 3**30,
|
||||||
|
"boolcol": False,
|
||||||
|
"datecol": date.today(),
|
||||||
|
"tscol": date.today(),
|
||||||
|
"doublecol": 3.14,
|
||||||
|
"intcol": 100,
|
||||||
|
"uuidcol": uuid.uuid4(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
res = connector.selectFromTable("test_table", ["*"])
|
||||||
|
self.assertEqual(len(res), 2)
|
||||||
|
self.assertEqual(res[0]["textcol"], "John Doe")
|
||||||
|
self.assertEqual(res[0]["bigintcol"], 2**30)
|
||||||
|
self.assertEqual(res[0]["boolcol"], True)
|
||||||
|
self.assertIsNotNone(res[0]["datecol"])
|
||||||
|
self.assertIsNotNone(res[0]["tscol"])
|
||||||
|
self.assertEqual(res[0]["doublecol"], 3.14)
|
||||||
|
self.assertEqual(res[0]["intcol"], 100)
|
||||||
|
self.assertEqual(
|
||||||
|
res[0]["uuidcol"], uuid.UUID("00000000-0000-0000-0000-000000000000")
|
||||||
|
)
|
||||||
|
self.assertEqual(res[1]["textcol"], "Jane Doe")
|
||||||
|
self.assertEqual(res[1]["bigintcol"], 3**30)
|
||||||
|
self.assertEqual(res[1]["boolcol"], False)
|
||||||
|
self.assertIsNotNone(res[1]["datecol"])
|
||||||
|
self.assertIsNotNone(res[1]["tscol"])
|
||||||
|
self.assertEqual(res[1]["doublecol"], 3.14)
|
||||||
|
self.assertEqual(res[1]["intcol"], 100)
|
||||||
|
self.assertIsNotNone(res[1]["uuidcol"])
|
||||||
|
connector.dropTable("test_table")
|
||||||
|
self.assertFalse(connector.tableExists("test_table"))
|
||||||
|
|
||||||
|
def test_tableInsertSelect(self):
|
||||||
|
connector.dropTable("test_table")
|
||||||
|
connector.createTable(
|
||||||
|
"test_table",
|
||||||
|
[
|
||||||
|
PrimarySerialColumnDefinition("id"),
|
||||||
|
TextColumnDefinition("name", "John Doe"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
connector.insertIntoTable("test_table", {"name": "John Doe"})
|
||||||
|
connector.insertIntoTable("test_table", {"name": "Jane Doe"})
|
||||||
|
connector.insertIntoTable("test_table", {"name": "John Smith"})
|
||||||
|
connector.insertIntoTable("test_table", {"name": "Jane Smith"})
|
||||||
|
rows = connector.selectFromTable("test_table", ["name"])
|
||||||
|
self.assertEqual(len(rows), 4)
|
||||||
|
self.assertEqual(rows[0]["name"], "John Doe")
|
||||||
|
self.assertEqual(rows[1]["name"], "Jane Doe")
|
||||||
|
self.assertEqual(rows[2]["name"], "John Smith")
|
||||||
|
self.assertEqual(rows[3]["name"], "Jane Smith")
|
||||||
|
connector.dropTable("test_table")
|
||||||
|
|
||||||
|
def test_tableUpdate(self):
|
||||||
|
connector.dropTable("test_table")
|
||||||
|
connector.createTable(
|
||||||
|
"test_table",
|
||||||
|
[
|
||||||
|
PrimarySerialColumnDefinition("id"),
|
||||||
|
TextColumnDefinition("name", "John Doe"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
connector.insertIntoTable("test_table", {"name": "John Doe"})
|
||||||
|
rows = connector.selectFromTable("test_table", ["name"])
|
||||||
|
self.assertEqual(rows[0]["name"], "John Doe")
|
||||||
|
connector.updateDataInTable(
|
||||||
|
"test_table", {"name": "John"}, {"name": "John Doe"}
|
||||||
|
)
|
||||||
|
rows = connector.selectFromTable("test_table", ["name"])
|
||||||
|
self.assertEqual(rows[0]["name"], "John")
|
||||||
|
connector.dropTable("test_table")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
5
requirement.txt
Normal file
5
requirement.txt
Normal file
|
|
@ -0,0 +1,5 @@
|
||||||
|
psycopg==3.1.8
|
||||||
|
psycopg-binary==3.1.8
|
||||||
|
psycopg-pool==3.1.5
|
||||||
|
typing_extensions==4.4.0
|
||||||
|
tzdata==2022.7
|
||||||
Loading…
Add table
Add a link
Reference in a new issue