diff --git a/src/based/__init__.py b/src/based/__init__.py index 3dc1f76..d3ec452 100644 --- a/src/based/__init__.py +++ b/src/based/__init__.py @@ -1 +1 @@ -__version__ = "0.1.0" +__version__ = "0.2.0" diff --git a/src/based/columns.py b/src/based/columns.py index da35c87..2d139d6 100644 --- a/src/based/columns.py +++ b/src/based/columns.py @@ -1,37 +1,23 @@ from datetime import date -from psycopg.sql import SQL, Identifier, Literal, Composed +from psycopg.sql import SQL, Identifier, Literal class ColumnDefinition: - def __init__(self, name: str): + def __init__(self, name: str, unique: bool, has_default: bool = False): self.name = name + self.unique = unique + self.has_default = has_default def sql(self): - raise NotImplementedError + raise NotImplementedError() def serialize(self): - raise NotImplementedError - - -class UniqueColumnDefinition(ColumnDefinition): - def __init__(self, wrapped: ColumnDefinition): - super().__init__(wrapped.name) - self.wrapped = wrapped - - def sql(self): - return SQL("{} UNIQUE").format(self.wrapped.sql()) - - def serialize(self): - return f"{self.wrapped.serialize()}:unique" - - -def make_column_unique(column) -> UniqueColumnDefinition: - return UniqueColumnDefinition(column) + raise NotImplementedError() class PrimarySerialColumnDefinition(ColumnDefinition): def __init__(self, name: str): - super().__init__(name) + super().__init__(name, unique=False) def sql(self): return SQL("{} SERIAL PRIMARY KEY").format(Identifier(self.name)) @@ -40,150 +26,106 @@ class PrimarySerialColumnDefinition(ColumnDefinition): return f"{self.name}:serial:primary" -class PrimaryUUIDColumnDefinition(ColumnDefinition): - def __init__(self, name: str): - super().__init__(name) - - def sql(self): - return SQL("{} uuid DEFAULT gen_random_uuid() PRIMARY KEY").format( - Identifier(self.name) - ) - - def serialize(self): - return f"{self.name}:uuid:primary" - - class TextColumnDefinition(ColumnDefinition): - def __init__(self, name: str, default: str | None = None): - super().__init__(name) + def __init__(self, name: str, default: str | None = None, unique: bool = False): + if default is not None and unique: + raise ValueError("Cannot have a default value and be unique") + super().__init__(name, unique) self.default = default def sql(self): if self.default is None: - return SQL("{} TEXT").format(Identifier(self.name)) + return SQL("{} TEXT UNIQUE" if self.unique else "{} TEXT").format( + Identifier(self.name) + ) else: return SQL("{} TEXT DEFAULT {}").format( Identifier(self.name), Literal(self.default) ) def serialize(self): - return f"{self.name}:str" - - -class BigintColumnDefinition(ColumnDefinition): - def __init__(self, name: str, default: int | None = None): - super().__init__(name) - self.default = default - - def sql(self): - if self.default is None: - return SQL("{} BIGINT").format(Identifier(self.name)) - else: - return SQL("{} BIGINT DEFAULT {}").format( - Identifier(self.name), Literal(self.default) - ) - - def serialize(self): - return f"{self.name}:bigint" + return f"{self.name}:str{':unique' if self.unique else ''}{':default' if self.default is not None else ''}" class BooleanColumnDefinition(ColumnDefinition): - def __init__(self, name: str, default: bool | None = None): - super().__init__(name) + def __init__(self, name: str, default: bool | None = None, unique: bool = False): + if default is not None and unique: + raise ValueError("Cannot have a default value and be unique") + super().__init__(name, unique) self.default = default def sql(self): if self.default is None: - return SQL("{} BOOLEAN").format(Identifier(self.name)) + return SQL("{} BOOLEAN UNIQUE" if self.unique else "{} BOOLEAN").format( + Identifier(self.name) + ) else: return SQL("{} BOOLEAN DEFAULT {}").format( Identifier(self.name), Literal(self.default) ) def serialize(self): - return f"{self.name}:bool" - - -class DateColumnDefinition(ColumnDefinition): - def __init__(self, name: str, default: date | None = None): - super().__init__(name) - self.default = default - - def sql(self): - if self.default is None: - return SQL("{} DATE").format(Identifier(self.name)) - else: - return SQL("{} DATE DEFAULT {}").format( - Identifier(self.name), Literal(self.default) - ) - - def serialize(self): - return f"{self.name}:date" + return f"{self.name}:bool{':unique' if self.unique else ''}{':default' if self.default is not None else ''}" class TimestampColumnDefinition(ColumnDefinition): - def __init__(self, name: str, default: date | None = None): - super().__init__(name) + def __init__(self, name: str, default: date | None = None, unique: bool = False): + if default is not None and unique: + raise ValueError("Cannot have a default value and be unique") + super().__init__(name, unique) self.default = default def sql(self): if self.default is None: - return SQL("{} TIMESTAMP").format(Identifier(self.name)) + return SQL("{} TIMESTAMP UNIQUE" if self.unique else "{} TIMESTAMP").format( + Identifier(self.name) + ) else: return SQL("{} TIMESTAMP DEFAULT {}").format( Identifier(self.name), Literal(self.default) ) def serialize(self): - return f"{self.name}:datetime" + return f"{self.name}:datetime{':unique' if self.unique else ''}{':default' if self.default is not None else ''}" class DoubleColumnDefinition(ColumnDefinition): - def __init__(self, name: str, default: float | None = None): - super().__init__(name) + def __init__(self, name: str, default: float | None = None, unique: bool = False): + if default is not None and unique: + raise ValueError("Cannot have a default value and be unique") + super().__init__(name, unique) self.default = default def sql(self): if self.default is None: - return SQL("{} DOUBLE PRECISION").format(Identifier(self.name)) + return SQL( + "{} DOUBLE PRECISION UNIQUE" if self.unique else "{} DOUBLE PRECISION" + ).format(Identifier(self.name)) else: return SQL("{} DOUBLE PRECISION DEFAULT {}").format( Identifier(self.name), Literal(self.default) ) def serialize(self): - return f"{self.name}:float" + return f"{self.name}:float{':unique' if self.unique else ''}{':default' if self.default is not None else ''}" class IntegerColumnDefinition(ColumnDefinition): - def __init__(self, name: str, default: int | None = None): - super().__init__(name) + def __init__(self, name: str, default: int | None = None, unique: bool = False): + if default is not None and unique: + raise ValueError("Cannot have a default value and be unique") + super().__init__(name, unique) self.default = default def sql(self): if self.default is None: - return SQL("{} INTEGER").format(Identifier(self.name)) + return SQL("{} INTEGER UNIQUE" if self.unique else "{} INTEGER").format( + Identifier(self.name) + ) else: return SQL("{} INTEGER DEFAULT {}").format( Identifier(self.name), Literal(self.default) ) def serialize(self): - return f"{self.name}:int" - - -class UUIDColumnDefinition(ColumnDefinition): - def __init__(self, name: str, default: str | None = None): - super().__init__(name) - self.default = default - - def sql(self): - if self.default is None: - return SQL("{} UUID").format(Identifier(self.name)) - else: - return SQL("{} UUID DEFAULT {}").format( - Identifier(self.name), Literal(self.default) - ) - - def serialize(self): - return f"{self.name}:uuid" + return f"{self.name}:int{':unique' if self.unique else ''}{':default' if self.default is not None else ''}" diff --git a/src/based/db.py b/src/based/db.py index da2ba17..28f189d 100644 --- a/src/based/db.py +++ b/src/based/db.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Literal as _Literal from psycopg import Connection from psycopg.rows import dict_row from psycopg.sql import SQL, Identifier, Literal @@ -13,26 +13,80 @@ class _ExtendedConnection(Connection): self.row_factory = dict_row +CONDITION_OPERATORS = _Literal[ + "eq", + "ne", + "gt", + "lt", + "ge", + "le", + "contains", + "not_contains", + "starts_with", + "not_starts_with", + "ends_with", + "not_ends_with", +] + + class ColumnCondition: - def __init__( - self, column: str, value: Any, isString: bool = False, isLike: bool = True - ): + def __init__(self, column: str, operator: CONDITION_OPERATORS, value: Any): self.column = column + self.operator = operator self.value = value - self.isString = isString - self.isLike = isLike def sql(self): - if self.isString: - if self.isLike: + match self.operator: + case "eq": + return SQL("{} = {}").format( + Identifier(self.column), Literal(self.value) + ) + case "ne": + return SQL("{} != {}").format( + Identifier(self.column), Literal(self.value) + ) + case "gt": + return SQL("{} > {}").format( + Identifier(self.column), Literal(self.value) + ) + case "lt": + return SQL("{} < {}").format( + Identifier(self.column), Literal(self.value) + ) + case "ge": + return SQL("{} >= {}").format( + Identifier(self.column), Literal(self.value) + ) + case "le": + return SQL("{} <= {}").format( + Identifier(self.column), Literal(self.value) + ) + case "contains": return SQL("{} LIKE {}").format( Identifier(self.column), Literal(self.value) ) - else: + case "not_contains": return SQL("{} NOT LIKE {}").format( Identifier(self.column), Literal(self.value) ) - return SQL("{} = {}").format(Identifier(self.column), Literal(self.value)) + case "starts_with": + return SQL("{} LIKE {}").format( + Identifier(self.column), Literal(f"{self.value}%") + ) + case "not_starts_with": + return SQL("{} NOT LIKE {}").format( + Identifier(self.column), Literal(f"{self.value}%") + ) + case "ends_with": + return SQL("{} LIKE {}").format( + Identifier(self.column), Literal(f"%{self.value}") + ) + case "not_ends_with": + return SQL("{} NOT LIKE {}").format( + Identifier(self.column), Literal(f"%{self.value}") + ) + case _: + raise ValueError(f"Unknown operator: {self.operator}") class ColumnUpdate: @@ -50,7 +104,7 @@ class DBConnector: if not self.__tableExistsInternal("tables_metadata"): columns = [ - PrimaryUUIDColumnDefinition("table_id"), + PrimarySerialColumnDefinition("table_id"), TextColumnDefinition("table_name"), TextColumnDefinition("columns"), BooleanColumnDefinition("system", False), @@ -58,6 +112,7 @@ class DBConnector: ] self.createTable("tables_metadata", columns, system=True, hidden=True) + @property def connection(self): return self._pool.connection() @@ -81,12 +136,12 @@ class DBConnector: def removeTableMetadata(self, table_name: str): self.deleteFromTable( "tables_metadata", - [ColumnCondition("table_name", table_name, isString=True, isLike=True)], + [ColumnCondition("table_name", "eq", table_name)], ) def tables(self): stmt = SQL("SELECT * FROM tables_metadata") - with self.connection() as conn: + with self.connection as conn: return conn.execute(stmt).fetchall() def tableExists(self, table_name: str): @@ -99,7 +154,7 @@ class DBConnector: ) """ ).format(Literal(table_name)) - with self.connection() as conn: + with self.connection as conn: result = conn.execute(stmt).fetchone() return False if result is None else result["exists"] @@ -113,7 +168,7 @@ class DBConnector: ) """ ).format(Literal(table_name)) - with self.connection() as conn: + with self.connection as conn: result = conn.execute(stmt).fetchone() return False if result is None else result["exists"] @@ -127,13 +182,13 @@ class DBConnector: stmt = SQL("CREATE TABLE IF NOT EXISTS {} ({})").format( Identifier(table_name), SQL(", ").join(map(lambda c: c.sql(), columns)) ) - with self.connection() as conn: + with self.connection as conn: conn.execute(stmt) self.saveTableMetadata(table_name, columns, system, hidden) def dropTable(self, table_name: str): stmt = SQL("DROP TABLE IF EXISTS {}").format(Identifier(table_name)) - with self.connection() as conn: + with self.connection as conn: conn.execute(stmt) self.removeTableMetadata(table_name) @@ -141,7 +196,7 @@ class DBConnector: stmt = SQL("SELECT * FROM tables_metadata WHERE table_name = {}").format( Literal(table_name) ) - with self.connection() as conn: + with self.connection as conn: return conn.execute(stmt).fetchone() def insertIntoTable(self, table_name: str, columns: dict[str, Any]): @@ -153,7 +208,7 @@ class DBConnector: SQL(", ").join(map(lambda c: Identifier(c), columns.keys())), SQL(", ").join(map(lambda c: Literal(c), columns.values())), ) - with self.connection() as conn: + with self.connection as conn: conn.execute(stmt) def selectFromTable( @@ -166,7 +221,7 @@ class DBConnector: SQL(", ").join(map(lambda c: Identifier(c), columns)), Identifier(table_name), ) - with self.connection() as conn: + with self.connection as conn: return conn.execute(stmt).fetchall() def filterFromTable( @@ -183,7 +238,7 @@ class DBConnector: Identifier(table_name), SQL(" AND ").join(map(lambda cp: cp.sql(), where)), ) - with self.connection() as conn: + with self.connection as conn: return conn.execute(stmt).fetchall() def updateDataInTable( @@ -194,7 +249,7 @@ class DBConnector: SQL(", ").join(map(lambda cp: cp.sql(), columns)), SQL(" AND ").join(map(lambda cp: cp.sql(), where)), ) - with self.connection() as conn: + with self.connection as conn: conn.execute(stmt) def selectJoinedTables( @@ -209,7 +264,7 @@ class DBConnector: Identifier(join_table), Identifier(join_column), ) - with self.connection() as conn: + with self.connection as conn: return conn.execute(stmt).fetchall() def deleteFromTable(self, table_name: str, where: list[ColumnCondition]): @@ -217,5 +272,5 @@ class DBConnector: Identifier(table_name), SQL(" AND ").join(map(lambda cp: cp.sql(), where)), ) - with self.connection() as conn: + with self.connection as conn: conn.execute(stmt)