feat: initial release of GIS classification project with strategy-based classifiers selector
This commit is contained in:
commit
af365cfe68
14 changed files with 1115 additions and 0 deletions
139
src/data/loader.py
Normal file
139
src/data/loader.py
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
"""Data loading module for GIS classification."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
import rasterio
|
||||
import geopandas as gpd
|
||||
import numpy as np
|
||||
from rasterio.mask import mask
|
||||
|
||||
|
||||
@dataclass
|
||||
class RasterData:
|
||||
"""Container for raster data."""
|
||||
|
||||
array: np.ndarray
|
||||
transform: rasterio.transform.Affine
|
||||
crs: rasterio.crs.CRS
|
||||
nodata: Optional[float] = None
|
||||
bands: int = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class VectorData:
|
||||
"""Container for vector data."""
|
||||
|
||||
gdf: gpd.GeoDataFrame
|
||||
class_column: str
|
||||
|
||||
|
||||
def load_raster(path: str) -> RasterData:
|
||||
"""Load GeoTIFF raster file.
|
||||
|
||||
Args:
|
||||
path: Path to the GeoTIFF file.
|
||||
|
||||
Returns:
|
||||
RasterData object with array, transform, CRS and metadata.
|
||||
"""
|
||||
with rasterio.open(path) as src:
|
||||
# Read all bands
|
||||
array = src.read()
|
||||
transform = src.transform
|
||||
crs = src.crs
|
||||
nodata = src.nodata
|
||||
|
||||
return RasterData(
|
||||
array=array,
|
||||
transform=transform,
|
||||
crs=crs,
|
||||
nodata=nodata,
|
||||
bands=array.shape[0] if len(array.shape) == 3 else 1,
|
||||
)
|
||||
|
||||
|
||||
def load_vector(path: str, class_column: str = "class") -> VectorData:
|
||||
"""Load Shapefile vector data.
|
||||
|
||||
Args:
|
||||
path: Path to the Shapefile.
|
||||
class_column: Name of the column containing class labels.
|
||||
|
||||
Returns:
|
||||
VectorData object with GeoDataFrame and class column name.
|
||||
|
||||
Raises:
|
||||
ValueError: If class_column doesn't exist in the Shapefile.
|
||||
"""
|
||||
gdf = gpd.read_file(path)
|
||||
|
||||
if class_column not in gdf.columns:
|
||||
raise ValueError(
|
||||
f"Column '{class_column}' not found. Available columns: {list(gdf.columns)}"
|
||||
)
|
||||
|
||||
return VectorData(gdf=gdf, class_column=class_column)
|
||||
|
||||
|
||||
def extract_raster_values_by_polygons(
|
||||
raster_path: str, vector: VectorData
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Extract raster values for each polygon in the vector layer.
|
||||
|
||||
Args:
|
||||
raster_path: Path to the GeoTIFF raster file.
|
||||
vector: VectorData object.
|
||||
|
||||
Returns:
|
||||
Tuple of (features, labels) arrays for classification.
|
||||
features: 2D array of shape (n_samples, n_bands)
|
||||
labels: 1D array of shape (n_samples,)
|
||||
"""
|
||||
features_list = []
|
||||
labels_list = []
|
||||
|
||||
with rasterio.open(raster_path) as src:
|
||||
nodata = src.nodata
|
||||
|
||||
for idx, row in vector.gdf.iterrows():
|
||||
geometry = row[vector.gdf.geometry.name]
|
||||
label = row[vector.class_column]
|
||||
|
||||
# Mask raster by polygon geometry
|
||||
try:
|
||||
masked_data, _ = mask(src, [geometry], crop=True)
|
||||
except Exception as e:
|
||||
print(f"Skipping polygon {idx}: {e}")
|
||||
continue
|
||||
|
||||
# Get valid pixels (not nodata)
|
||||
if nodata is not None:
|
||||
valid_mask = masked_data != nodata
|
||||
else:
|
||||
valid_mask = np.ones_like(masked_data, dtype=bool)
|
||||
|
||||
# Extract valid pixel values for all bands
|
||||
if len(masked_data.shape) == 3:
|
||||
# Multi-band: reshape to (bands, pixels)
|
||||
band_values = masked_data[:, valid_mask[0]]
|
||||
else:
|
||||
# Single band
|
||||
band_values = masked_data[valid_mask]
|
||||
|
||||
if band_values.size > 0:
|
||||
# Transpose to (pixels, bands)
|
||||
if len(band_values.shape) == 1:
|
||||
band_values = band_values.reshape(-1, 1)
|
||||
else:
|
||||
band_values = band_values.T
|
||||
|
||||
features_list.append(band_values)
|
||||
labels_list.append(np.full(band_values.shape[0], label))
|
||||
|
||||
if not features_list:
|
||||
raise ValueError("No valid pixels extracted from polygons")
|
||||
|
||||
features = np.vstack(features_list)
|
||||
labels = np.concatenate(labels_list)
|
||||
|
||||
return features, labels
|
||||
Loading…
Add table
Add a link
Reference in a new issue