[DB] metadata CD; D from CRUD

This commit is contained in:
Andrew 2023-03-11 23:32:50 +07:00
parent 0522918fdd
commit a533dddf1a
2 changed files with 124 additions and 24 deletions

114
db.py
View file

@ -3,7 +3,7 @@ from psycopg import Connection
from psycopg.rows import namedtuple_row, dict_row from psycopg.rows import namedtuple_row, dict_row
from psycopg.sql import SQL, Identifier, Literal from psycopg.sql import SQL, Identifier, Literal
from psycopg_pool import ConnectionPool from psycopg_pool import ConnectionPool
from columns import ColumnDefinition from columns import *
class _ExtendedConnection(Connection): class _ExtendedConnection(Connection):
@ -12,39 +12,111 @@ class _ExtendedConnection(Connection):
self.row_factory = dict_row self.row_factory = dict_row
class ColumnCondition:
def __init__(
self, column: str, value: Any, isString: bool = False, isLike: bool = False
):
self.column = column
self.value = value
self.isString = isString
self.isLike = isLike
def sql(self):
if self.isString:
if self.isLike:
return SQL("{} LIKE {}").format(
Identifier(self.column), Literal(self.value)
)
else:
return SQL("{} NOT LIKE {}").format(
Identifier(self.column), Literal(self.value)
)
return SQL("{} = {}").format(Identifier(self.column), Literal(self.value))
class ColumnUpdate:
def __init__(self, column: str, value: Any):
self.column = column
self.value = value
def sql(self):
return SQL("{} = {}").format(Identifier(self.column), Literal(self.value))
class DBConnector: class DBConnector:
def __init__(self, conninfo: str): def __init__(self, conninfo: str):
self._pool = ConnectionPool(conninfo, connection_class=_ExtendedConnection) self._pool = ConnectionPool(conninfo, connection_class=_ExtendedConnection)
if not self.tableExists("tables_metadata"):
columns = [
PrimaryUUIDColumnDefinition("table_id"),
TextColumnDefinition("table_name"),
TextColumnDefinition("columns"),
BooleanColumnDefinition("system", False),
BooleanColumnDefinition("hidden", False),
]
self.createTable("tables_metadata", columns, system=True, hidden=True)
def connection(self): def connection(self):
return self._pool.connection() return self._pool.connection()
def addTableInfoToHstore(self, table_name: str, columns: list[ColumnDefinition]): def saveTableMetadata(
# TODO: implement this self,
... table_name: str,
columns: list[ColumnDefinition],
system: bool = False,
hidden: bool = False,
):
self.insertIntoTable(
"tables_metadata",
{
"table_name": table_name,
"columns": ",".join(map(lambda c: c.serialize(), columns)),
"system": system,
"hidden": hidden,
},
)
def removeTableMetadata(self, table_name: str):
self.deleteFromTable(
"tables_metadata",
[ColumnCondition("table_name", table_name, isString=True, isLike=True)],
)
def tableExists(self, table_name: str): def tableExists(self, table_name: str):
stmt = SQL( stmt = SQL(
"""SELECT EXISTS ( """
SELECT FROM information_schema.tables SELECT EXISTS (
WHERE table_schema LIKE 'public' AND table_type LIKE 'BASE TABLE' AND table_name = {} SELECT 1
);""" FROM tables_metadata
WHERE table_name = {}
)
"""
).format(Literal(table_name)) ).format(Literal(table_name))
with self.connection() as conn: with self.connection() as conn:
result = conn.execute(stmt).fetchone() result = conn.execute(stmt).fetchone()
print(result)
return False if result is None else result["exists"] return False if result is None else result["exists"]
def createTable(self, table_name: str, columns: list[ColumnDefinition]): def createTable(
self,
table_name: str,
columns: list[ColumnDefinition],
system: bool = False,
hidden: bool = False,
):
stmt = SQL("CREATE TABLE IF NOT EXISTS {} ({})").format( stmt = SQL("CREATE TABLE IF NOT EXISTS {} ({})").format(
Identifier(table_name), SQL(", ").join(map(lambda c: c.sql(), columns)) Identifier(table_name), SQL(", ").join(map(lambda c: c.sql(), columns))
) )
with self.connection() as conn: with self.connection() as conn:
conn.execute(stmt) conn.execute(stmt)
self.saveTableMetadata(table_name, columns, system, hidden)
def dropTable(self, table_name: str): def dropTable(self, table_name: str):
stmt = SQL("DROP TABLE IF EXISTS {}").format(Identifier(table_name)) stmt = SQL("DROP TABLE IF EXISTS {}").format(Identifier(table_name))
with self.connection() as conn: with self.connection() as conn:
conn.execute(stmt) conn.execute(stmt)
self.removeTableMetadata(table_name)
def insertIntoTable(self, table_name: str, columns: dict[str, Any]): def insertIntoTable(self, table_name: str, columns: dict[str, Any]):
if len(columns) == 0: if len(columns) == 0:
@ -72,22 +144,12 @@ class DBConnector:
return conn.execute(stmt).fetchall() return conn.execute(stmt).fetchall()
def updateDataInTable( def updateDataInTable(
self, table_name: str, columns: dict[str, Any], where: dict[str, Any] self, table_name: str, columns: list[ColumnUpdate], where: list[ColumnCondition]
): ):
stmt = SQL("UPDATE {} SET {} WHERE {}").format( stmt = SQL("UPDATE {} SET {} WHERE {}").format(
Identifier(table_name), Identifier(table_name),
SQL(", ").join( SQL(", ").join(map(lambda cp: cp.sql(), columns)),
map( SQL(", ").join(map(lambda cp: cp.sql(), where)),
lambda c: Identifier(c) + SQL(" = ") + Literal(columns[c]),
columns.keys(),
)
),
SQL(", ").join(
map(
lambda c: Identifier(c) + SQL(" LIKE ") + Literal(where[c]),
where.keys(),
)
),
) )
with self.connection() as conn: with self.connection() as conn:
conn.execute(stmt) conn.execute(stmt)
@ -106,3 +168,11 @@ class DBConnector:
) )
with self.connection() as conn: with self.connection() as conn:
return conn.execute(stmt).fetchall() return conn.execute(stmt).fetchall()
def deleteFromTable(self, table_name: str, where: list[ColumnCondition]):
stmt = SQL("DELETE FROM {} WHERE {}").format(
Identifier(table_name),
SQL(", ").join(map(lambda cp: cp.sql(), where)),
)
with self.connection() as conn:
conn.execute(stmt)

