diff --git a/README.md b/README.md
index ec6888e..a9dc900 100644
--- a/README.md
+++ b/README.md
@@ -10,8 +10,10 @@ gis-classification/
├── requirements.txt # Dependencies
├── data/ # Input data folder
├── output/ # Classification results
+├── static/ # Web frontend
└── src/
├── classifier.py # Main classification pipeline
+ ├── api.py # FastAPI web server
├── data/
│ └── loader.py # Data loading (GeoTIFF, Shapefile)
├── strategies/
@@ -26,8 +28,15 @@ gis-classification/
pip install -r requirements.txt
```
+Or with uv:
+```bash
+uv pip install -r requirements.txt
+```
+
## Usage
+### CLI Mode
+
1. Place your input files in `data/`:
- `landsat.tif` - GeoTIFF from Landsat
- `polygons.shp` - Shapefile with class labels
@@ -42,6 +51,31 @@ pip install -r requirements.txt
python main.py
```
+### Web Interface
+
+Start the web server:
+
+```bash
+uvicorn src.api:app --host 0.0.0.0 --port 8000 --reload
+```
+
+Then open http://localhost:8000 in your browser.
+
+**Features:**
+- Upload GeoTIFF raster and Shapefile training data
+- Select classification strategy (Random Forest, SVM, Logistic Regression, MLE)
+- View training metrics (Accuracy, Cohen's Kappa)
+- Interactive map visualization with Leaflet
+- Download classified GeoTIFF results
+
+**API Endpoints:**
+- `GET /` - Web interface
+- `POST /train` - Train classifier with uploaded files
+- `POST /predict` - Run classification
+- `GET /result/{session_id}` - Get result metadata
+- `GET /result/{session_id}/download` - Download classified GeoTIFF
+- `GET /docs` - Interactive API documentation (Swagger UI)
+
## Adding Custom Classification Strategy
Create a new class implementing `ClassificationStrategy`:
@@ -54,17 +88,17 @@ class MyCustomStrategy(ClassificationStrategy):
def train(self, X: np.ndarray, y: np.ndarray) -> None:
# Your training logic
pass
-
+
def predict(self, X: np.ndarray) -> np.ndarray:
# Your prediction logic
pass
-
+
def predict_proba(self, X: np.ndarray) -> np.ndarray:
pass
-
+
def get_params(self) -> dict:
pass
-
+
@property
def name(self) -> str:
return "MyCustom"
@@ -78,4 +112,11 @@ STRATEGY = MyCustomStrategy()
## Output
- `output/classified.tif` - Classified raster (GeoTIFF)
-- Console output with accuracy metrics
+- Console output with accuracy metrics (Accuracy, Cohen's Kappa)
+- Web interface visualization with interactive map
+
+## Metrics
+
+The classifier reports:
+- **Accuracy**: Overall classification accuracy
+- **Cohen's Kappa**: Agreement statistic accounting for chance (values > 0.8 indicate excellent agreement)
diff --git a/main.py b/main.py
index ad24fda..2b84556 100644
--- a/main.py
+++ b/main.py
@@ -29,19 +29,19 @@ OUTPUT_PATH = os.path.join("output", "classified.tif")
# Change strategy by uncommenting desired option:
# Option 1: Random Forest (recommended for GIS)
-STRATEGY = RandomForestStrategy(
- n_estimators=100,
- max_depth=None,
- random_state=42,
-)
-
-# # Option 2: Support Vector Machine
-# STRATEGY = SVMStrategy(
-# kernel="linear", # Fast prediction; use 'rbf' for better accuracy but much slower
-# C=1.0,
+# STRATEGY = RandomForestStrategy(
+# n_estimators=100,
+# max_depth=None,
# random_state=42,
# )
+# # Option 2: Support Vector Machine
+STRATEGY = SVMStrategy(
+ kernel="linear", # Fast prediction; use 'rbf' for better accuracy but much slower
+ C=1.0,
+ random_state=42,
+)
+
# Option 3: Logistic Regression
# STRATEGY = LogisticRegressionStrategy(
# penalty="l2",
diff --git a/requirements.txt b/requirements.txt
index 7f959bc..97e2270 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,3 +1,4 @@
+# Core dependencies
rasterio>=1.3.0
geopandas>=0.12.0
shapely>=2.0.0
@@ -5,3 +6,12 @@ scikit-learn>=1.3.0
scipy>=1.10.0
numpy>=1.24.0
pandas>=2.0.0
+
+# Web API
+fastapi>=0.104.0
+uvicorn[standard]>=0.24.0
+python-multipart>=0.0.6
+pydantic>=2.0.0
+
+# Visualization (optional, for tile generation)
+matplotlib>=3.7.0
diff --git a/src/api.py b/src/api.py
new file mode 100644
index 0000000..ed324fa
--- /dev/null
+++ b/src/api.py
@@ -0,0 +1,398 @@
+"""FastAPI web API for GIS classification."""
+
+import os
+import io
+import tempfile
+import shutil
+from typing import Optional
+from fastapi import FastAPI, File, UploadFile, HTTPException, Form
+from fastapi.responses import JSONResponse, FileResponse, HTMLResponse, Response
+from fastapi.staticfiles import StaticFiles
+from pydantic import BaseModel
+import numpy as np
+import rasterio
+from rasterio.io import MemoryFile
+
+from .classifier import GISClassifier
+from .data import load_raster, load_vector, extract_raster_values_by_polygons
+from .strategies.classifiers import (
+ RandomForestStrategy,
+ SVMStrategy,
+ LogisticRegressionStrategy,
+ MLEStrategy,
+)
+
+
+app = FastAPI(
+ title="GIS Classification API",
+ description="Land cover classification using machine learning",
+ version="1.0.0",
+)
+
+# Store for active classifiers (in production, use proper session management)
+_classifiers: dict[str, dict] = {}
+
+
+class TrainRequest(BaseModel):
+ """Request model for training."""
+
+ strategy: str = "random_forest"
+ class_column: str = "class"
+ test_size: float = 0.2
+ random_state: int = 42
+ # Strategy-specific parameters
+ n_estimators: int = 100
+ max_depth: Optional[int] = None
+ kernel: str = "linear"
+ C: float = 1.0
+
+
+class TrainResponse(BaseModel):
+ """Response model for training."""
+
+ status: str
+ message: str
+ metrics: Optional[dict] = None
+
+
+def _get_strategy(name: str, params: dict) -> object:
+ """Create classification strategy by name."""
+ strategies = {
+ "random_forest": RandomForestStrategy,
+ "svm": SVMStrategy,
+ "logistic_regression": LogisticRegressionStrategy,
+ "mle": MLEStrategy,
+ }
+
+ if name not in strategies:
+ raise ValueError(f"Unknown strategy: {name}. Available: {list(strategies.keys())}")
+
+ StrategyClass = strategies[name]
+
+ # Map parameters based on strategy
+ if name == "random_forest":
+ return StrategyClass(
+ n_estimators=params.get("n_estimators", 100),
+ max_depth=params.get("max_depth"),
+ random_state=params.get("random_state", 42),
+ )
+ elif name == "svm":
+ return StrategyClass(
+ kernel=params.get("kernel", "linear"),
+ C=params.get("C", 1.0),
+ random_state=params.get("random_state", 42),
+ )
+ elif name == "logistic_regression":
+ return StrategyClass(
+ C=params.get("C", 1.0),
+ max_iter=params.get("max_iter", 1000),
+ random_state=params.get("random_state", 42),
+ )
+ elif name == "mle":
+ return StrategyClass(
+ reg_covar=params.get("reg_covar", 1e-6),
+ )
+
+ return StrategyClass()
+
+
+@app.get("/", response_class=HTMLResponse)
+async def root():
+ """Serve the frontend HTML page."""
+ static_path = os.path.join(os.path.dirname(__file__), "..", "static")
+ html_path = os.path.join(static_path, "index.html")
+
+ if os.path.exists(html_path):
+ with open(html_path, "r", encoding="utf-8") as f:
+ return HTMLResponse(content=f.read())
+
+ return HTMLResponse(
+ content="""
+
+
GIS Classification API
+
+ GIS Classification API
+ API is running. Visit /docs for interactive documentation.
+ Endpoints:
+
+ POST /train - Train classifier with uploaded files
+ POST /predict - Classify a raster
+ GET /result/{session_id} - Get classification result
+ GET /result/{session_id}/download - Download classified GeoTIFF
+
+
+
+ """,
+ )
+
+
+@app.post("/train")
+async def train(
+ raster: UploadFile = File(..., description="GeoTIFF raster file"),
+ vector_files: list[UploadFile] = File(..., description="Shapefile files (.shp, .shx, .dbf, .prj)"),
+ strategy: str = Form("random_forest"),
+ class_column: str = Form("class"),
+ test_size: float = Form(0.2),
+ random_state: int = Form(42),
+ n_estimators: int = Form(100),
+ kernel: str = Form("linear"),
+ C: float = Form(1.0),
+):
+ """
+ Train a classification model.
+
+ - **raster**: GeoTIFF file with multispectral data
+ - **vector_files**: Shapefile files (.shp, .shx, .dbf, .prj) - upload all together
+ - **strategy**: Classification algorithm (random_forest, svm, logistic_regression, mle)
+ - **class_column**: Column name in shapefile containing class labels
+ """
+ import uuid
+
+ session_id = str(uuid.uuid4())
+
+ try:
+ # Create temp directory for session
+ temp_dir = tempfile.mkdtemp(prefix=f"gis_{session_id}_")
+
+ # Save uploaded raster
+ raster_path = os.path.join(temp_dir, "raster.tif")
+ with open(raster_path, "wb") as f:
+ f.write(await raster.read())
+
+ # Save uploaded shapefile files with consistent base name
+ vector_base = os.path.join(temp_dir, "vector")
+ vector_path = f"{vector_base}.shp"
+
+ for file in vector_files:
+ # Get file extension
+ ext = os.path.splitext(file.filename)[1].lower()
+ # Save with consistent base name
+ file_path = f"{vector_base}{ext}"
+ with open(file_path, "wb") as f:
+ f.write(await file.read())
+
+ # Verify required shapefile files exist
+ required_exts = ['.shp', '.shx', '.dbf']
+ missing = [ext for ext in required_exts if not os.path.exists(f"{vector_base}{ext}")]
+ if missing:
+ raise ValueError(f"Missing required shapefile files: {', '.join(missing)}")
+
+ # Create strategy
+ strategy_obj = _get_strategy(
+ strategy,
+ {
+ "n_estimators": n_estimators,
+ "kernel": kernel,
+ "C": C,
+ "random_state": random_state,
+ },
+ )
+
+ # Train classifier
+ classifier = GISClassifier(strategy=strategy_obj)
+ metrics = classifier.train(
+ raster_path=raster_path,
+ vector_path=vector_path,
+ class_column=class_column,
+ test_size=test_size,
+ random_state=random_state,
+ )
+
+ # Store session data
+ _classifiers[session_id] = {
+ "classifier": classifier,
+ "raster_path": raster_path,
+ "vector_path": vector_path,
+ "temp_dir": temp_dir,
+ "result_path": None,
+ }
+
+ # Convert numpy types for JSON serialization
+ serializable_metrics = {
+ "train_samples": int(metrics["train_samples"]),
+ "val_samples": int(metrics["val_samples"]),
+ "accuracy": float(metrics["accuracy"]),
+ "kappa": float(metrics.get("kappa", 0)),
+ "classes": [int(c) if hasattr(c, "item") else c for c in metrics["classes"]],
+ }
+
+ return {
+ "status": "success",
+ "session_id": session_id,
+ "message": f"Training completed with {strategy} strategy",
+ "metrics": serializable_metrics,
+ }
+
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
+
+
+@app.post("/predict")
+async def predict(
+ session_id: str = Form(...),
+ output_format: str = Form("geotiff"),
+):
+ """
+ Run classification on the loaded raster.
+
+ - **session_id**: Session ID from /train endpoint
+ - **output_format**: Output format (geotiff, geojson)
+ """
+ if session_id not in _classifiers:
+ raise HTTPException(status_code=404, detail="Session not found. Please train first.")
+
+ session = _classifiers[session_id]
+ classifier = session["classifier"]
+ raster_path = session["raster_path"]
+
+ try:
+ # Generate output path
+ if output_format == "geotiff":
+ output_path = os.path.join(session["temp_dir"], "classified.tif")
+ else:
+ output_path = os.path.join(session["temp_dir"], "classified.tif")
+
+ # Run prediction
+ result = classifier.predict(raster_path=raster_path, output_path=output_path)
+
+ session["result_path"] = output_path
+ session["result"] = result
+
+ # Get raster bounds for frontend
+ with rasterio.open(raster_path) as src:
+ bounds = src.bounds
+ crs = str(src.crs)
+
+ return {
+ "status": "success",
+ "session_id": session_id,
+ "output_format": output_format,
+ "bounds": {
+ "left": bounds.left,
+ "bottom": bounds.bottom,
+ "right": bounds.right,
+ "top": bounds.top,
+ },
+ "crs": crs,
+ "shape": [int(result.predicted_array.shape[0]), int(result.predicted_array.shape[1])],
+ "classes": [int(c) if hasattr(c, "item") else c for c in result.classes],
+ "message": "Classification completed",
+ }
+
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
+
+
+@app.get("/result/{session_id}")
+async def get_result(session_id: str):
+ """Get classification result metadata."""
+ if session_id not in _classifiers:
+ raise HTTPException(status_code=404, detail="Session not found")
+
+ session = _classifiers[session_id]
+
+ if session.get("result_path") is None:
+ raise HTTPException(status_code=400, detail="No prediction result available")
+
+ result = session.get("result")
+ raster_path = session["raster_path"]
+
+ with rasterio.open(raster_path) as src:
+ bounds = src.bounds
+ crs = str(src.crs)
+
+ return {
+ "session_id": session_id,
+ "bounds": {
+ "left": bounds.left,
+ "bottom": bounds.bottom,
+ "right": bounds.right,
+ "top": bounds.top,
+ },
+ "crs": crs,
+ "classes": [int(c) if hasattr(c, "item") else c for c in result.classes],
+ "output_path": session["result_path"],
+ }
+
+
+@app.get("/result/{session_id}/download")
+async def download_result(session_id: str):
+ """Download classified GeoTIFF file."""
+ if session_id not in _classifiers:
+ raise HTTPException(status_code=404, detail="Session not found")
+
+ session = _classifiers[session_id]
+ result_path = session.get("result_path")
+
+ if result_path is None or not os.path.exists(result_path):
+ raise HTTPException(status_code=404, detail="Result file not found")
+
+ return FileResponse(
+ path=result_path,
+ media_type="application/geotiff",
+ filename="classified.tif",
+ )
+
+
+@app.get("/result/{session_id}/tile/{z}/{x}/{y}.png")
+async def get_tile(session_id: str, z: int, x: int, y: int):
+ """Get classification result as PNG tile (experimental)."""
+ if session_id not in _classifiers:
+ raise HTTPException(status_code=404, detail="Session not found")
+
+ session = _classifiers[session_id]
+ result_path = session.get("result_path")
+
+ if result_path is None or not os.path.exists(result_path):
+ raise HTTPException(status_code=404, detail="Result file not found")
+
+ try:
+ import matplotlib
+
+ matplotlib.use("Agg")
+ import matplotlib.pyplot as plt
+ from matplotlib.colors import ListedColormap
+ import rasterio
+ from rasterio.warp import transform_bounds
+ from rasterio.io import MemoryFile
+
+ with rasterio.open(result_path) as src:
+ # Read data
+ data = src.read(1)
+
+ # Create color map
+ n_classes = len(np.unique(data))
+ cmap = plt.cm.get_cmap("tab10", n_classes)
+
+ # Create RGBA image
+ fig, ax = plt.subplots(figsize=(1, 1))
+ ax.imshow(data, cmap=cmap, vmin=0, vmax=n_classes - 1)
+ ax.axis("off")
+
+ buf = io.BytesIO()
+ plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
+ buf.seek(0)
+
+ plt.close(fig)
+
+ return Response(content=buf.getvalue(), media_type="image/png")
+
+ except ImportError:
+ raise HTTPException(status_code=501, detail="Tile generation requires matplotlib")
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
+
+
+@app.get("/health")
+async def health():
+ """Health check endpoint."""
+ return {"status": "healthy", "version": "1.0.0"}
+
+
+@app.on_event("shutdown")
+async def cleanup():
+ """Clean up temporary files on shutdown."""
+ for session_id, session in _classifiers.items():
+ temp_dir = session.get("temp_dir")
+ if temp_dir and os.path.exists(temp_dir):
+ shutil.rmtree(temp_dir, ignore_errors=True)
diff --git a/static/index.html b/static/index.html
new file mode 100644
index 0000000..6fea884
--- /dev/null
+++ b/static/index.html
@@ -0,0 +1,994 @@
+
+
+
+
+
+
+ GIS Classification
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file