"""Data loading module for GIS classification.""" from dataclasses import dataclass from typing import Optional import rasterio import geopandas as gpd import numpy as np from rasterio.mask import mask @dataclass class RasterData: """Container for raster data.""" array: np.ndarray transform: rasterio.transform.Affine crs: rasterio.crs.CRS nodata: Optional[float] = None bands: int = 1 @dataclass class VectorData: """Container for vector data.""" gdf: gpd.GeoDataFrame class_column: str def load_raster(path: str) -> RasterData: """Load GeoTIFF raster file. Args: path: Path to the GeoTIFF file. Returns: RasterData object with array, transform, CRS and metadata. """ with rasterio.open(path) as src: # Read all bands array = src.read() transform = src.transform crs = src.crs nodata = src.nodata return RasterData( array=array, transform=transform, crs=crs, nodata=nodata, bands=array.shape[0] if len(array.shape) == 3 else 1, ) def load_vector(path: str, class_column: str = "class") -> VectorData: """Load Shapefile vector data. Args: path: Path to the Shapefile. class_column: Name of the column containing class labels. Returns: VectorData object with GeoDataFrame and class column name. Raises: ValueError: If class_column doesn't exist in the Shapefile. """ gdf = gpd.read_file(path) if class_column not in gdf.columns: raise ValueError( f"Column '{class_column}' not found. Available columns: {list(gdf.columns)}" ) return VectorData(gdf=gdf, class_column=class_column) def extract_raster_values_by_polygons( raster_path: str, vector: VectorData ) -> tuple[np.ndarray, np.ndarray]: """Extract raster values for each polygon in the vector layer. Args: raster_path: Path to the GeoTIFF raster file. vector: VectorData object. Returns: Tuple of (features, labels) arrays for classification. features: 2D array of shape (n_samples, n_bands) labels: 1D array of shape (n_samples,) """ features_list = [] labels_list = [] with rasterio.open(raster_path) as src: nodata = src.nodata for idx, row in vector.gdf.iterrows(): geometry = row[vector.gdf.geometry.name] label = row[vector.class_column] # Mask raster by polygon geometry try: masked_data, _ = mask(src, [geometry], crop=True) except Exception as e: print(f"Skipping polygon {idx}: {e}") continue # Get valid pixels (not nodata) if nodata is not None: valid_mask = masked_data != nodata else: valid_mask = np.ones_like(masked_data, dtype=bool) # Extract valid pixel values for all bands if len(masked_data.shape) == 3: # Multi-band: reshape to (bands, pixels) band_values = masked_data[:, valid_mask[0]] else: # Single band band_values = masked_data[valid_mask] if band_values.size > 0: # Transpose to (pixels, bands) if len(band_values.shape) == 1: band_values = band_values.reshape(-1, 1) else: band_values = band_values.T features_list.append(band_values) labels_list.append(np.full(band_values.shape[0], label)) if not features_list: raise ValueError("No valid pixels extracted from polygons") features = np.vstack(features_list) labels = np.concatenate(labels_list) return features, labels