View file

@ -1,6 +1,6 @@
import unittest import unittest
import uuid import uuid
from db import DBConnector from db import DBConnector, ColumnUpdate, ColumnCondition
from columns import * from columns import *
conninfo = "postgresql://postgres:asarch6122@localhost" conninfo = "postgresql://postgres:asarch6122@localhost"
@ -133,12 +133,42 @@ class TestDBConnector(unittest.TestCase):
rows = connector.selectFromTable("test_table", ["name"]) rows = connector.selectFromTable("test_table", ["name"])
self.assertEqual(rows[0]["name"], "John Doe") self.assertEqual(rows[0]["name"], "John Doe")
connector.updateDataInTable( connector.updateDataInTable(
"test_table", {"name": "John"}, {"name": "John Doe"} "test_table",
[
ColumnUpdate("name", "John"),
],
[
ColumnCondition("name", "John Doe"),
],
) )
rows = connector.selectFromTable("test_table", ["name"]) rows = connector.selectFromTable("test_table", ["name"])
self.assertEqual(rows[0]["name"], "John") self.assertEqual(rows[0]["name"], "John")
connector.dropTable("test_table") connector.dropTable("test_table")
def test_tableDeleteFrom(self):
connector.dropTable("test_table")
connector.createTable(
"test_table",
[
PrimarySerialColumnDefinition("id"),
TextColumnDefinition("name", "John Doe"),
],
)
connector.insertIntoTable("test_table", {"name": "John Doe"})
connector.insertIntoTable("test_table", {"name": "Jane Doe"})
connector.insertIntoTable("test_table", {"name": "Mikhail Prokopenko"})
rows = connector.selectFromTable("test_table", ["name"])
self.assertEqual(len(rows), 3)
self.assertEqual(rows[0]["name"], "John Doe")
self.assertEqual(rows[1]["name"], "Jane Doe")
self.assertEqual(rows[2]["name"], "Mikhail Prokopenko")
connector.deleteFromTable("test_table", [ColumnCondition("name", "John Doe")])
rows = connector.selectFromTable("test_table", ["name"])
self.assertEqual(len(rows), 2)
self.assertEqual(rows[0]["name"], "Jane Doe")
self.assertEqual(rows[1]["name"], "Mikhail Prokopenko")
connector.dropTable("test_table")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()