feat: initial release of GIS classification project with strategy-based classifiers selector

This commit is contained in:
Andrew 2026-03-15 11:35:50 +07:00
commit af365cfe68
14 changed files with 1115 additions and 0 deletions

139
src/data/loader.py Normal file
View file

@ -0,0 +1,139 @@
"""Data loading module for GIS classification."""
from dataclasses import dataclass
from typing import Optional
import rasterio
import geopandas as gpd
import numpy as np
from rasterio.mask import mask
@dataclass
class RasterData:
"""Container for raster data."""
array: np.ndarray
transform: rasterio.transform.Affine
crs: rasterio.crs.CRS
nodata: Optional[float] = None
bands: int = 1
@dataclass
class VectorData:
"""Container for vector data."""
gdf: gpd.GeoDataFrame
class_column: str
def load_raster(path: str) -> RasterData:
"""Load GeoTIFF raster file.
Args:
path: Path to the GeoTIFF file.
Returns:
RasterData object with array, transform, CRS and metadata.
"""
with rasterio.open(path) as src:
# Read all bands
array = src.read()
transform = src.transform
crs = src.crs
nodata = src.nodata
return RasterData(
array=array,
transform=transform,
crs=crs,
nodata=nodata,
bands=array.shape[0] if len(array.shape) == 3 else 1,
)
def load_vector(path: str, class_column: str = "class") -> VectorData:
"""Load Shapefile vector data.
Args:
path: Path to the Shapefile.
class_column: Name of the column containing class labels.
Returns:
VectorData object with GeoDataFrame and class column name.
Raises:
ValueError: If class_column doesn't exist in the Shapefile.
"""
gdf = gpd.read_file(path)
if class_column not in gdf.columns:
raise ValueError(
f"Column '{class_column}' not found. Available columns: {list(gdf.columns)}"
)
return VectorData(gdf=gdf, class_column=class_column)
def extract_raster_values_by_polygons(
raster_path: str, vector: VectorData
) -> tuple[np.ndarray, np.ndarray]:
"""Extract raster values for each polygon in the vector layer.
Args:
raster_path: Path to the GeoTIFF raster file.
vector: VectorData object.
Returns:
Tuple of (features, labels) arrays for classification.
features: 2D array of shape (n_samples, n_bands)
labels: 1D array of shape (n_samples,)
"""
features_list = []
labels_list = []
with rasterio.open(raster_path) as src:
nodata = src.nodata
for idx, row in vector.gdf.iterrows():
geometry = row[vector.gdf.geometry.name]
label = row[vector.class_column]
# Mask raster by polygon geometry
try:
masked_data, _ = mask(src, [geometry], crop=True)
except Exception as e:
print(f"Skipping polygon {idx}: {e}")
continue
# Get valid pixels (not nodata)
if nodata is not None:
valid_mask = masked_data != nodata
else:
valid_mask = np.ones_like(masked_data, dtype=bool)
# Extract valid pixel values for all bands
if len(masked_data.shape) == 3:
# Multi-band: reshape to (bands, pixels)
band_values = masked_data[:, valid_mask[0]]
else:
# Single band
band_values = masked_data[valid_mask]
if band_values.size > 0:
# Transpose to (pixels, bands)
if len(band_values.shape) == 1:
band_values = band_values.reshape(-1, 1)
else:
band_values = band_values.T
features_list.append(band_values)
labels_list.append(np.full(band_values.shape[0], label))
if not features_list:
raise ValueError("No valid pixels extracted from polygons")
features = np.vstack(features_list)
labels = np.concatenate(labels_list)
return features, labels