diff --git a/README.md b/README.md index ec6888e..a9dc900 100644 --- a/README.md +++ b/README.md @@ -10,8 +10,10 @@ gis-classification/ ├── requirements.txt # Dependencies ├── data/ # Input data folder ├── output/ # Classification results +├── static/ # Web frontend └── src/ ├── classifier.py # Main classification pipeline + ├── api.py # FastAPI web server ├── data/ │ └── loader.py # Data loading (GeoTIFF, Shapefile) ├── strategies/ @@ -26,8 +28,15 @@ gis-classification/ pip install -r requirements.txt ``` +Or with uv: +```bash +uv pip install -r requirements.txt +``` + ## Usage +### CLI Mode + 1. Place your input files in `data/`: - `landsat.tif` - GeoTIFF from Landsat - `polygons.shp` - Shapefile with class labels @@ -42,6 +51,31 @@ pip install -r requirements.txt python main.py ``` +### Web Interface + +Start the web server: + +```bash +uvicorn src.api:app --host 0.0.0.0 --port 8000 --reload +``` + +Then open http://localhost:8000 in your browser. + +**Features:** +- Upload GeoTIFF raster and Shapefile training data +- Select classification strategy (Random Forest, SVM, Logistic Regression, MLE) +- View training metrics (Accuracy, Cohen's Kappa) +- Interactive map visualization with Leaflet +- Download classified GeoTIFF results + +**API Endpoints:** +- `GET /` - Web interface +- `POST /train` - Train classifier with uploaded files +- `POST /predict` - Run classification +- `GET /result/{session_id}` - Get result metadata +- `GET /result/{session_id}/download` - Download classified GeoTIFF +- `GET /docs` - Interactive API documentation (Swagger UI) + ## Adding Custom Classification Strategy Create a new class implementing `ClassificationStrategy`: @@ -54,17 +88,17 @@ class MyCustomStrategy(ClassificationStrategy): def train(self, X: np.ndarray, y: np.ndarray) -> None: # Your training logic pass - + def predict(self, X: np.ndarray) -> np.ndarray: # Your prediction logic pass - + def predict_proba(self, X: np.ndarray) -> np.ndarray: pass - + def get_params(self) -> dict: pass - + @property def name(self) -> str: return "MyCustom" @@ -78,4 +112,11 @@ STRATEGY = MyCustomStrategy() ## Output - `output/classified.tif` - Classified raster (GeoTIFF) -- Console output with accuracy metrics +- Console output with accuracy metrics (Accuracy, Cohen's Kappa) +- Web interface visualization with interactive map + +## Metrics + +The classifier reports: +- **Accuracy**: Overall classification accuracy +- **Cohen's Kappa**: Agreement statistic accounting for chance (values > 0.8 indicate excellent agreement) diff --git a/main.py b/main.py index ad24fda..2b84556 100644 --- a/main.py +++ b/main.py @@ -29,19 +29,19 @@ OUTPUT_PATH = os.path.join("output", "classified.tif") # Change strategy by uncommenting desired option: # Option 1: Random Forest (recommended for GIS) -STRATEGY = RandomForestStrategy( - n_estimators=100, - max_depth=None, - random_state=42, -) - -# # Option 2: Support Vector Machine -# STRATEGY = SVMStrategy( -# kernel="linear", # Fast prediction; use 'rbf' for better accuracy but much slower -# C=1.0, +# STRATEGY = RandomForestStrategy( +# n_estimators=100, +# max_depth=None, # random_state=42, # ) +# # Option 2: Support Vector Machine +STRATEGY = SVMStrategy( + kernel="linear", # Fast prediction; use 'rbf' for better accuracy but much slower + C=1.0, + random_state=42, +) + # Option 3: Logistic Regression # STRATEGY = LogisticRegressionStrategy( # penalty="l2", diff --git a/requirements.txt b/requirements.txt index 7f959bc..97e2270 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +# Core dependencies rasterio>=1.3.0 geopandas>=0.12.0 shapely>=2.0.0 @@ -5,3 +6,12 @@ scikit-learn>=1.3.0 scipy>=1.10.0 numpy>=1.24.0 pandas>=2.0.0 + +# Web API +fastapi>=0.104.0 +uvicorn[standard]>=0.24.0 +python-multipart>=0.0.6 +pydantic>=2.0.0 + +# Visualization (optional, for tile generation) +matplotlib>=3.7.0 diff --git a/src/api.py b/src/api.py new file mode 100644 index 0000000..ed324fa --- /dev/null +++ b/src/api.py @@ -0,0 +1,398 @@ +"""FastAPI web API for GIS classification.""" + +import os +import io +import tempfile +import shutil +from typing import Optional +from fastapi import FastAPI, File, UploadFile, HTTPException, Form +from fastapi.responses import JSONResponse, FileResponse, HTMLResponse, Response +from fastapi.staticfiles import StaticFiles +from pydantic import BaseModel +import numpy as np +import rasterio +from rasterio.io import MemoryFile + +from .classifier import GISClassifier +from .data import load_raster, load_vector, extract_raster_values_by_polygons +from .strategies.classifiers import ( + RandomForestStrategy, + SVMStrategy, + LogisticRegressionStrategy, + MLEStrategy, +) + + +app = FastAPI( + title="GIS Classification API", + description="Land cover classification using machine learning", + version="1.0.0", +) + +# Store for active classifiers (in production, use proper session management) +_classifiers: dict[str, dict] = {} + + +class TrainRequest(BaseModel): + """Request model for training.""" + + strategy: str = "random_forest" + class_column: str = "class" + test_size: float = 0.2 + random_state: int = 42 + # Strategy-specific parameters + n_estimators: int = 100 + max_depth: Optional[int] = None + kernel: str = "linear" + C: float = 1.0 + + +class TrainResponse(BaseModel): + """Response model for training.""" + + status: str + message: str + metrics: Optional[dict] = None + + +def _get_strategy(name: str, params: dict) -> object: + """Create classification strategy by name.""" + strategies = { + "random_forest": RandomForestStrategy, + "svm": SVMStrategy, + "logistic_regression": LogisticRegressionStrategy, + "mle": MLEStrategy, + } + + if name not in strategies: + raise ValueError(f"Unknown strategy: {name}. Available: {list(strategies.keys())}") + + StrategyClass = strategies[name] + + # Map parameters based on strategy + if name == "random_forest": + return StrategyClass( + n_estimators=params.get("n_estimators", 100), + max_depth=params.get("max_depth"), + random_state=params.get("random_state", 42), + ) + elif name == "svm": + return StrategyClass( + kernel=params.get("kernel", "linear"), + C=params.get("C", 1.0), + random_state=params.get("random_state", 42), + ) + elif name == "logistic_regression": + return StrategyClass( + C=params.get("C", 1.0), + max_iter=params.get("max_iter", 1000), + random_state=params.get("random_state", 42), + ) + elif name == "mle": + return StrategyClass( + reg_covar=params.get("reg_covar", 1e-6), + ) + + return StrategyClass() + + +@app.get("/", response_class=HTMLResponse) +async def root(): + """Serve the frontend HTML page.""" + static_path = os.path.join(os.path.dirname(__file__), "..", "static") + html_path = os.path.join(static_path, "index.html") + + if os.path.exists(html_path): + with open(html_path, "r", encoding="utf-8") as f: + return HTMLResponse(content=f.read()) + + return HTMLResponse( + content=""" + + GIS Classification API + +

