feat: initial release of GIS classification project with strategy-based classifiers selector

This commit is contained in:
Andrew 2026-03-15 11:35:50 +07:00
commit af365cfe68
14 changed files with 1115 additions and 0 deletions

187
.gitignore vendored Normal file
View file

@ -0,0 +1,187 @@
# Created by https://www.toptal.com/developers/gitignore/api/python,gis
# Edit at https://www.toptal.com/developers/gitignore?templates=python,gis
### GIS ###
*.gpx
*.kml
*.tif
*.tif.aux.xml
*.tiff
*.shx
*.shp
*.prj
*.dbf
### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
### Python Patch ###
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
poetry.toml
# ruff
.ruff_cache/
# LSP config files
pyrightconfig.json
# End of https://www.toptal.com/developers/gitignore/api/python,gis

81
README.md Normal file
View file

@ -0,0 +1,81 @@
# GIS Classification
Python project for land cover classification using GIS data (GeoTIFF + Shapefile).
## Project Structure
```
gis-classification/
├── main.py # Main script with parameters
├── requirements.txt # Dependencies
├── data/ # Input data folder
├── output/ # Classification results
└── src/
├── classifier.py # Main classification pipeline
├── data/
│ └── loader.py # Data loading (GeoTIFF, Shapefile)
├── strategies/
│ ├── base.py # Strategy interface
│ └── classifiers.py # Built-in strategies (RF, SVM, LR)
└── utils/
```
## Installation
```bash
pip install -r requirements.txt
```
## Usage
1. Place your input files in `data/`:
- `landsat.tif` - GeoTIFF from Landsat
- `polygons.shp` - Shapefile with class labels
2. Configure parameters in `main.py`:
- Input/output paths
- Classification strategy (RandomForest, SVM, LogisticRegression)
- Training parameters
3. Run:
```bash
python main.py
```
## Adding Custom Classification Strategy
Create a new class implementing `ClassificationStrategy`:
```python
from src.strategies import ClassificationStrategy
import numpy as np
class MyCustomStrategy(ClassificationStrategy):
def train(self, X: np.ndarray, y: np.ndarray) -> None:
# Your training logic
pass
def predict(self, X: np.ndarray) -> np.ndarray:
# Your prediction logic
pass
def predict_proba(self, X: np.ndarray) -> np.ndarray:
pass
def get_params(self) -> dict:
pass
@property
def name(self) -> str:
return "MyCustom"
```
Then use in `main.py`:
```python
STRATEGY = MyCustomStrategy()
```
## Output
- `output/classified.tif` - Classified raster (GeoTIFF)
- Console output with accuracy metrics

0
data/.gitkeep Normal file
View file

116
main.py Normal file
View file

