gis-classification/main.py

115 lines
2.9 KiB
Python

"""Main script for GIS classification.
Configure parameters below to run classification.
"""
import os
from src import (
GISClassifier,
RandomForestStrategy,
SVMStrategy,
LogisticRegressionStrategy,
MLEStrategy,
)
# ==================== PARAMETERS ====================
# Input files
RASTER_PATH = os.path.join("data", "landsat.tif") # Path to GeoTIFF (Landsat)
VECTOR_PATH = os.path.join("data", "polygons.shx") # Path to Shapefile
# Column in Shapefile containing class labels
CLASS_COLUMN = "macroclass"
# Output file for classified raster
OUTPUT_PATH = os.path.join("output", "classified.tif")
# Classification strategy parameters
# Change strategy by uncommenting desired option:
# Option 1: Random Forest (recommended for GIS)
STRATEGY = RandomForestStrategy(
n_estimators=100,
max_depth=None,
random_state=42,
)
# # Option 2: Support Vector Machine
# STRATEGY = SVMStrategy(
# kernel="linear", # Fast prediction; use 'rbf' for better accuracy but much slower
# C=1.0,
# random_state=42,
# )
# Option 3: Logistic Regression
# STRATEGY = LogisticRegressionStrategy(
# penalty="l2",
# C=1.0,
# max_iter=1000,
# random_state=42,
# )
# Option 4: Maximum Likelihood Estimation (classic for GIS)
# STRATEGY = MLEStrategy(
# reg_covar=1e-6,
# )
# Training parameters
TEST_SIZE = 0.2 # Fraction for validation
RANDOM_STATE = 42 # Random seed
# ==================== RUN CLASSIFICATION ====================
def main():
# Check input files exist
if not os.path.exists(RASTER_PATH):
print(f"Error: Raster file not found: {RASTER_PATH}")
return
if not os.path.exists(VECTOR_PATH):
print(f"Error: Vector file not found: {VECTOR_PATH}")
return
# Create output directory if needed
os.makedirs(os.path.dirname(OUTPUT_PATH), exist_ok=True)
# Initialize classifier with strategy
print(f"Using classification strategy: {STRATEGY.name}")
print(f"Strategy parameters: {STRATEGY.get_params()}")
print()
classifier = GISClassifier(strategy=STRATEGY)
# Train
print("Training classifier...")
metrics = classifier.train(
raster_path=RASTER_PATH,
vector_path=VECTOR_PATH,
class_column=CLASS_COLUMN,
test_size=TEST_SIZE,
random_state=RANDOM_STATE,
)
print(f"Training samples: {metrics['train_samples']}")
print(f"Validation samples: {metrics['val_samples']}")
print(f"Accuracy: {metrics['accuracy']:.2%}")
print(f"Classes: {metrics['classes']}")
print()
# Predict
print(f"Classifying raster: {RASTER_PATH}")
result = classifier.predict(
raster_path=RASTER_PATH,
output_path=OUTPUT_PATH,
)
print(f"Classification complete!")
print(f"Output saved to: {OUTPUT_PATH}")
print(f"Output shape: {result.predicted_array.shape}")
print(f"Classes in output: {result.classes}")
if __name__ == "__main__":
main()