GIS Classification API

+

API is running. Visit /docs for interactive documentation.

+

Endpoints:

+ + + + """, + ) + + +@app.post("/train") +async def train( + raster: UploadFile = File(..., description="GeoTIFF raster file"), + vector_files: list[UploadFile] = File(..., description="Shapefile files (.shp, .shx, .dbf, .prj)"), + strategy: str = Form("random_forest"), + class_column: str = Form("class"), + test_size: float = Form(0.2), + random_state: int = Form(42), + n_estimators: int = Form(100), + kernel: str = Form("linear"), + C: float = Form(1.0), +): + """ + Train a classification model. + + - **raster**: GeoTIFF file with multispectral data + - **vector_files**: Shapefile files (.shp, .shx, .dbf, .prj) - upload all together + - **strategy**: Classification algorithm (random_forest, svm, logistic_regression, mle) + - **class_column**: Column name in shapefile containing class labels + """ + import uuid + + session_id = str(uuid.uuid4()) + + try: + # Create temp directory for session + temp_dir = tempfile.mkdtemp(prefix=f"gis_{session_id}_") + + # Save uploaded raster + raster_path = os.path.join(temp_dir, "raster.tif") + with open(raster_path, "wb") as f: + f.write(await raster.read()) + + # Save uploaded shapefile files with consistent base name + vector_base = os.path.join(temp_dir, "vector") + vector_path = f"{vector_base}.shp" + + for file in vector_files: + # Get file extension + ext = os.path.splitext(file.filename)[1].lower() + # Save with consistent base name + file_path = f"{vector_base}{ext}" + with open(file_path, "wb") as f: + f.write(await file.read()) + + # Verify required shapefile files exist + required_exts = ['.shp', '.shx', '.dbf'] + missing = [ext for ext in required_exts if not os.path.exists(f"{vector_base}{ext}")] + if missing: + raise ValueError(f"Missing required shapefile files: {', '.join(missing)}") + + # Create strategy + strategy_obj = _get_strategy( + strategy, + { + "n_estimators": n_estimators, + "kernel": kernel, + "C": C, + "random_state": random_state, + }, + ) + + # Train classifier + classifier = GISClassifier(strategy=strategy_obj) + metrics = classifier.train( + raster_path=raster_path, + vector_path=vector_path, + class_column=class_column, + test_size=test_size, + random_state=random_state, + ) + + # Store session data + _classifiers[session_id] = { + "classifier": classifier, + "raster_path": raster_path, + "vector_path": vector_path, + "temp_dir": temp_dir, + "result_path": None, + } + + # Convert numpy types for JSON serialization + serializable_metrics = { + "train_samples": int(metrics["train_samples"]), + "val_samples": int(metrics["val_samples"]), + "accuracy": float(metrics["accuracy"]), + "kappa": float(metrics.get("kappa", 0)), + "classes": [int(c) if hasattr(c, "item") else c for c in metrics["classes"]], + } + + return { + "status": "success", + "session_id": session_id, + "message": f"Training completed with {strategy} strategy", + "metrics": serializable_metrics, + } + + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.post("/predict") +async def predict( + session_id: str = Form(...), + output_format: str = Form("geotiff"), +): + """ + Run classification on the loaded raster. + + - **session_id**: Session ID from /train endpoint + - **output_format**: Output format (geotiff, geojson) + """ + if session_id not in _classifiers: + raise HTTPException(status_code=404, detail="Session not found. Please train first.") + + session = _classifiers[session_id] + classifier = session["classifier"] + raster_path = session["raster_path"] + + try: + # Generate output path + if output_format == "geotiff": + output_path = os.path.join(session["temp_dir"], "classified.tif") + else: + output_path = os.path.join(session["temp_dir"], "classified.tif") + + # Run prediction + result = classifier.predict(raster_path=raster_path, output_path=output_path) + + session["result_path"] = output_path + session["result"] = result + + # Get raster bounds for frontend + with rasterio.open(raster_path) as src: + bounds = src.bounds + crs = str(src.crs) + + return { + "status": "success", + "session_id": session_id, + "output_format": output_format, + "bounds": { + "left": bounds.left, + "bottom": bounds.bottom, + "right": bounds.right, + "top": bounds.top, + }, + "crs": crs, + "shape": [int(result.predicted_array.shape[0]), int(result.predicted_array.shape[1])], + "classes": [int(c) if hasattr(c, "item") else c for c in result.classes], + "message": "Classification completed", + } + + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/result/{session_id}") +async def get_result(session_id: str): + """Get classification result metadata.""" + if session_id not in _classifiers: + raise HTTPException(status_code=404, detail="Session not found") + + session = _classifiers[session_id] + + if session.get("result_path") is None: + raise HTTPException(status_code=400, detail="No prediction result available") + + result = session.get("result") + raster_path = session["raster_path"] + + with rasterio.open(raster_path) as src: + bounds = src.bounds + crs = str(src.crs) + + return { + "session_id": session_id, + "bounds": { + "left": bounds.left, + "bottom": bounds.bottom, + "right": bounds.right, + "top": bounds.top, + }, + "crs": crs, + "classes": [int(c) if hasattr(c, "item") else c for c in result.classes], + "output_path": session["result_path"], + } + + +@app.get("/result/{session_id}/download") +async def download_result(session_id: str): + """Download classified GeoTIFF file.""" + if session_id not in _classifiers: + raise HTTPException(status_code=404, detail="Session not found") + + session = _classifiers[session_id] + result_path = session.get("result_path") + + if result_path is None or not os.path.exists(result_path): + raise HTTPException(status_code=404, detail="Result file not found") + + return FileResponse( + path=result_path, + media_type="application/geotiff", + filename="classified.tif", + ) + + +@app.get("/result/{session_id}/tile/{z}/{x}/{y}.png") +async def get_tile(session_id: str, z: int, x: int, y: int): + """Get classification result as PNG tile (experimental).""" + if session_id not in _classifiers: + raise HTTPException(status_code=404, detail="Session not found") + + session = _classifiers[session_id] + result_path = session.get("result_path") + + if result_path is None or not os.path.exists(result_path): + raise HTTPException(status_code=404, detail="Result file not found") + + try: + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + from matplotlib.colors import ListedColormap + import rasterio + from rasterio.warp import transform_bounds + from rasterio.io import MemoryFile + + with rasterio.open(result_path) as src: + # Read data + data = src.read(1) + + # Create color map + n_classes = len(np.unique(data)) + cmap = plt.cm.get_cmap("tab10", n_classes) + + # Create RGBA image + fig, ax = plt.subplots(figsize=(1, 1)) + ax.imshow(data, cmap=cmap, vmin=0, vmax=n_classes - 1) + ax.axis("off") + + buf = io.BytesIO() + plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0) + buf.seek(0) + + plt.close(fig) + + return Response(content=buf.getvalue(), media_type="image/png") + + except ImportError: + raise HTTPException(status_code=501, detail="Tile generation requires matplotlib") + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/health") +async def health(): + """Health check endpoint.""" + return {"status": "healthy", "version": "1.0.0"} + + +@app.on_event("shutdown") +async def cleanup(): + """Clean up temporary files on shutdown.""" + for session_id, session in _classifiers.items(): + temp_dir = session.get("temp_dir") + if temp_dir and os.path.exists(temp_dir): + shutil.rmtree(temp_dir, ignore_errors=True) diff --git a/static/index.html b/static/index.html new file mode 100644 index 0000000..6fea884 --- /dev/null +++ b/static/index.html @@ -0,0 +1,994 @@ + + + + + + + GIS Classification + + + + + + + + +
+ + +
+
+ +
+
+ + + + + + + + + + + + \ No newline at end of file