@ -0,0 +1,116 @@
"""Main script for GIS classification.
Configure parameters below to run classification.
"""
import os
from src import (
GISClassifier,
RandomForestStrategy,
SVMStrategy,
LogisticRegressionStrategy,
MLEStrategy,
)
# ==================== PARAMETERS ====================
# Input files
RASTER_PATH = os.path.join("data", "landsat.tif") # Path to GeoTIFF (Landsat)
VECTOR_PATH = os.path.join("data", "polygons.shx") # Path to Shapefile
# Column in Shapefile containing class labels
CLASS_COLUMN = "macroclass"
# Output file for classified raster
OUTPUT_PATH = os.path.join("output", "classified.tif")
# Classification strategy parameters
# Change strategy by uncommenting desired option:
# Option 1: Random Forest (recommended for GIS)
STRATEGY = RandomForestStrategy(
n_estimators=100,
max_depth=None,
random_state=42,
)
# Option 2: Support Vector Machine
# STRATEGY = SVMStrategy(
# kernel="rbf",
# C=1.0,
# gamma="scale",
# random_state=42,
# )
# Option 3: Logistic Regression
# STRATEGY = LogisticRegressionStrategy(
# penalty="l2",
# C=1.0,
# max_iter=1000,
# random_state=42,
# )
# Option 4: Maximum Likelihood Estimation (classic for GIS)
# STRATEGY = MLEStrategy(
# reg_covar=1e-6,
# )
# Training parameters
TEST_SIZE = 0.2 # Fraction for validation
RANDOM_STATE = 42 # Random seed
# ==================== RUN CLASSIFICATION ====================
def main():
# Check input files exist
if not os.path.exists(RASTER_PATH):
print(f"Error: Raster file not found: {RASTER_PATH}")
return
if not os.path.exists(VECTOR_PATH):
print(f"Error: Vector file not found: {VECTOR_PATH}")
return
# Create output directory if needed
os.makedirs(os.path.dirname(OUTPUT_PATH), exist_ok=True)
# Initialize classifier with strategy
print(f"Using classification strategy: {STRATEGY.name}")
print(f"Strategy parameters: {STRATEGY.get_params()}")
print()
classifier = GISClassifier(strategy=STRATEGY)
# Train
print("Training classifier...")
metrics = classifier.train(
raster_path=RASTER_PATH,
vector_path=VECTOR_PATH,
class_column=CLASS_COLUMN,
test_size=TEST_SIZE,
random_state=RANDOM_STATE,
)
print(f"Training samples: {metrics['train_samples']}")
print(f"Validation samples: {metrics['val_samples']}")
print(f"Accuracy: {metrics['accuracy']:.2%}")
print(f"Classes: {metrics['classes']}")
print()
# Predict
print(f"Classifying raster: {RASTER_PATH}")
result = classifier.predict(
raster_path=RASTER_PATH,
output_path=OUTPUT_PATH,
)
print(f"Classification complete!")
print(f"Output saved to: {OUTPUT_PATH}")
print(f"Output shape: {result.predicted_array.shape}")
print(f"Classes in output: {result.classes}")
if __name__ == "__main__":
main()

0
output/.gitkeep Normal file
View file

7
requirements.txt Normal file
View file

@ -0,0 +1,7 @@
rasterio>=1.3.0
geopandas>=0.12.0
shapely>=2.0.0
scikit-learn>=1.3.0
scipy>=1.10.0
numpy>=1.24.0
pandas>=2.0.0

25
src/__init__.py Normal file
View file

@ -0,0 +1,25 @@
"""GIS Classification Package."""
from .classifier import GISClassifier, ClassificationResult
from .data import RasterData, VectorData, load_raster, load_vector
from .strategies import (
ClassificationStrategy,
RandomForestStrategy,
SVMStrategy,
LogisticRegressionStrategy,
MLEStrategy,
)
__all__ = [
"GISClassifier",
"ClassificationResult",
"RasterData",
"VectorData",
"load_raster",
"load_vector",
"ClassificationStrategy",
"RandomForestStrategy",
"SVMStrategy",
"LogisticRegressionStrategy",
"MLEStrategy",
]

225
src/classifier.py Normal file
View file

