116 lines
3 KiB
Python
116 lines
3 KiB
Python
"""Main script for GIS classification.
|
|
|
|
Configure parameters below to run classification.
|
|
"""
|
|
|
|
import os
|
|
from src import (
|
|
GISClassifier,
|
|
RandomForestStrategy,
|
|
SVMStrategy,
|
|
LogisticRegressionStrategy,
|
|
MLEStrategy,
|
|
)
|
|
|
|
|
|
# ==================== PARAMETERS ====================
|
|
|
|
# Input files
|
|
RASTER_PATH = os.path.join("data", "landsat.tif") # Path to GeoTIFF (Landsat)
|
|
VECTOR_PATH = os.path.join("data", "polygons.shx") # Path to Shapefile
|
|
|
|
# Column in Shapefile containing class labels
|
|
CLASS_COLUMN = "macroclass"
|
|
|
|
# Output file for classified raster
|
|
OUTPUT_PATH = os.path.join("output", "classified.tif")
|
|
|
|
# Classification strategy parameters
|
|
# Change strategy by uncommenting desired option:
|
|
|
|
# Option 1: Random Forest (recommended for GIS)
|
|
STRATEGY = RandomForestStrategy(
|
|
n_estimators=100,
|
|
max_depth=None,
|
|
random_state=42,
|
|
)
|
|
|
|
# # Option 2: Support Vector Machine
|
|
# STRATEGY = SVMStrategy(
|
|
# kernel="linear", # Fast prediction; use 'rbf' for better accuracy but much slower
|
|
# C=1.0,
|
|
# random_state=42,
|
|
# )
|
|
|
|
# Option 3: Logistic Regression
|
|
# STRATEGY = LogisticRegressionStrategy(
|
|
# penalty="l2",
|
|
# C=1.0,
|
|
# max_iter=1000,
|
|
# random_state=42,
|
|
# )
|
|
|
|
# Option 4: Maximum Likelihood Estimation (classic for GIS)
|
|
# STRATEGY = MLEStrategy(
|
|
# reg_covar=1e-6,
|
|
# )
|
|
|
|
# Training parameters
|
|
TEST_SIZE = 0.2 # Fraction for validation
|
|
RANDOM_STATE = 42 # Random seed
|
|
|
|
# ==================== RUN CLASSIFICATION ====================
|
|
|
|
|
|
def main():
|
|
# Check input files exist
|
|
if not os.path.exists(RASTER_PATH):
|
|
print(f"Error: Raster file not found: {RASTER_PATH}")
|
|
return
|
|
|
|
if not os.path.exists(VECTOR_PATH):
|
|
print(f"Error: Vector file not found: {VECTOR_PATH}")
|
|
return
|
|
|
|
# Create output directory if needed
|
|
os.makedirs(os.path.dirname(OUTPUT_PATH), exist_ok=True)
|
|
|
|
# Initialize classifier with strategy
|
|
print(f"Using classification strategy: {STRATEGY.name}")
|
|
print(f"Strategy parameters: {STRATEGY.get_params()}")
|
|
print()
|
|
|
|
classifier = GISClassifier(strategy=STRATEGY)
|
|
|
|
# Train
|
|
print("Training classifier...")
|
|
metrics = classifier.train(
|
|
raster_path=RASTER_PATH,
|
|
vector_path=VECTOR_PATH,
|
|
class_column=CLASS_COLUMN,
|
|
test_size=TEST_SIZE,
|
|
random_state=RANDOM_STATE,
|
|
)
|
|
|
|
print(f"Training samples: {metrics['train_samples']}")
|
|
print(f"Validation samples: {metrics['val_samples']}")
|
|
print(f"Accuracy: {metrics['accuracy']:.2%}")
|
|
print(f"Cohen's Kappa: {metrics['kappa']:.4f}")
|
|
print(f"Classes: {metrics['classes']}")
|
|
print()
|
|
|
|
# Predict
|
|
print(f"Classifying raster: {RASTER_PATH}")
|
|
result = classifier.predict(
|
|
raster_path=RASTER_PATH,
|
|
output_path=OUTPUT_PATH,
|
|
)
|
|
|
|
print(f"Classification complete!")
|
|
print(f"Output saved to: {OUTPUT_PATH}")
|
|
print(f"Output shape: {result.predicted_array.shape}")
|
|
print(f"Classes in output: {result.classes}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|