gis-classification/src/data/loader.py

139 lines
3.8 KiB
Python

"""Data loading module for GIS classification."""
from dataclasses import dataclass
from typing import Optional
import rasterio
import geopandas as gpd
import numpy as np
from rasterio.mask import mask
@dataclass
class RasterData:
"""Container for raster data."""
array: np.ndarray
transform: rasterio.transform.Affine
crs: rasterio.crs.CRS
nodata: Optional[float] = None
bands: int = 1
@dataclass
class VectorData:
"""Container for vector data."""
gdf: gpd.GeoDataFrame
class_column: str
def load_raster(path: str) -> RasterData:
"""Load GeoTIFF raster file.
Args:
path: Path to the GeoTIFF file.
Returns:
RasterData object with array, transform, CRS and metadata.
"""
with rasterio.open(path) as src:
# Read all bands
array = src.read()
transform = src.transform
crs = src.crs
nodata = src.nodata
return RasterData(
array=array,
transform=transform,
crs=crs,
nodata=nodata,
bands=array.shape[0] if len(array.shape) == 3 else 1,
)
def load_vector(path: str, class_column: str = "class") -> VectorData:
"""Load Shapefile vector data.
Args:
path: Path to the Shapefile.
class_column: Name of the column containing class labels.
Returns:
VectorData object with GeoDataFrame and class column name.
Raises:
ValueError: If class_column doesn't exist in the Shapefile.
"""
gdf = gpd.read_file(path)
if class_column not in gdf.columns:
raise ValueError(
f"Column '{class_column}' not found. Available columns: {list(gdf.columns)}"
)
return VectorData(gdf=gdf, class_column=class_column)
def extract_raster_values_by_polygons(
raster_path: str, vector: VectorData
) -> tuple[np.ndarray, np.ndarray]:
"""Extract raster values for each polygon in the vector layer.
Args:
raster_path: Path to the GeoTIFF raster file.
vector: VectorData object.
Returns:
Tuple of (features, labels) arrays for classification.
features: 2D array of shape (n_samples, n_bands)
labels: 1D array of shape (n_samples,)
"""
features_list = []
labels_list = []
with rasterio.open(raster_path) as src:
nodata = src.nodata
for idx, row in vector.gdf.iterrows():
geometry = row[vector.gdf.geometry.name]
label = row[vector.class_column]
# Mask raster by polygon geometry
try:
masked_data, _ = mask(src, [geometry], crop=True)
except Exception as e:
print(f"Skipping polygon {idx}: {e}")
continue
# Get valid pixels (not nodata)
if nodata is not None:
valid_mask = masked_data != nodata
else:
valid_mask = np.ones_like(masked_data, dtype=bool)
# Extract valid pixel values for all bands
if len(masked_data.shape) == 3:
# Multi-band: reshape to (bands, pixels)
band_values = masked_data[:, valid_mask[0]]
else:
# Single band
band_values = masked_data[valid_mask]
if band_values.size > 0:
# Transpose to (pixels, bands)
if len(band_values.shape) == 1:
band_values = band_values.reshape(-1, 1)
else:
band_values = band_values.T
features_list.append(band_values)
labels_list.append(np.full(band_values.shape[0], label))
if not features_list:
raise ValueError("No valid pixels extracted from polygons")
features = np.vstack(features_list)
labels = np.concatenate(labels_list)
return features, labels