@ -0,0 +1,225 @@
"""Main classification pipeline."""
from dataclasses import dataclass, field
from typing import Any
import numpy as np
import rasterio
from rasterio.transform import from_bounds
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score, confusion_matrix
from .data import RasterData, VectorData, load_raster, load_vector, extract_raster_values_by_polygons
from .strategies import ClassificationStrategy
@dataclass
class ClassificationResult:
"""Container for classification results."""
predicted_array: np.ndarray
transform: rasterio.transform.Affine
crs: rasterio.crs.CRS
classes: np.ndarray
accuracy: float | None = None
report: str | None = None
confusion_matrix: np.ndarray | None = None
metadata: dict[str, Any] = field(default_factory=dict)
class GISClassifier:
"""Main classifier for GIS data.
Uses strategy pattern to allow plugging different classification algorithms.
"""
def __init__(self, strategy: ClassificationStrategy):
"""Initialize classifier with a strategy.
Args:
strategy: ClassificationStrategy instance to use for classification.
"""
self.strategy = strategy
self._is_trained = False
self._classes: np.ndarray | None = None
def train(
self,
raster_path: str,
vector_path: str,
class_column: str = "class",
test_size: float = 0.2,
random_state: int = 42,
) -> dict[str, float]:
"""Train the classifier using raster and vector data.
Args:
raster_path: Path to GeoTIFF file.
vector_path: Path to Shapefile.
class_column: Name of column with class labels in Shapefile.
test_size: Fraction of data to use for validation.
random_state: Random seed for reproducibility.
Returns:
Dictionary with training metrics.
"""
# Load data
vector = load_vector(vector_path, class_column)
# Extract features and labels
X, y = extract_raster_values_by_polygons(raster_path, vector)
# Split for validation
X_train, X_val, y_train, y_val = train_test_split(
X, y, test_size=test_size, random_state=random_state, stratify=y
)
# Train the strategy
self.strategy.train(X_train, y_train)
self._is_trained = True
self._classes = np.unique(y)
# Evaluate
y_pred = self.strategy.predict(X_val)
accuracy = accuracy_score(y_val, y_pred)
return {
"train_samples": len(X_train),
"val_samples": len(X_val),
"accuracy": accuracy,
"classes": list(self._classes),
}
def predict(self, raster_path: str, output_path: str | None = None) -> ClassificationResult:
"""Classify a raster file.
Args:
raster_path: Path to GeoTIFF file to classify.
output_path: Optional path to save classified raster.
Returns:
ClassificationResult with predicted array and metadata.
Raises:
RuntimeError: If classifier is not trained.
"""
if not self._is_trained:
raise RuntimeError("Classifier must be trained before prediction")
# Load raster
raster = load_raster(raster_path)
# Prepare data for prediction
# Reshape from (bands, height, width) to (pixels, bands)
if len(raster.array.shape) == 3:
bands, height, width = raster.array.shape
X = raster.array.transpose(1, 2, 0).reshape(-1, bands)
else:
height, width = raster.array.shape
X = raster.array.reshape(-1, 1)
# Predict
predictions = self.strategy.predict(X)
# Reshape back to 2D
predicted_array = predictions.reshape(height, width)
# Create result
result = ClassificationResult(
predicted_array=predicted_array.astype(np.uint8),
transform=raster.transform,
crs=raster.crs,
classes=self._classes if self._classes is not None else np.unique(predictions),
metadata=self.strategy.get_params(),
)
# Save if path provided
if output_path:
self._save_result(result, output_path)
return result
def predict_from_data(
self,
raster: RasterData,
output_path: str | None = None,
) -> ClassificationResult:
"""Classify already loaded raster data.
Args:
raster: RasterData object.
output_path: Optional path to save classified raster.
Returns:
ClassificationResult with predicted array and metadata.
"""
if not self._is_trained:
raise RuntimeError("Classifier must be trained before prediction")
# Prepare data for prediction
if len(raster.array.shape) == 3:
bands, height, width = raster.array.shape
X = raster.array.transpose(1, 2, 0).reshape(-1, bands)
else:
height, width = raster.array.shape
X = raster.array.reshape(-1, 1)
# Predict
predictions = self.strategy.predict(X)
predicted_array = predictions.reshape(height, width)
result = ClassificationResult(
predicted_array=predicted_array.astype(np.uint8),
transform=raster.transform,
crs=raster.crs,
classes=self._classes if self._classes is not None else np.unique(predictions),
metadata=self.strategy.get_params(),
)
if output_path:
self._save_result(result, output_path)
return result
def _save_result(self, result: ClassificationResult, output_path: str) -> None:
"""Save classification result to GeoTIFF."""
with rasterio.open(
output_path,
"w",
driver="GTiff",
height=result.predicted_array.shape[0],
width=result.predicted_array.shape[1],
count=1,
dtype=result.predicted_array.dtype,
crs=result.crs,
transform=result.transform,
) as dst:
dst.write(result.predicted_array, 1)
def evaluate(
self,
raster_path: str,
vector_path: str,
class_column: str = "class",
) -> dict[str, Any]:
"""Evaluate classifier on validation data.
Args:
raster_path: Path to GeoTIFF file.
vector_path: Path to Shapefile.
class_column: Name of column with class labels.
Returns:
Dictionary with evaluation metrics.
"""
if not self._is_trained:
raise RuntimeError("Classifier must be trained before evaluation")
vector = load_vector(vector_path, class_column)
X, y_true = extract_raster_values_by_polygons(raster_path, vector)
y_pred = self.strategy.predict(X)
return {
"accuracy": accuracy_score(y_true, y_pred),
"classification_report": classification_report(y_true, y_pred),
"confusion_matrix": confusion_matrix(y_true, y_pred).tolist(),
}

