From 9e1ca9dd43c7f43b09c19ceeb629f399a9b8290d Mon Sep 17 00:00:00 2001 From: Andrew nuark G Date: Sun, 12 Mar 2023 02:44:59 +0700 Subject: [PATCH] Little fixes --- src/based/db.py | 38 +++++++++++++++++++++++++++++--------- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/src/based/db.py b/src/based/db.py index c3480eb..2203e48 100644 --- a/src/based/db.py +++ b/src/based/db.py @@ -15,7 +15,7 @@ class _ExtendedConnection(Connection): class ColumnCondition: def __init__( - self, column: str, value: Any, isString: bool = False, isLike: bool = False + self, column: str, value: Any, isString: bool = False, isLike: bool = True ): self.column = column self.value = value @@ -48,7 +48,7 @@ class DBConnector: def __init__(self, conninfo: str): self._pool = ConnectionPool(conninfo, connection_class=_ExtendedConnection) - if not self.tableExists("tables_metadata"): + if not self.__tableExistsInternal("tables_metadata"): columns = [ PrimaryUUIDColumnDefinition("table_id"), TextColumnDefinition("table_name"), @@ -103,6 +103,20 @@ class DBConnector: result = conn.execute(stmt).fetchone() return False if result is None else result["exists"] + def __tableExistsInternal(self, table_name: str): + stmt = SQL( + """ + SELECT EXISTS ( + SELECT 1 + FROM information_schema.tables + WHERE table_name = {} + ) + """ + ).format(Literal(table_name)) + with self.connection() as conn: + result = conn.execute(stmt).fetchone() + return False if result is None else result["exists"] + def createTable( self, table_name: str, @@ -151,11 +165,17 @@ class DBConnector: def filterFromTable( self, table_name: str, columns: list[str], where: list[ColumnCondition] ) -> list[dict[str, Any]]: - stmt = SQL("SELECT {} FROM {} WHERE {}").format( - SQL(", ").join(map(lambda c: Identifier(c), columns)), - Identifier(table_name), - SQL(", ").join(map(lambda cp: cp.sql(), where)), - ) + if len(columns) == 1 and columns[0] == "*": + stmt = SQL("SELECT * FROM {} WHERE {}").format( + Identifier(table_name), + SQL(" AND ").join(map(lambda cp: cp.sql(), where)), + ) + else: + stmt = SQL("SELECT {} FROM {} WHERE {}").format( + SQL(", ").join(map(lambda c: Identifier(c), columns)), + Identifier(table_name), + SQL(" AND ").join(map(lambda cp: cp.sql(), where)), + ) with self.connection() as conn: return conn.execute(stmt).fetchall() @@ -165,7 +185,7 @@ class DBConnector: stmt = SQL("UPDATE {} SET {} WHERE {}").format( Identifier(table_name), SQL(", ").join(map(lambda cp: cp.sql(), columns)), - SQL(", ").join(map(lambda cp: cp.sql(), where)), + SQL(" AND ").join(map(lambda cp: cp.sql(), where)), ) with self.connection() as conn: conn.execute(stmt) @@ -188,7 +208,7 @@ class DBConnector: def deleteFromTable(self, table_name: str, where: list[ColumnCondition]): stmt = SQL("DELETE FROM {} WHERE {}").format( Identifier(table_name), - SQL(", ").join(map(lambda cp: cp.sql(), where)), + SQL(" AND ").join(map(lambda cp: cp.sql(), where)), ) with self.connection() as conn: conn.execute(stmt)