from datetime import date from psycopg.sql import SQL, Identifier, Literal, Composed class ColumnDefinition: def __init__(self, name: str): self.name = name def sql(self): raise NotImplementedError class UniqueColumnDefinition(ColumnDefinition): def __init__(self, name: str): super().__init__(name) def sql(self): return SQL("{} UNIQUE").format(Identifier(self.name)) def make_column_unique(column) -> UniqueColumnDefinition: return UniqueColumnDefinition(column) class PrimarySerialColumnDefinition(ColumnDefinition): def __init__(self, name: str): super().__init__(name) def sql(self): return SQL("{} SERIAL PRIMARY KEY").format(Identifier(self.name)) class PrimaryUUIDColumnDefinition(ColumnDefinition): def __init__(self, name: str): super().__init__(name) def sql(self): return SQL("{} uuid DEFAULT gen_random_uuid() PRIMARY KEY").format( Identifier(self.name) ) class TextColumnDefinition(ColumnDefinition): def __init__(self, name: str, default: str | None = None): super().__init__(name) self.default = default def sql(self): if self.default is None: return SQL("{} TEXT").format(Identifier(self.name)) else: return SQL("{} TEXT DEFAULT {}").format( Identifier(self.name), Literal(self.default) ) class BigintColumnDefinition(ColumnDefinition): def __init__(self, name: str, default: int | None = None): super().__init__(name) self.default = default def sql(self): if self.default is None: return SQL("{} BIGINT").format(Identifier(self.name)) else: return SQL("{} BIGINT DEFAULT {}").format( Identifier(self.name), Literal(self.default) ) class BooleanColumnDefinition(ColumnDefinition): def __init__(self, name: str, default: bool | None = None): super().__init__(name) self.default = default def sql(self): if self.default is None: return SQL("{} BOOLEAN").format(Identifier(self.name)) else: return SQL("{} BOOLEAN DEFAULT {}").format( Identifier(self.name), Literal(self.default) ) class DateColumnDefinition(ColumnDefinition): def __init__(self, name: str, default: date | None = None): super().__init__(name) self.default = default def sql(self): if self.default is None: return SQL("{} DATE").format(Identifier(self.name)) else: return SQL("{} DATE DEFAULT {}").format( Identifier(self.name), Literal(self.default) ) class TimestampColumnDefinition(ColumnDefinition): def __init__(self, name: str, default: date | None = None): super().__init__(name) self.default = default def sql(self): if self.default is None: return SQL("{} TIMESTAMP").format(Identifier(self.name)) else: return SQL("{} TIMESTAMP DEFAULT {}").format( Identifier(self.name), Literal(self.default) ) class DoubleColumnDefinition(ColumnDefinition): def __init__(self, name: str, default: float | None = None): super().__init__(name) self.default = default def sql(self): if self.default is None: return SQL("{} DOUBLE PRECISION").format(Identifier(self.name)) else: return SQL("{} DOUBLE PRECISION DEFAULT {}").format( Identifier(self.name), Literal(self.default) ) class IntegerColumnDefinition(ColumnDefinition): def __init__(self, name: str, default: int | None = None): super().__init__(name) self.default = default def sql(self): if self.default is None: return SQL("{} INTEGER").format(Identifier(self.name)) else: return SQL("{} INTEGER DEFAULT {}").format( Identifier(self.name), Literal(self.default) ) class UUIDColumnDefinition(ColumnDefinition): def __init__(self, name: str, default: str | None = None): super().__init__(name) self.default = default def sql(self): if self.default is None: return SQL("{} UUID").format(Identifier(self.name)) else: return SQL("{} UUID DEFAULT {}").format( Identifier(self.name), Literal(self.default) )