11
src/data/__init__.py Normal file
View file

@ -0,0 +1,11 @@
"""Data module."""
from .loader import RasterData, VectorData, load_raster, load_vector, extract_raster_values_by_polygons
__all__ = [
"RasterData",
"VectorData",
"load_raster",
"load_vector",
"extract_raster_values_by_polygons",
]

139
src/data/loader.py Normal file
View file

@ -0,0 +1,139 @@
"""Data loading module for GIS classification."""
from dataclasses import dataclass
from typing import Optional
import rasterio
import geopandas as gpd
import numpy as np
from rasterio.mask import mask
@dataclass
class RasterData:
"""Container for raster data."""
array: np.ndarray
transform: rasterio.transform.Affine
crs: rasterio.crs.CRS
nodata: Optional[float] = None
bands: int = 1
@dataclass
class VectorData:
"""Container for vector data."""
gdf: gpd.GeoDataFrame
class_column: str
def load_raster(path: str) -> RasterData:
"""Load GeoTIFF raster file.
Args:
path: Path to the GeoTIFF file.
Returns:
RasterData object with array, transform, CRS and metadata.
"""
with rasterio.open(path) as src:
# Read all bands
array = src.read()
transform = src.transform
crs = src.crs
nodata = src.nodata
return RasterData(
array=array,
transform=transform,
crs=crs,
nodata=nodata,
bands=array.shape[0] if len(array.shape) == 3 else 1,
)
def load_vector(path: str, class_column: str = "class") -> VectorData:
"""Load Shapefile vector data.
Args:
path: Path to the Shapefile.
class_column: Name of the column containing class labels.
Returns:
VectorData object with GeoDataFrame and class column name.
Raises:
ValueError: If class_column doesn't exist in the Shapefile.
"""
gdf = gpd.read_file(path)
if class_column not in gdf.columns:
raise ValueError(
f"Column '{class_column}' not found. Available columns: {list(gdf.columns)}"
)
return VectorData(gdf=gdf, class_column=class_column)
def extract_raster_values_by_polygons(
raster_path: str, vector: VectorData
) -> tuple[np.ndarray, np.ndarray]:
"""Extract raster values for each polygon in the vector layer.
Args:
raster_path: Path to the GeoTIFF raster file.
vector: VectorData object.
Returns:
Tuple of (features, labels) arrays for classification.
features: 2D array of shape (n_samples, n_bands)
labels: 1D array of shape (n_samples,)
"""
features_list = []
labels_list = []
with rasterio.open(raster_path) as src:
nodata = src.nodata
for idx, row in vector.gdf.iterrows():
geometry = row[vector.gdf.geometry.name]
label = row[vector.class_column]
# Mask raster by polygon geometry
try:
masked_data, _ = mask(src, [geometry], crop=True)
except Exception as e:
print(f"Skipping polygon {idx}: {e}")
continue
# Get valid pixels (not nodata)
if nodata is not None:
valid_mask = masked_data != nodata
else:
valid_mask = np.ones_like(masked_data, dtype=bool)
# Extract valid pixel values for all bands
if len(masked_data.shape) == 3:
# Multi-band: reshape to (bands, pixels)
band_values = masked_data[:, valid_mask[0]]
else:
# Single band
band_values = masked_data[valid_mask]
if band_values.size > 0:
# Transpose to (pixels, bands)
if len(band_values.shape) == 1:
band_values = band_values.reshape(-1, 1)
else:
band_values = band_values.T
features_list.append(band_values)
labels_list.append(np.full(band_values.shape[0], label))
if not features_list:
raise ValueError("No valid pixels extracted from polygons")
features = np.vstack(features_list)
labels = np.concatenate(labels_list)
return features, labels

View file

@ -0,0 +1,12 @@
"""Strategies module for classification algorithms."""
from .base import ClassificationStrategy
from .classifiers import RandomForestStrategy, SVMStrategy, LogisticRegressionStrategy, MLEStrategy
__all__ = [
"ClassificationStrategy",
"RandomForestStrategy",
"SVMStrategy",
"LogisticRegressionStrategy",
"MLEStrategy",
]

