139 lines
3.8 KiB
Python
139 lines
3.8 KiB
Python
"""Data loading module for GIS classification."""
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Optional
|
|
import rasterio
|
|
import geopandas as gpd
|
|
import numpy as np
|
|
from rasterio.mask import mask
|
|
|
|
|
|
@dataclass
|
|
class RasterData:
|
|
"""Container for raster data."""
|
|
|
|
array: np.ndarray
|
|
transform: rasterio.transform.Affine
|
|
crs: rasterio.crs.CRS
|
|
nodata: Optional[float] = None
|
|
bands: int = 1
|
|
|
|
|
|
@dataclass
|
|
class VectorData:
|
|
"""Container for vector data."""
|
|
|
|
gdf: gpd.GeoDataFrame
|
|
class_column: str
|
|
|
|
|
|
def load_raster(path: str) -> RasterData:
|
|
"""Load GeoTIFF raster file.
|
|
|
|
Args:
|
|
path: Path to the GeoTIFF file.
|
|
|
|
Returns:
|
|
RasterData object with array, transform, CRS and metadata.
|
|
"""
|
|
with rasterio.open(path) as src:
|
|
# Read all bands
|
|
array = src.read()
|
|
transform = src.transform
|
|
crs = src.crs
|
|
nodata = src.nodata
|
|
|
|
return RasterData(
|
|
array=array,
|
|
transform=transform,
|
|
crs=crs,
|
|
nodata=nodata,
|
|
bands=array.shape[0] if len(array.shape) == 3 else 1,
|
|
)
|
|
|
|
|
|
def load_vector(path: str, class_column: str = "class") -> VectorData:
|
|
"""Load Shapefile vector data.
|
|
|
|
Args:
|
|
path: Path to the Shapefile.
|
|
class_column: Name of the column containing class labels.
|
|
|
|
Returns:
|
|
VectorData object with GeoDataFrame and class column name.
|
|
|
|
Raises:
|
|
ValueError: If class_column doesn't exist in the Shapefile.
|
|
"""
|
|
gdf = gpd.read_file(path)
|
|
|
|
if class_column not in gdf.columns:
|
|
raise ValueError(
|
|
f"Column '{class_column}' not found. Available columns: {list(gdf.columns)}"
|
|
)
|
|
|
|
return VectorData(gdf=gdf, class_column=class_column)
|
|
|
|
|
|
def extract_raster_values_by_polygons(
|
|
raster_path: str, vector: VectorData
|
|
) -> tuple[np.ndarray, np.ndarray]:
|
|
"""Extract raster values for each polygon in the vector layer.
|
|
|
|
Args:
|
|
raster_path: Path to the GeoTIFF raster file.
|
|
vector: VectorData object.
|
|
|
|
Returns:
|
|
Tuple of (features, labels) arrays for classification.
|
|
features: 2D array of shape (n_samples, n_bands)
|
|
labels: 1D array of shape (n_samples,)
|
|
"""
|
|
features_list = []
|
|
labels_list = []
|
|
|
|
with rasterio.open(raster_path) as src:
|
|
nodata = src.nodata
|
|
|
|
for idx, row in vector.gdf.iterrows():
|
|
geometry = row[vector.gdf.geometry.name]
|
|
label = row[vector.class_column]
|
|
|
|
# Mask raster by polygon geometry
|
|
try:
|
|
masked_data, _ = mask(src, [geometry], crop=True)
|
|
except Exception as e:
|
|
print(f"Skipping polygon {idx}: {e}")
|
|
continue
|
|
|
|
# Get valid pixels (not nodata)
|
|
if nodata is not None:
|
|
valid_mask = masked_data != nodata
|
|
else:
|
|
valid_mask = np.ones_like(masked_data, dtype=bool)
|
|
|
|
# Extract valid pixel values for all bands
|
|
if len(masked_data.shape) == 3:
|
|
# Multi-band: reshape to (bands, pixels)
|
|
band_values = masked_data[:, valid_mask[0]]
|
|
else:
|
|
# Single band
|
|
band_values = masked_data[valid_mask]
|
|
|
|
if band_values.size > 0:
|
|
# Transpose to (pixels, bands)
|
|
if len(band_values.shape) == 1:
|
|
band_values = band_values.reshape(-1, 1)
|
|
else:
|
|
band_values = band_values.T
|
|
|
|
features_list.append(band_values)
|
|
labels_list.append(np.full(band_values.shape[0], label))
|
|
|
|
if not features_list:
|
|
raise ValueError("No valid pixels extracted from polygons")
|
|
|
|
features = np.vstack(features_list)
|
|
labels = np.concatenate(labels_list)
|
|
|
|
return features, labels
|