From af365cfe685e20f82002d9b347fd510e36c6c16b Mon Sep 17 00:00:00 2001 From: Andrew nuark G Date: Sun, 15 Mar 2026 11:35:50 +0700 Subject: [PATCH] feat: initial release of GIS classification project with strategy-based classifiers selector --- .gitignore | 187 +++++++++++++++++++++++++ README.md | 81 +++++++++++ data/.gitkeep | 0 main.py | 116 ++++++++++++++++ output/.gitkeep | 0 requirements.txt | 7 + src/__init__.py | 25 ++++ src/classifier.py | 225 ++++++++++++++++++++++++++++++ src/data/__init__.py | 11 ++ src/data/loader.py | 139 +++++++++++++++++++ src/strategies/__init__.py | 12 ++ src/strategies/base.py | 61 +++++++++ src/strategies/classifiers.py | 250 ++++++++++++++++++++++++++++++++++ src/utils/__init__.py | 1 + 14 files changed, 1115 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 data/.gitkeep create mode 100644 main.py create mode 100644 output/.gitkeep create mode 100644 requirements.txt create mode 100644 src/__init__.py create mode 100644 src/classifier.py create mode 100644 src/data/__init__.py create mode 100644 src/data/loader.py create mode 100644 src/strategies/__init__.py create mode 100644 src/strategies/base.py create mode 100644 src/strategies/classifiers.py create mode 100644 src/utils/__init__.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..32e46ed --- /dev/null +++ b/.gitignore @@ -0,0 +1,187 @@ +# Created by https://www.toptal.com/developers/gitignore/api/python,gis +# Edit at https://www.toptal.com/developers/gitignore?templates=python,gis + +### GIS ### +*.gpx +*.kml +*.tif +*.tif.aux.xml +*.tiff +*.shx +*.shp +*.prj +*.dbf + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +### Python Patch ### +# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration +poetry.toml + +# ruff +.ruff_cache/ + +# LSP config files +pyrightconfig.json + +# End of https://www.toptal.com/developers/gitignore/api/python,gis \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..ec6888e --- /dev/null +++ b/README.md @@ -0,0 +1,81 @@ +# GIS Classification + +Python project for land cover classification using GIS data (GeoTIFF + Shapefile). + +## Project Structure + +``` +gis-classification/ +├── main.py # Main script with parameters +├── requirements.txt # Dependencies +├── data/ # Input data folder +├── output/ # Classification results +└── src/ + ├── classifier.py # Main classification pipeline + ├── data/ + │ └── loader.py # Data loading (GeoTIFF, Shapefile) + ├── strategies/ + │ ├── base.py # Strategy interface + │ └── classifiers.py # Built-in strategies (RF, SVM, LR) + └── utils/ +``` + +## Installation + +```bash +pip install -r requirements.txt +``` + +## Usage + +1. Place your input files in `data/`: + - `landsat.tif` - GeoTIFF from Landsat + - `polygons.shp` - Shapefile with class labels + +2. Configure parameters in `main.py`: + - Input/output paths + - Classification strategy (RandomForest, SVM, LogisticRegression) + - Training parameters + +3. Run: +```bash +python main.py +``` + +## Adding Custom Classification Strategy + +Create a new class implementing `ClassificationStrategy`: + +```python +from src.strategies import ClassificationStrategy +import numpy as np + +class MyCustomStrategy(ClassificationStrategy): + def train(self, X: np.ndarray, y: np.ndarray) -> None: + # Your training logic + pass + + def predict(self, X: np.ndarray) -> np.ndarray: + # Your prediction logic + pass + + def predict_proba(self, X: np.ndarray) -> np.ndarray: + pass + + def get_params(self) -> dict: + pass + + @property + def name(self) -> str: + return "MyCustom" +``` + +Then use in `main.py`: +```python +STRATEGY = MyCustomStrategy() +``` + +## Output + +- `output/classified.tif` - Classified raster (GeoTIFF) +- Console output with accuracy metrics diff --git a/data/.gitkeep b/data/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/main.py b/main.py new file mode 100644 index 0000000..cda7835 --- /dev/null +++ b/main.py @@ -0,0 +1,116 @@ +"""Main script for GIS classification. + +Configure parameters below to run classification. +""" + +import os +from src import ( + GISClassifier, + RandomForestStrategy, + SVMStrategy, + LogisticRegressionStrategy, + MLEStrategy, +) + + +# ==================== PARAMETERS ==================== + +# Input files +RASTER_PATH = os.path.join("data", "landsat.tif") # Path to GeoTIFF (Landsat) +VECTOR_PATH = os.path.join("data", "polygons.shx") # Path to Shapefile + +# Column in Shapefile containing class labels +CLASS_COLUMN = "macroclass" + +# Output file for classified raster +OUTPUT_PATH = os.path.join("output", "classified.tif") + +# Classification strategy parameters +# Change strategy by uncommenting desired option: + +# Option 1: Random Forest (recommended for GIS) +STRATEGY = RandomForestStrategy( + n_estimators=100, + max_depth=None, + random_state=42, +) + +# Option 2: Support Vector Machine +# STRATEGY = SVMStrategy( +# kernel="rbf", +# C=1.0, +# gamma="scale", +# random_state=42, +# ) + +# Option 3: Logistic Regression +# STRATEGY = LogisticRegressionStrategy( +# penalty="l2", +# C=1.0, +# max_iter=1000, +# random_state=42, +# ) + +# Option 4: Maximum Likelihood Estimation (classic for GIS) +# STRATEGY = MLEStrategy( +# reg_covar=1e-6, +# ) + +# Training parameters +TEST_SIZE = 0.2 # Fraction for validation +RANDOM_STATE = 42 # Random seed + +# ==================== RUN CLASSIFICATION ==================== + + +def main(): + # Check input files exist + if not os.path.exists(RASTER_PATH): + print(f"Error: Raster file not found: {RASTER_PATH}") + return + + if not os.path.exists(VECTOR_PATH): + print(f"Error: Vector file not found: {VECTOR_PATH}") + return + + # Create output directory if needed + os.makedirs(os.path.dirname(OUTPUT_PATH), exist_ok=True) + + # Initialize classifier with strategy + print(f"Using classification strategy: {STRATEGY.name}") + print(f"Strategy parameters: {STRATEGY.get_params()}") + print() + + classifier = GISClassifier(strategy=STRATEGY) + + # Train + print("Training classifier...") + metrics = classifier.train( + raster_path=RASTER_PATH, + vector_path=VECTOR_PATH, + class_column=CLASS_COLUMN, + test_size=TEST_SIZE, + random_state=RANDOM_STATE, + ) + + print(f"Training samples: {metrics['train_samples']}") + print(f"Validation samples: {metrics['val_samples']}") + print(f"Accuracy: {metrics['accuracy']:.2%}") + print(f"Classes: {metrics['classes']}") + print() + + # Predict + print(f"Classifying raster: {RASTER_PATH}") + result = classifier.predict( + raster_path=RASTER_PATH, + output_path=OUTPUT_PATH, + ) + + print(f"Classification complete!") + print(f"Output saved to: {OUTPUT_PATH}") + print(f"Output shape: {result.predicted_array.shape}") + print(f"Classes in output: {result.classes}") + + +if __name__ == "__main__": + main() diff --git a/output/.gitkeep b/output/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..7f959bc --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +rasterio>=1.3.0 +geopandas>=0.12.0 +shapely>=2.0.0 +scikit-learn>=1.3.0 +scipy>=1.10.0 +numpy>=1.24.0 +pandas>=2.0.0 diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..2409c46 --- /dev/null +++ b/src/__init__.py @@ -0,0 +1,25 @@ +"""GIS Classification Package.""" + +from .classifier import GISClassifier, ClassificationResult +from .data import RasterData, VectorData, load_raster, load_vector +from .strategies import ( + ClassificationStrategy, + RandomForestStrategy, + SVMStrategy, + LogisticRegressionStrategy, + MLEStrategy, +) + +__all__ = [ + "GISClassifier", + "ClassificationResult", + "RasterData", + "VectorData", + "load_raster", + "load_vector", + "ClassificationStrategy", + "RandomForestStrategy", + "SVMStrategy", + "LogisticRegressionStrategy", + "MLEStrategy", +] diff --git a/src/classifier.py b/src/classifier.py new file mode 100644 index 0000000..901be1c --- /dev/null +++ b/src/classifier.py @@ -0,0 +1,225 @@ +"""Main classification pipeline.""" + +from dataclasses import dataclass, field +from typing import Any +import numpy as np +import rasterio +from rasterio.transform import from_bounds +from sklearn.model_selection import train_test_split +from sklearn.metrics import classification_report, accuracy_score, confusion_matrix + +from .data import RasterData, VectorData, load_raster, load_vector, extract_raster_values_by_polygons +from .strategies import ClassificationStrategy + + +@dataclass +class ClassificationResult: + """Container for classification results.""" + predicted_array: np.ndarray + transform: rasterio.transform.Affine + crs: rasterio.crs.CRS + classes: np.ndarray + accuracy: float | None = None + report: str | None = None + confusion_matrix: np.ndarray | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + +class GISClassifier: + """Main classifier for GIS data. + + Uses strategy pattern to allow plugging different classification algorithms. + """ + + def __init__(self, strategy: ClassificationStrategy): + """Initialize classifier with a strategy. + + Args: + strategy: ClassificationStrategy instance to use for classification. + """ + self.strategy = strategy + self._is_trained = False + self._classes: np.ndarray | None = None + + def train( + self, + raster_path: str, + vector_path: str, + class_column: str = "class", + test_size: float = 0.2, + random_state: int = 42, + ) -> dict[str, float]: + """Train the classifier using raster and vector data. + + Args: + raster_path: Path to GeoTIFF file. + vector_path: Path to Shapefile. + class_column: Name of column with class labels in Shapefile. + test_size: Fraction of data to use for validation. + random_state: Random seed for reproducibility. + + Returns: + Dictionary with training metrics. + """ + # Load data + vector = load_vector(vector_path, class_column) + + # Extract features and labels + X, y = extract_raster_values_by_polygons(raster_path, vector) + + # Split for validation + X_train, X_val, y_train, y_val = train_test_split( + X, y, test_size=test_size, random_state=random_state, stratify=y + ) + + # Train the strategy + self.strategy.train(X_train, y_train) + self._is_trained = True + self._classes = np.unique(y) + + # Evaluate + y_pred = self.strategy.predict(X_val) + accuracy = accuracy_score(y_val, y_pred) + + return { + "train_samples": len(X_train), + "val_samples": len(X_val), + "accuracy": accuracy, + "classes": list(self._classes), + } + + def predict(self, raster_path: str, output_path: str | None = None) -> ClassificationResult: + """Classify a raster file. + + Args: + raster_path: Path to GeoTIFF file to classify. + output_path: Optional path to save classified raster. + + Returns: + ClassificationResult with predicted array and metadata. + + Raises: + RuntimeError: If classifier is not trained. + """ + if not self._is_trained: + raise RuntimeError("Classifier must be trained before prediction") + + # Load raster + raster = load_raster(raster_path) + + # Prepare data for prediction + # Reshape from (bands, height, width) to (pixels, bands) + if len(raster.array.shape) == 3: + bands, height, width = raster.array.shape + X = raster.array.transpose(1, 2, 0).reshape(-1, bands) + else: + height, width = raster.array.shape + X = raster.array.reshape(-1, 1) + + # Predict + predictions = self.strategy.predict(X) + + # Reshape back to 2D + predicted_array = predictions.reshape(height, width) + + # Create result + result = ClassificationResult( + predicted_array=predicted_array.astype(np.uint8), + transform=raster.transform, + crs=raster.crs, + classes=self._classes if self._classes is not None else np.unique(predictions), + metadata=self.strategy.get_params(), + ) + + # Save if path provided + if output_path: + self._save_result(result, output_path) + + return result + + def predict_from_data( + self, + raster: RasterData, + output_path: str | None = None, + ) -> ClassificationResult: + """Classify already loaded raster data. + + Args: + raster: RasterData object. + output_path: Optional path to save classified raster. + + Returns: + ClassificationResult with predicted array and metadata. + """ + if not self._is_trained: + raise RuntimeError("Classifier must be trained before prediction") + + # Prepare data for prediction + if len(raster.array.shape) == 3: + bands, height, width = raster.array.shape + X = raster.array.transpose(1, 2, 0).reshape(-1, bands) + else: + height, width = raster.array.shape + X = raster.array.reshape(-1, 1) + + # Predict + predictions = self.strategy.predict(X) + predicted_array = predictions.reshape(height, width) + + result = ClassificationResult( + predicted_array=predicted_array.astype(np.uint8), + transform=raster.transform, + crs=raster.crs, + classes=self._classes if self._classes is not None else np.unique(predictions), + metadata=self.strategy.get_params(), + ) + + if output_path: + self._save_result(result, output_path) + + return result + + def _save_result(self, result: ClassificationResult, output_path: str) -> None: + """Save classification result to GeoTIFF.""" + with rasterio.open( + output_path, + "w", + driver="GTiff", + height=result.predicted_array.shape[0], + width=result.predicted_array.shape[1], + count=1, + dtype=result.predicted_array.dtype, + crs=result.crs, + transform=result.transform, + ) as dst: + dst.write(result.predicted_array, 1) + + def evaluate( + self, + raster_path: str, + vector_path: str, + class_column: str = "class", + ) -> dict[str, Any]: + """Evaluate classifier on validation data. + + Args: + raster_path: Path to GeoTIFF file. + vector_path: Path to Shapefile. + class_column: Name of column with class labels. + + Returns: + Dictionary with evaluation metrics. + """ + if not self._is_trained: + raise RuntimeError("Classifier must be trained before evaluation") + + vector = load_vector(vector_path, class_column) + X, y_true = extract_raster_values_by_polygons(raster_path, vector) + + y_pred = self.strategy.predict(X) + + return { + "accuracy": accuracy_score(y_true, y_pred), + "classification_report": classification_report(y_true, y_pred), + "confusion_matrix": confusion_matrix(y_true, y_pred).tolist(), + } diff --git a/src/data/__init__.py b/src/data/__init__.py new file mode 100644 index 0000000..9f5b902 --- /dev/null +++ b/src/data/__init__.py @@ -0,0 +1,11 @@ +"""Data module.""" + +from .loader import RasterData, VectorData, load_raster, load_vector, extract_raster_values_by_polygons + +__all__ = [ + "RasterData", + "VectorData", + "load_raster", + "load_vector", + "extract_raster_values_by_polygons", +] diff --git a/src/data/loader.py b/src/data/loader.py new file mode 100644 index 0000000..c5fcf55 --- /dev/null +++ b/src/data/loader.py @@ -0,0 +1,139 @@ +"""Data loading module for GIS classification.""" + +from dataclasses import dataclass +from typing import Optional +import rasterio +import geopandas as gpd +import numpy as np +from rasterio.mask import mask + + +@dataclass +class RasterData: + """Container for raster data.""" + + array: np.ndarray + transform: rasterio.transform.Affine + crs: rasterio.crs.CRS + nodata: Optional[float] = None + bands: int = 1 + + +@dataclass +class VectorData: + """Container for vector data.""" + + gdf: gpd.GeoDataFrame + class_column: str + + +def load_raster(path: str) -> RasterData: + """Load GeoTIFF raster file. + + Args: + path: Path to the GeoTIFF file. + + Returns: + RasterData object with array, transform, CRS and metadata. + """ + with rasterio.open(path) as src: + # Read all bands + array = src.read() + transform = src.transform + crs = src.crs + nodata = src.nodata + + return RasterData( + array=array, + transform=transform, + crs=crs, + nodata=nodata, + bands=array.shape[0] if len(array.shape) == 3 else 1, + ) + + +def load_vector(path: str, class_column: str = "class") -> VectorData: + """Load Shapefile vector data. + + Args: + path: Path to the Shapefile. + class_column: Name of the column containing class labels. + + Returns: + VectorData object with GeoDataFrame and class column name. + + Raises: + ValueError: If class_column doesn't exist in the Shapefile. + """ + gdf = gpd.read_file(path) + + if class_column not in gdf.columns: + raise ValueError( + f"Column '{class_column}' not found. Available columns: {list(gdf.columns)}" + ) + + return VectorData(gdf=gdf, class_column=class_column) + + +def extract_raster_values_by_polygons( + raster_path: str, vector: VectorData +) -> tuple[np.ndarray, np.ndarray]: + """Extract raster values for each polygon in the vector layer. + + Args: + raster_path: Path to the GeoTIFF raster file. + vector: VectorData object. + + Returns: + Tuple of (features, labels) arrays for classification. + features: 2D array of shape (n_samples, n_bands) + labels: 1D array of shape (n_samples,) + """ + features_list = [] + labels_list = [] + + with rasterio.open(raster_path) as src: + nodata = src.nodata + + for idx, row in vector.gdf.iterrows(): + geometry = row[vector.gdf.geometry.name] + label = row[vector.class_column] + + # Mask raster by polygon geometry + try: + masked_data, _ = mask(src, [geometry], crop=True) + except Exception as e: + print(f"Skipping polygon {idx}: {e}") + continue + + # Get valid pixels (not nodata) + if nodata is not None: + valid_mask = masked_data != nodata + else: + valid_mask = np.ones_like(masked_data, dtype=bool) + + # Extract valid pixel values for all bands + if len(masked_data.shape) == 3: + # Multi-band: reshape to (bands, pixels) + band_values = masked_data[:, valid_mask[0]] + else: + # Single band + band_values = masked_data[valid_mask] + + if band_values.size > 0: + # Transpose to (pixels, bands) + if len(band_values.shape) == 1: + band_values = band_values.reshape(-1, 1) + else: + band_values = band_values.T + + features_list.append(band_values) + labels_list.append(np.full(band_values.shape[0], label)) + + if not features_list: + raise ValueError("No valid pixels extracted from polygons") + + features = np.vstack(features_list) + labels = np.concatenate(labels_list) + + return features, labels diff --git a/src/strategies/__init__.py b/src/strategies/__init__.py new file mode 100644 index 0000000..30b8f9e --- /dev/null +++ b/src/strategies/__init__.py @@ -0,0 +1,12 @@ +"""Strategies module for classification algorithms.""" + +from .base import ClassificationStrategy +from .classifiers import RandomForestStrategy, SVMStrategy, LogisticRegressionStrategy, MLEStrategy + +__all__ = [ + "ClassificationStrategy", + "RandomForestStrategy", + "SVMStrategy", + "LogisticRegressionStrategy", + "MLEStrategy", +] diff --git a/src/strategies/base.py b/src/strategies/base.py new file mode 100644 index 0000000..f8c5cbd --- /dev/null +++ b/src/strategies/base.py @@ -0,0 +1,61 @@ +"""Strategy interface for classification algorithms.""" + +from abc import ABC, abstractmethod +from typing import Any +import numpy as np + + +class ClassificationStrategy(ABC): + """Abstract base class for classification strategies. + + Implement this interface to add new classification algorithms. + """ + + @abstractmethod + def train(self, X: np.ndarray, y: np.ndarray) -> None: + """Train the classifier on provided data. + + Args: + X: Feature array of shape (n_samples, n_features). + y: Target labels of shape (n_samples,). + """ + pass + + @abstractmethod + def predict(self, X: np.ndarray) -> np.ndarray: + """Predict classes for input data. + + Args: + X: Feature array of shape (n_samples, n_features). + + Returns: + Predicted labels of shape (n_samples,). + """ + pass + + @abstractmethod + def predict_proba(self, X: np.ndarray) -> np.ndarray: + """Predict class probabilities for input data. + + Args: + X: Feature array of shape (n_samples, n_features). + + Returns: + Probability array of shape (n_samples, n_classes). + """ + pass + + @abstractmethod + def get_params(self) -> dict[str, Any]: + """Get classifier parameters. + + Returns: + Dictionary of classifier parameters. + """ + pass + + @property + @abstractmethod + def name(self) -> str: + """Return strategy name.""" + pass diff --git a/src/strategies/classifiers.py b/src/strategies/classifiers.py new file mode 100644 index 0000000..1b10196 --- /dev/null +++ b/src/strategies/classifiers.py @@ -0,0 +1,250 @@ +"""Built-in classification strategies.""" + +from typing import Any +import numpy as np +from scipy.stats import multivariate_normal +from sklearn.ensemble import RandomForestClassifier +from sklearn.svm import SVC +from sklearn.linear_model import LogisticRegression +from .base import ClassificationStrategy + + +class RandomForestStrategy(ClassificationStrategy): + """Random Forest classification strategy.""" + + def __init__( + self, + n_estimators: int = 100, + max_depth: int | None = None, + random_state: int = 42, + **kwargs + ): + self.n_estimators = n_estimators + self.max_depth = max_depth + self.random_state = random_state + self._clf = RandomForestClassifier( + n_estimators=n_estimators, + max_depth=max_depth, + random_state=random_state, + **kwargs + ) + + def train(self, X: np.ndarray, y: np.ndarray) -> None: + self._clf.fit(X, y) + + def predict(self, X: np.ndarray) -> np.ndarray: + return self._clf.predict(X) + + def predict_proba(self, X: np.ndarray) -> np.ndarray: + return self._clf.predict_proba(X) + + def get_params(self) -> dict[str, Any]: + return { + "n_estimators": self.n_estimators, + "max_depth": self.max_depth, + "random_state": self.random_state, + } + + @property + def name(self) -> str: + return "RandomForest" + + +class SVMStrategy(ClassificationStrategy): + """Support Vector Machine classification strategy.""" + + def __init__( + self, + kernel: str = "rbf", + C: float = 1.0, + gamma: str = "scale", + random_state: int = 42, + **kwargs + ): + self.kernel = kernel + self.C = C + self.gamma = gamma + self.random_state = random_state + self._clf = SVC( + kernel=kernel, + C=C, + gamma=gamma, + random_state=random_state, + probability=True, + **kwargs + ) + + def train(self, X: np.ndarray, y: np.ndarray) -> None: + self._clf.fit(X, y) + + def predict(self, X: np.ndarray) -> np.ndarray: + return self._clf.predict(X) + + def predict_proba(self, X: np.ndarray) -> np.ndarray: + return self._clf.predict_proba(X) + + def get_params(self) -> dict[str, Any]: + return { + "kernel": self.kernel, + "C": self.C, + "gamma": self.gamma, + "random_state": self.random_state, + } + + @property + def name(self) -> str: + return "SVM" + + +class LogisticRegressionStrategy(ClassificationStrategy): + """Logistic Regression classification strategy.""" + + def __init__( + self, + penalty: str = "l2", + C: float = 1.0, + max_iter: int = 1000, + random_state: int = 42, + **kwargs + ): + self.penalty = penalty + self.C = C + self.max_iter = max_iter + self.random_state = random_state + self._clf = LogisticRegression( + penalty=penalty, + C=C, + max_iter=max_iter, + random_state=random_state, + **kwargs + ) + + def train(self, X: np.ndarray, y: np.ndarray) -> None: + self._clf.fit(X, y) + + def predict(self, X: np.ndarray) -> np.ndarray: + return self._clf.predict(X) + + def predict_proba(self, X: np.ndarray) -> np.ndarray: + return self._clf.predict_proba(X) + + def get_params(self) -> dict[str, Any]: + return { + "penalty": self.penalty, + "C": self.C, + "max_iter": self.max_iter, + "random_state": self.random_state, + } + + @property + def name(self) -> str: + return "LogisticRegression" + + +class MLEStrategy(ClassificationStrategy): + """Maximum Likelihood Estimation classification strategy. + + Assumes each class follows a multivariate normal distribution. + Classic algorithm for GIS/remote sensing classification. + """ + + def __init__(self, reg_covar: float = 1e-6): + """Initialize MLE classifier. + + Args: + reg_covar: Regularization for covariance matrix stability. + """ + self.reg_covar = reg_covar + self._means: dict[Any, np.ndarray] = {} + self._covs: dict[Any, np.ndarray] = {} + self._priors: dict[Any, float] = {} + self._classes: np.ndarray | None = None + + def train(self, X: np.ndarray, y: np.ndarray) -> None: + """Estimate mean, covariance and prior for each class.""" + self._classes = np.unique(y) + self._means = {} + self._covs = {} + self._priors = {} + + n_samples = len(y) + + for cls in self._classes: + X_cls = X[y == cls] + + # Prior probability + self._priors[cls] = len(X_cls) / n_samples + + # Mean vector + self._means[cls] = np.mean(X_cls, axis=0) + + # Covariance matrix with regularization + cov = np.cov(X_cls, rowvar=False) + if cov.ndim == 0: + cov = np.array([[cov]]) + cov += np.eye(cov.shape[0]) * self.reg_covar + self._covs[cls] = cov + + def _compute_log_likelihood(self, X: np.ndarray, cls: Any) -> np.ndarray: + """Compute log-likelihood for a class.""" + mean = self._means[cls] + cov = self._covs[cls] + prior = self._priors[cls] + + try: + rv = multivariate_normal(mean=mean, cov=cov, allow_singular=True) + log_likelihood = rv.logpdf(X) + except Exception: + # Fallback: compute manually + diff = X - mean + try: + cov_inv = np.linalg.inv(cov) + except np.linalg.LinAlgError: + cov_inv = np.linalg.pinv(cov) + + mahalanobis = np.sum(diff @ cov_inv * diff, axis=1) + log_det = np.linalg.slogdet(cov)[1] + log_likelihood = -0.5 * (X.shape[1] * np.log(2 * np.pi) + log_det + mahalanobis) + + return log_likelihood + np.log(prior) + + def predict(self, X: np.ndarray) -> np.ndarray: + """Predict class with maximum likelihood.""" + if self._classes is None: + raise RuntimeError("Classifier not trained") + + # Compute log-likelihoods for all classes + log_likelihoods = np.zeros((X.shape[0], len(self._classes))) + + for i, cls in enumerate(self._classes): + log_likelihoods[:, i] = self._compute_log_likelihood(X, cls) + + # Return class with maximum likelihood + return self._classes[np.argmax(log_likelihoods, axis=1)] + + def predict_proba(self, X: np.ndarray) -> np.ndarray: + """Predict class probabilities using softmax of log-likelihoods.""" + if self._classes is None: + raise RuntimeError("Classifier not trained") + + # Compute log-likelihoods + log_likelihoods = np.zeros((X.shape[0], len(self._classes))) + + for i, cls in enumerate(self._classes): + log_likelihoods[:, i] = self._compute_log_likelihood(X, cls) + + # Convert to probabilities via softmax + log_likelihoods -= np.max(log_likelihoods, axis=1, keepdims=True) + exp_ll = np.exp(log_likelihoods) + probabilities = exp_ll / np.sum(exp_ll, axis=1, keepdims=True) + + return probabilities + + def get_params(self) -> dict[str, Any]: + return { + "reg_covar": self.reg_covar, + } + + @property + def name(self) -> str: + return "MLE" diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..285e1d4 --- /dev/null +++ b/src/utils/__init__.py @@ -0,0 +1 @@ +"""Utilities module."""