61
src/strategies/base.py Normal file
View file

@ -0,0 +1,61 @@
"""Strategy interface for classification algorithms."""
from abc import ABC, abstractmethod
from typing import Any
import numpy as np
class ClassificationStrategy(ABC):
"""Abstract base class for classification strategies.
Implement this interface to add new classification algorithms.
"""
@abstractmethod
def train(self, X: np.ndarray, y: np.ndarray) -> None:
"""Train the classifier on provided data.
Args:
X: Feature array of shape (n_samples, n_features).
y: Target labels of shape (n_samples,).
"""
pass
@abstractmethod
def predict(self, X: np.ndarray) -> np.ndarray:
"""Predict classes for input data.
Args:
X: Feature array of shape (n_samples, n_features).
Returns:
Predicted labels of shape (n_samples,).
"""
pass
@abstractmethod
def predict_proba(self, X: np.ndarray) -> np.ndarray:
"""Predict class probabilities for input data.
Args:
X: Feature array of shape (n_samples, n_features).
Returns:
Probability array of shape (n_samples, n_classes).
"""
pass
@abstractmethod
def get_params(self) -> dict[str, Any]:
"""Get classifier parameters.
Returns:
Dictionary of classifier parameters.
"""
pass
@property
@abstractmethod
def name(self) -> str:
"""Return strategy name."""
pass

View file

@ -0,0 +1,250 @@
"""Built-in classification strategies."""
from typing import Any
import numpy as np
from scipy.stats import multivariate_normal
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from .base import ClassificationStrategy
class RandomForestStrategy(ClassificationStrategy):
"""Random Forest classification strategy."""
def __init__(
self,
n_estimators: int = 100,
max_depth: int | None = None,
random_state: int = 42,
**kwargs
):
self.n_estimators = n_estimators
self.max_depth = max_depth
self.random_state = random_state
self._clf = RandomForestClassifier(
n_estimators=n_estimators,
max_depth=max_depth,
random_state=random_state,
**kwargs
)
def train(self, X: np.ndarray, y: np.ndarray) -> None:
self._clf.fit(X, y)
def predict(self, X: np.ndarray) -> np.ndarray:
return self._clf.predict(X)
def predict_proba(self, X: np.ndarray) -> np.ndarray:
return self._clf.predict_proba(X)
def get_params(self) -> dict[str, Any]:
return {
"n_estimators": self.n_estimators,
"max_depth": self.max_depth,
"random_state": self.random_state,
}
@property
def name(self) -> str:
return "RandomForest"
class SVMStrategy(ClassificationStrategy):
"""Support Vector Machine classification strategy."""
def __init__(
self,
kernel: str = "rbf",
C: float = 1.0,
gamma: str = "scale",
random_state: int = 42,
**kwargs
):
self.kernel = kernel
self.C = C
self.gamma = gamma
self.random_state = random_state
self._clf = SVC(
kernel=kernel,
C=C,
gamma=gamma,
random_state=random_state,
probability=True,
**kwargs
)
def train(self, X: np.ndarray, y: np.ndarray) -> None:
self._clf.fit(X, y)
def predict(self, X: np.ndarray) -> np.ndarray:
return self._clf.predict(X)
def predict_proba(self, X: np.ndarray) -> np.ndarray:
return self._clf.predict_proba(X)
def get_params(self) -> dict[str, Any]:
return {
"kernel": self.kernel,
"C": self.C,
"gamma": self.gamma,
"random_state": self.random_state,
}
@property
def name(self) -> str:
return "SVM"
class LogisticRegressionStrategy(ClassificationStrategy):
"""Logistic Regression classification strategy."""
def __init__(
self,
penalty: str = "l2",
C: float = 1.0,
max_iter: int = 1000,
random_state: int = 42,
**kwargs
):
self.penalty = penalty
self.C = C
self.max_iter = max_iter
self.random_state = random_state
self._clf = LogisticRegression(
penalty=penalty,
C=C,
max_iter=max_iter,
random_state=random_state,
**kwargs
)
def train(self, X: np.ndarray, y: np.ndarray) -> None:
self._clf.fit(X, y)
def predict(self, X: np.ndarray) -> np.ndarray:
return self._clf.predict(X)
def predict_proba(self, X: np.ndarray) -> np.ndarray:
return self._clf.predict_proba(X)
def get_params(self) -> dict[str, Any]:
return {
"penalty": self.penalty,
"C": self.C,
"max_iter": self.max_iter,
"random_state": self.random_state,
}
@property
def name(self) -> str:
return "LogisticRegression"
class MLEStrategy(ClassificationStrategy):
"""Maximum Likelihood Estimation classification strategy.
Assumes each class follows a multivariate normal distribution.
Classic algorithm for GIS/remote sensing classification.
"""
def __init__(self, reg_covar: float = 1e-6):
"""Initialize MLE classifier.
Args:
reg_covar: Regularization for covariance matrix stability.
"""
self.reg_covar = reg_covar
self._means: dict[Any, np.ndarray] = {}
self._covs: dict[Any, np.ndarray] = {}
self._priors: dict[Any, float] = {}
self._classes: np.ndarray | None = None
def train(self, X: np.ndarray, y: np.ndarray) -> None:
"""Estimate mean, covariance and prior for each class."""
self._classes = np.unique(y)
self._means = {}
self._covs = {}
self._priors = {}
n_samples = len(y)
for cls in self._classes:
X_cls = X[y == cls]
# Prior probability
self._priors[cls] = len(X_cls) / n_samples
# Mean vector
self._means[cls] = np.mean(X_cls, axis=0)
# Covariance matrix with regularization
cov = np.cov(X_cls, rowvar=False)
if cov.ndim == 0:
cov = np.array([[cov]])
cov += np.eye(cov.shape[0]) * self.reg_covar
self._covs[cls] = cov
def _compute_log_likelihood(self, X: np.ndarray, cls: Any) -> np.ndarray:
"""Compute log-likelihood for a class."""
mean = self._means[cls]
cov = self._covs[cls]
prior = self._priors[cls]
try:
rv = multivariate_normal(mean=mean, cov=cov, allow_singular=True)
log_likelihood = rv.logpdf(X)
except Exception:
# Fallback: compute manually
diff = X - mean
try:
cov_inv = np.linalg.inv(cov)
except np.linalg.LinAlgError:
cov_inv = np.linalg.pinv(cov)
mahalanobis = np.sum(diff @ cov_inv * diff, axis=1)
log_det = np.linalg.slogdet(cov)[1]
log_likelihood = -0.5 * (X.shape[1] * np.log(2 * np.pi) + log_det + mahalanobis)
return log_likelihood + np.log(prior)
def predict(self, X: np.ndarray) -> np.ndarray:
"""Predict class with maximum likelihood."""
if self._classes is None:
raise RuntimeError("Classifier not trained")
# Compute log-likelihoods for all classes
log_likelihoods = np.zeros((X.shape[0], len(self._classes)))
for i, cls in enumerate(self._classes):
log_likelihoods[:, i] = self._compute_log_likelihood(X, cls)
# Return class with maximum likelihood
return self._classes[np.argmax(log_likelihoods, axis=1)]
def predict_proba(self, X: np.ndarray) -> np.ndarray:
"""Predict class probabilities using softmax of log-likelihoods."""
if self._classes is None:
raise RuntimeError("Classifier not trained")
# Compute log-likelihoods
log_likelihoods = np.zeros((X.shape[0], len(self._classes)))
for i, cls in enumerate(self._classes):
log_likelihoods[:, i] = self._compute_log_likelihood(X, cls)
# Convert to probabilities via softmax
log_likelihoods -= np.max(log_likelihoods, axis=1, keepdims=True)
exp_ll = np.exp(log_likelihoods)
probabilities = exp_ll / np.sum(exp_ll, axis=1, keepdims=True)
return probabilities
def get_params(self) -> dict[str, Any]:
return {
"reg_covar": self.reg_covar,
}
@property
def name(self) -> str:
return "MLE"

1
src/utils/__init__.py Normal file
View file

@ -0,0 +1 @@
"""Utilities module."""