feat: add FastAPI web interface for GIS classification
This commit is contained in:
parent
5a9b8469bd
commit
6815769d2b
5 changed files with 1458 additions and 15 deletions
51
README.md
51
README.md
|
|
@ -10,8 +10,10 @@ gis-classification/
|
||||||
├── requirements.txt # Dependencies
|
├── requirements.txt # Dependencies
|
||||||
├── data/ # Input data folder
|
├── data/ # Input data folder
|
||||||
├── output/ # Classification results
|
├── output/ # Classification results
|
||||||
|
├── static/ # Web frontend
|
||||||
└── src/
|
└── src/
|
||||||
├── classifier.py # Main classification pipeline
|
├── classifier.py # Main classification pipeline
|
||||||
|
├── api.py # FastAPI web server
|
||||||
├── data/
|
├── data/
|
||||||
│ └── loader.py # Data loading (GeoTIFF, Shapefile)
|
│ └── loader.py # Data loading (GeoTIFF, Shapefile)
|
||||||
├── strategies/
|
├── strategies/
|
||||||
|
|
@ -26,8 +28,15 @@ gis-classification/
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Or with uv:
|
||||||
|
```bash
|
||||||
|
uv pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
|
### CLI Mode
|
||||||
|
|
||||||
1. Place your input files in `data/`:
|
1. Place your input files in `data/`:
|
||||||
- `landsat.tif` - GeoTIFF from Landsat
|
- `landsat.tif` - GeoTIFF from Landsat
|
||||||
- `polygons.shp` - Shapefile with class labels
|
- `polygons.shp` - Shapefile with class labels
|
||||||
|
|
@ -42,6 +51,31 @@ pip install -r requirements.txt
|
||||||
python main.py
|
python main.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Web Interface
|
||||||
|
|
||||||
|
Start the web server:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uvicorn src.api:app --host 0.0.0.0 --port 8000 --reload
|
||||||
|
```
|
||||||
|
|
||||||
|
Then open http://localhost:8000 in your browser.
|
||||||
|
|
||||||
|
**Features:**
|
||||||
|
- Upload GeoTIFF raster and Shapefile training data
|
||||||
|
- Select classification strategy (Random Forest, SVM, Logistic Regression, MLE)
|
||||||
|
- View training metrics (Accuracy, Cohen's Kappa)
|
||||||
|
- Interactive map visualization with Leaflet
|
||||||
|
- Download classified GeoTIFF results
|
||||||
|
|
||||||
|
**API Endpoints:**
|
||||||
|
- `GET /` - Web interface
|
||||||
|
- `POST /train` - Train classifier with uploaded files
|
||||||
|
- `POST /predict` - Run classification
|
||||||
|
- `GET /result/{session_id}` - Get result metadata
|
||||||
|
- `GET /result/{session_id}/download` - Download classified GeoTIFF
|
||||||
|
- `GET /docs` - Interactive API documentation (Swagger UI)
|
||||||
|
|
||||||
## Adding Custom Classification Strategy
|
## Adding Custom Classification Strategy
|
||||||
|
|
||||||
Create a new class implementing `ClassificationStrategy`:
|
Create a new class implementing `ClassificationStrategy`:
|
||||||
|
|
@ -54,17 +88,17 @@ class MyCustomStrategy(ClassificationStrategy):
|
||||||
def train(self, X: np.ndarray, y: np.ndarray) -> None:
|
def train(self, X: np.ndarray, y: np.ndarray) -> None:
|
||||||
# Your training logic
|
# Your training logic
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def predict(self, X: np.ndarray) -> np.ndarray:
|
def predict(self, X: np.ndarray) -> np.ndarray:
|
||||||
# Your prediction logic
|
# Your prediction logic
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def predict_proba(self, X: np.ndarray) -> np.ndarray:
|
def predict_proba(self, X: np.ndarray) -> np.ndarray:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_params(self) -> dict:
|
def get_params(self) -> dict:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return "MyCustom"
|
return "MyCustom"
|
||||||
|
|
@ -78,4 +112,11 @@ STRATEGY = MyCustomStrategy()
|
||||||
## Output
|
## Output
|
||||||
|
|
||||||
- `output/classified.tif` - Classified raster (GeoTIFF)
|
- `output/classified.tif` - Classified raster (GeoTIFF)
|
||||||
- Console output with accuracy metrics
|
- Console output with accuracy metrics (Accuracy, Cohen's Kappa)
|
||||||
|
- Web interface visualization with interactive map
|
||||||
|
|
||||||
|
## Metrics
|
||||||
|
|
||||||
|
The classifier reports:
|
||||||
|
- **Accuracy**: Overall classification accuracy
|
||||||
|
- **Cohen's Kappa**: Agreement statistic accounting for chance (values > 0.8 indicate excellent agreement)
|
||||||
|
|
|
||||||
20
main.py
20
main.py
|
|
@ -29,19 +29,19 @@ OUTPUT_PATH = os.path.join("output", "classified.tif")
|
||||||
# Change strategy by uncommenting desired option:
|
# Change strategy by uncommenting desired option:
|
||||||
|
|
||||||
# Option 1: Random Forest (recommended for GIS)
|
# Option 1: Random Forest (recommended for GIS)
|
||||||
STRATEGY = RandomForestStrategy(
|
# STRATEGY = RandomForestStrategy(
|
||||||
n_estimators=100,
|
# n_estimators=100,
|
||||||
max_depth=None,
|
# max_depth=None,
|
||||||
random_state=42,
|
|
||||||
)
|
|
||||||
|
|
||||||
# # Option 2: Support Vector Machine
|
|
||||||
# STRATEGY = SVMStrategy(
|
|
||||||
# kernel="linear", # Fast prediction; use 'rbf' for better accuracy but much slower
|
|
||||||
# C=1.0,
|
|
||||||
# random_state=42,
|
# random_state=42,
|
||||||
# )
|
# )
|
||||||
|
|
||||||
|
# # Option 2: Support Vector Machine
|
||||||
|
STRATEGY = SVMStrategy(
|
||||||
|
kernel="linear", # Fast prediction; use 'rbf' for better accuracy but much slower
|
||||||
|
C=1.0,
|
||||||
|
random_state=42,
|
||||||
|
)
|
||||||
|
|
||||||
# Option 3: Logistic Regression
|
# Option 3: Logistic Regression
|
||||||
# STRATEGY = LogisticRegressionStrategy(
|
# STRATEGY = LogisticRegressionStrategy(
|
||||||
# penalty="l2",
|
# penalty="l2",
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
# Core dependencies
|
||||||
rasterio>=1.3.0
|
rasterio>=1.3.0
|
||||||
geopandas>=0.12.0
|
geopandas>=0.12.0
|
||||||
shapely>=2.0.0
|
shapely>=2.0.0
|
||||||
|
|
@ -5,3 +6,12 @@ scikit-learn>=1.3.0
|
||||||
scipy>=1.10.0
|
scipy>=1.10.0
|
||||||
numpy>=1.24.0
|
numpy>=1.24.0
|
||||||
pandas>=2.0.0
|
pandas>=2.0.0
|
||||||
|
|
||||||
|
# Web API
|
||||||
|
fastapi>=0.104.0
|
||||||
|
uvicorn[standard]>=0.24.0
|
||||||
|
python-multipart>=0.0.6
|
||||||
|
pydantic>=2.0.0
|
||||||
|
|
||||||
|
# Visualization (optional, for tile generation)
|
||||||
|
matplotlib>=3.7.0
|
||||||
|
|
|
||||||
398
src/api.py
Normal file
398
src/api.py
Normal file
|
|
@ -0,0 +1,398 @@
|
||||||
|
"""FastAPI web API for GIS classification."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import io
|
||||||
|
import tempfile
|
||||||
|
import shutil
|
||||||
|
from typing import Optional
|
||||||
|
from fastapi import FastAPI, File, UploadFile, HTTPException, Form
|
||||||
|
from fastapi.responses import JSONResponse, FileResponse, HTMLResponse, Response
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
from pydantic import BaseModel
|
||||||
|
import numpy as np
|
||||||
|
import rasterio
|
||||||
|
from rasterio.io import MemoryFile
|
||||||
|
|
||||||
|
from .classifier import GISClassifier
|
||||||
|
from .data import load_raster, load_vector, extract_raster_values_by_polygons
|
||||||
|
from .strategies.classifiers import (
|
||||||
|
RandomForestStrategy,
|
||||||
|
SVMStrategy,
|
||||||
|
LogisticRegressionStrategy,
|
||||||
|
MLEStrategy,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="GIS Classification API",
|
||||||
|
description="Land cover classification using machine learning",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store for active classifiers (in production, use proper session management)
|
||||||
|
_classifiers: dict[str, dict] = {}
|
||||||
|
|
||||||
|
|
||||||
|
class TrainRequest(BaseModel):
|
||||||
|
"""Request model for training."""
|
||||||
|
|
||||||
|
strategy: str = "random_forest"
|
||||||
|
class_column: str = "class"
|
||||||
|
test_size: float = 0.2
|
||||||
|
random_state: int = 42
|
||||||
|
# Strategy-specific parameters
|
||||||
|
n_estimators: int = 100
|
||||||
|
max_depth: Optional[int] = None
|
||||||
|
kernel: str = "linear"
|
||||||
|
C: float = 1.0
|
||||||
|
|
||||||
|
|
||||||
|
class TrainResponse(BaseModel):
|
||||||
|
"""Response model for training."""
|
||||||
|
|
||||||
|
status: str
|
||||||
|
message: str
|
||||||
|
metrics: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_strategy(name: str, params: dict) -> object:
|
||||||
|
"""Create classification strategy by name."""
|
||||||
|
strategies = {
|
||||||
|
"random_forest": RandomForestStrategy,
|
||||||
|
"svm": SVMStrategy,
|
||||||
|
"logistic_regression": LogisticRegressionStrategy,
|
||||||
|
"mle": MLEStrategy,
|
||||||
|
}
|
||||||
|
|
||||||
|
if name not in strategies:
|
||||||
|
raise ValueError(f"Unknown strategy: {name}. Available: {list(strategies.keys())}")
|
||||||
|
|
||||||
|
StrategyClass = strategies[name]
|
||||||
|
|
||||||
|
# Map parameters based on strategy
|
||||||
|
if name == "random_forest":
|
||||||
|
return StrategyClass(
|
||||||
|
n_estimators=params.get("n_estimators", 100),
|
||||||
|
max_depth=params.get("max_depth"),
|
||||||
|
random_state=params.get("random_state", 42),
|
||||||
|
)
|
||||||
|
elif name == "svm":
|
||||||
|
return StrategyClass(
|
||||||
|
kernel=params.get("kernel", "linear"),
|
||||||
|
C=params.get("C", 1.0),
|
||||||
|
random_state=params.get("random_state", 42),
|
||||||
|
)
|
||||||
|
elif name == "logistic_regression":
|
||||||
|
return StrategyClass(
|
||||||
|
C=params.get("C", 1.0),
|
||||||
|
max_iter=params.get("max_iter", 1000),
|
||||||
|
random_state=params.get("random_state", 42),
|
||||||
|
)
|
||||||
|
elif name == "mle":
|
||||||
|
return StrategyClass(
|
||||||
|
reg_covar=params.get("reg_covar", 1e-6),
|
||||||
|
)
|
||||||
|
|
||||||
|
return StrategyClass()
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/", response_class=HTMLResponse)
|
||||||
|
async def root():
|
||||||
|
"""Serve the frontend HTML page."""
|
||||||
|
static_path = os.path.join(os.path.dirname(__file__), "..", "static")
|
||||||
|
html_path = os.path.join(static_path, "index.html")
|
||||||
|
|
||||||
|
if os.path.exists(html_path):
|
||||||
|
with open(html_path, "r", encoding="utf-8") as f:
|
||||||
|
return HTMLResponse(content=f.read())
|
||||||
|
|
||||||
|
return HTMLResponse(
|
||||||
|
content="""
|
||||||
|
<html>
|
||||||
|
<head><title>GIS Classification API</title></head>
|
||||||
|
<body>
|
||||||
|
<h1>GIS Classification API</h1>
|
||||||
|
<p>API is running. Visit <a href="/docs">/docs</a> for interactive documentation.</p>
|
||||||
|
<h2>Endpoints:</h2>
|
||||||
|
<ul>
|
||||||
|
<li><code>POST /train</code> - Train classifier with uploaded files</li>
|
||||||
|
<li><code>POST /predict</code> - Classify a raster</li>
|
||||||
|
<li><code>GET /result/{session_id}</code> - Get classification result</li>
|
||||||
|
<li><code>GET /result/{session_id}/download</code> - Download classified GeoTIFF</li>
|
||||||
|
</ul>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/train")
|
||||||
|
async def train(
|
||||||
|
raster: UploadFile = File(..., description="GeoTIFF raster file"),
|
||||||
|
vector_files: list[UploadFile] = File(..., description="Shapefile files (.shp, .shx, .dbf, .prj)"),
|
||||||
|
strategy: str = Form("random_forest"),
|
||||||
|
class_column: str = Form("class"),
|
||||||
|
test_size: float = Form(0.2),
|
||||||
|
random_state: int = Form(42),
|
||||||
|
n_estimators: int = Form(100),
|
||||||
|
kernel: str = Form("linear"),
|
||||||
|
C: float = Form(1.0),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Train a classification model.
|
||||||
|
|
||||||
|
- **raster**: GeoTIFF file with multispectral data
|
||||||
|
- **vector_files**: Shapefile files (.shp, .shx, .dbf, .prj) - upload all together
|
||||||
|
- **strategy**: Classification algorithm (random_forest, svm, logistic_regression, mle)
|
||||||
|
- **class_column**: Column name in shapefile containing class labels
|
||||||
|
"""
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create temp directory for session
|
||||||
|
temp_dir = tempfile.mkdtemp(prefix=f"gis_{session_id}_")
|
||||||
|
|
||||||
|
# Save uploaded raster
|
||||||
|
raster_path = os.path.join(temp_dir, "raster.tif")
|
||||||
|
with open(raster_path, "wb") as f:
|
||||||
|
f.write(await raster.read())
|
||||||
|
|
||||||
|
# Save uploaded shapefile files with consistent base name
|
||||||
|
vector_base = os.path.join(temp_dir, "vector")
|
||||||
|
vector_path = f"{vector_base}.shp"
|
||||||
|
|
||||||
|
for file in vector_files:
|
||||||
|
# Get file extension
|
||||||
|
ext = os.path.splitext(file.filename)[1].lower()
|
||||||
|
# Save with consistent base name
|
||||||
|
file_path = f"{vector_base}{ext}"
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
f.write(await file.read())
|
||||||
|
|
||||||
|
# Verify required shapefile files exist
|
||||||
|
required_exts = ['.shp', '.shx', '.dbf']
|
||||||
|
missing = [ext for ext in required_exts if not os.path.exists(f"{vector_base}{ext}")]
|
||||||
|
if missing:
|
||||||
|
raise ValueError(f"Missing required shapefile files: {', '.join(missing)}")
|
||||||
|
|
||||||
|
# Create strategy
|
||||||
|
strategy_obj = _get_strategy(
|
||||||
|
strategy,
|
||||||
|
{
|
||||||
|
"n_estimators": n_estimators,
|
||||||
|
"kernel": kernel,
|
||||||
|
"C": C,
|
||||||
|
"random_state": random_state,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Train classifier
|
||||||
|
classifier = GISClassifier(strategy=strategy_obj)
|
||||||
|
metrics = classifier.train(
|
||||||
|
raster_path=raster_path,
|
||||||
|
vector_path=vector_path,
|
||||||
|
class_column=class_column,
|
||||||
|
test_size=test_size,
|
||||||
|
random_state=random_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store session data
|
||||||
|
_classifiers[session_id] = {
|
||||||
|
"classifier": classifier,
|
||||||
|
"raster_path": raster_path,
|
||||||
|
"vector_path": vector_path,
|
||||||
|
"temp_dir": temp_dir,
|
||||||
|
"result_path": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Convert numpy types for JSON serialization
|
||||||
|
serializable_metrics = {
|
||||||
|
"train_samples": int(metrics["train_samples"]),
|
||||||
|
"val_samples": int(metrics["val_samples"]),
|
||||||
|
"accuracy": float(metrics["accuracy"]),
|
||||||
|
"kappa": float(metrics.get("kappa", 0)),
|
||||||
|
"classes": [int(c) if hasattr(c, "item") else c for c in metrics["classes"]],
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": f"Training completed with {strategy} strategy",
|
||||||
|
"metrics": serializable_metrics,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/predict")
|
||||||
|
async def predict(
|
||||||
|
session_id: str = Form(...),
|
||||||
|
output_format: str = Form("geotiff"),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Run classification on the loaded raster.
|
||||||
|
|
||||||
|
- **session_id**: Session ID from /train endpoint
|
||||||
|
- **output_format**: Output format (geotiff, geojson)
|
||||||
|
"""
|
||||||
|
if session_id not in _classifiers:
|
||||||
|
raise HTTPException(status_code=404, detail="Session not found. Please train first.")
|
||||||
|
|
||||||
|
session = _classifiers[session_id]
|
||||||
|
classifier = session["classifier"]
|
||||||
|
raster_path = session["raster_path"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Generate output path
|
||||||
|
if output_format == "geotiff":
|
||||||
|
output_path = os.path.join(session["temp_dir"], "classified.tif")
|
||||||
|
else:
|
||||||
|
output_path = os.path.join(session["temp_dir"], "classified.tif")
|
||||||
|
|
||||||
|
# Run prediction
|
||||||
|
result = classifier.predict(raster_path=raster_path, output_path=output_path)
|
||||||
|
|
||||||
|
session["result_path"] = output_path
|
||||||
|
session["result"] = result
|
||||||
|
|
||||||
|
# Get raster bounds for frontend
|
||||||
|
with rasterio.open(raster_path) as src:
|
||||||
|
bounds = src.bounds
|
||||||
|
crs = str(src.crs)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"session_id": session_id,
|
||||||
|
"output_format": output_format,
|
||||||
|
"bounds": {
|
||||||
|
"left": bounds.left,
|
||||||
|
"bottom": bounds.bottom,
|
||||||
|
"right": bounds.right,
|
||||||
|
"top": bounds.top,
|
||||||
|
},
|
||||||
|
"crs": crs,
|
||||||
|
"shape": [int(result.predicted_array.shape[0]), int(result.predicted_array.shape[1])],
|
||||||
|
"classes": [int(c) if hasattr(c, "item") else c for c in result.classes],
|
||||||
|
"message": "Classification completed",
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/result/{session_id}")
|
||||||
|
async def get_result(session_id: str):
|
||||||
|
"""Get classification result metadata."""
|
||||||
|
if session_id not in _classifiers:
|
||||||
|
raise HTTPException(status_code=404, detail="Session not found")
|
||||||
|
|
||||||
|
session = _classifiers[session_id]
|
||||||
|
|
||||||
|
if session.get("result_path") is None:
|
||||||
|
raise HTTPException(status_code=400, detail="No prediction result available")
|
||||||
|
|
||||||
|
result = session.get("result")
|
||||||
|
raster_path = session["raster_path"]
|
||||||
|
|
||||||
|
with rasterio.open(raster_path) as src:
|
||||||
|
bounds = src.bounds
|
||||||
|
crs = str(src.crs)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"bounds": {
|
||||||
|
"left": bounds.left,
|
||||||
|
"bottom": bounds.bottom,
|
||||||
|
"right": bounds.right,
|
||||||
|
"top": bounds.top,
|
||||||
|
},
|
||||||
|
"crs": crs,
|
||||||
|
"classes": [int(c) if hasattr(c, "item") else c for c in result.classes],
|
||||||
|
"output_path": session["result_path"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/result/{session_id}/download")
|
||||||
|
async def download_result(session_id: str):
|
||||||
|
"""Download classified GeoTIFF file."""
|
||||||
|
if session_id not in _classifiers:
|
||||||
|
raise HTTPException(status_code=404, detail="Session not found")
|
||||||
|
|
||||||
|
session = _classifiers[session_id]
|
||||||
|
result_path = session.get("result_path")
|
||||||
|
|
||||||
|
if result_path is None or not os.path.exists(result_path):
|
||||||
|
raise HTTPException(status_code=404, detail="Result file not found")
|
||||||
|
|
||||||
|
return FileResponse(
|
||||||
|
path=result_path,
|
||||||
|
media_type="application/geotiff",
|
||||||
|
filename="classified.tif",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/result/{session_id}/tile/{z}/{x}/{y}.png")
|
||||||
|
async def get_tile(session_id: str, z: int, x: int, y: int):
|
||||||
|
"""Get classification result as PNG tile (experimental)."""
|
||||||
|
if session_id not in _classifiers:
|
||||||
|
raise HTTPException(status_code=404, detail="Session not found")
|
||||||
|
|
||||||
|
session = _classifiers[session_id]
|
||||||
|
result_path = session.get("result_path")
|
||||||
|
|
||||||
|
if result_path is None or not os.path.exists(result_path):
|
||||||
|
raise HTTPException(status_code=404, detail="Result file not found")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import matplotlib
|
||||||
|
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib.colors import ListedColormap
|
||||||
|
import rasterio
|
||||||
|
from rasterio.warp import transform_bounds
|
||||||
|
from rasterio.io import MemoryFile
|
||||||
|
|
||||||
|
with rasterio.open(result_path) as src:
|
||||||
|
# Read data
|
||||||
|
data = src.read(1)
|
||||||
|
|
||||||
|
# Create color map
|
||||||
|
n_classes = len(np.unique(data))
|
||||||
|
cmap = plt.cm.get_cmap("tab10", n_classes)
|
||||||
|
|
||||||
|
# Create RGBA image
|
||||||
|
fig, ax = plt.subplots(figsize=(1, 1))
|
||||||
|
ax.imshow(data, cmap=cmap, vmin=0, vmax=n_classes - 1)
|
||||||
|
ax.axis("off")
|
||||||
|
|
||||||
|
buf = io.BytesIO()
|
||||||
|
plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
|
||||||
|
buf.seek(0)
|
||||||
|
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
return Response(content=buf.getvalue(), media_type="image/png")
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
raise HTTPException(status_code=501, detail="Tile generation requires matplotlib")
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health():
|
||||||
|
"""Health check endpoint."""
|
||||||
|
return {"status": "healthy", "version": "1.0.0"}
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("shutdown")
|
||||||
|
async def cleanup():
|
||||||
|
"""Clean up temporary files on shutdown."""
|
||||||
|
for session_id, session in _classifiers.items():
|
||||||
|
temp_dir = session.get("temp_dir")
|
||||||
|
if temp_dir and os.path.exists(temp_dir):
|
||||||
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||||
994
static/index.html
Normal file
994
static/index.html
Normal file
|
|
@ -0,0 +1,994 @@
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>GIS Classification</title>
|
||||||
|
|
||||||
|
<!-- Leaflet CSS -->
|
||||||
|
<link rel="stylesheet" href="https://unpkg.com/leaflet@1.9.4/dist/leaflet.css" />
|
||||||
|
|
||||||
|
<style>
|
||||||
|
* {
|
||||||
|
margin: 0;
|
||||||
|
padding: 0;
|
||||||
|
box-sizing: border-box;
|
||||||
|
}
|
||||||
|
|
||||||
|
body {
|
||||||
|
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
|
||||||
|
background: #1a1a2e;
|
||||||
|
color: #eee;
|
||||||
|
min-height: 100vh;
|
||||||
|
}
|
||||||
|
|
||||||
|
.container {
|
||||||
|
display: flex;
|
||||||
|
height: 100vh;
|
||||||
|
}
|
||||||
|
|
||||||
|
.sidebar {
|
||||||
|
width: 400px;
|
||||||
|
background: #16213e;
|
||||||
|
padding: 20px;
|
||||||
|
overflow-y: auto;
|
||||||
|
border-right: 1px solid #0f3460;
|
||||||
|
}
|
||||||
|
|
||||||
|
.sidebar h1 {
|
||||||
|
font-size: 1.5rem;
|
||||||
|
margin-bottom: 10px;
|
||||||
|
color: #e94560;
|
||||||
|
}
|
||||||
|
|
||||||
|
.sidebar h2 {
|
||||||
|
font-size: 1.1rem;
|
||||||
|
margin: 20px 0 10px;
|
||||||
|
color: #0f3460;
|
||||||
|
border-bottom: 2px solid #e94560;
|
||||||
|
padding-bottom: 5px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.form-group {
|
||||||
|
margin-bottom: 15px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.form-group label {
|
||||||
|
display: block;
|
||||||
|
margin-bottom: 5px;
|
||||||
|
font-size: 0.9rem;
|
||||||
|
color: #aaa;
|
||||||
|
}
|
||||||
|
|
||||||
|
.form-group input,
|
||||||
|
.form-group select {
|
||||||
|
width: 100%;
|
||||||
|
padding: 10px;
|
||||||
|
border: 1px solid #0f3460;
|
||||||
|
border-radius: 5px;
|
||||||
|
background: #1a1a2e;
|
||||||
|
color: #eee;
|
||||||
|
font-size: 0.9rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.form-group input[type="file"] {
|
||||||
|
padding: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.form-group small {
|
||||||
|
display: block;
|
||||||
|
margin-top: 3px;
|
||||||
|
color: #666;
|
||||||
|
font-size: 0.8rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn {
|
||||||
|
width: 100%;
|
||||||
|
padding: 12px;
|
||||||
|
border: none;
|
||||||
|
border-radius: 5px;
|
||||||
|
font-size: 1rem;
|
||||||
|
cursor: pointer;
|
||||||
|
transition: all 0.3s;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-primary {
|
||||||
|
background: #e94560;
|
||||||
|
color: white;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-primary:hover {
|
||||||
|
background: #ff6b6b;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-primary:disabled {
|
||||||
|
background: #555;
|
||||||
|
cursor: not-allowed;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-secondary {
|
||||||
|
background: #0f3460;
|
||||||
|
color: white;
|
||||||
|
margin-top: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-secondary:hover {
|
||||||
|
background: #1a4a7a;
|
||||||
|
}
|
||||||
|
|
||||||
|
.map-container {
|
||||||
|
flex: 1;
|
||||||
|
position: relative;
|
||||||
|
}
|
||||||
|
|
||||||
|
#map {
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status {
|
||||||
|
padding: 10px;
|
||||||
|
border-radius: 5px;
|
||||||
|
margin-top: 15px;
|
||||||
|
font-size: 0.9rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status.info {
|
||||||
|
background: #0f3460;
|
||||||
|
border-left: 3px solid #4fc3f7;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status.success {
|
||||||
|
background: #1b5e20;
|
||||||
|
border-left: 3px solid #4caf50;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status.error {
|
||||||
|
background: #b71c1c;
|
||||||
|
border-left: 3px solid #f44336;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status.loading {
|
||||||
|
background: #0f3460;
|
||||||
|
border-left: 3px solid #ff9800;
|
||||||
|
animation: pulse 1.5s infinite;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes pulse {
|
||||||
|
|
||||||
|
0%,
|
||||||
|
100% {
|
||||||
|
opacity: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
50% {
|
||||||
|
opacity: 0.7;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
.metrics {
|
||||||
|
background: #0f3460;
|
||||||
|
padding: 15px;
|
||||||
|
border-radius: 5px;
|
||||||
|
margin-top: 15px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.metrics-grid {
|
||||||
|
display: grid;
|
||||||
|
grid-template-columns: 1fr 1fr;
|
||||||
|
gap: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.metric {
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
.metric-value {
|
||||||
|
font-size: 1.5rem;
|
||||||
|
font-weight: bold;
|
||||||
|
color: #e94560;
|
||||||
|
}
|
||||||
|
|
||||||
|
.metric-label {
|
||||||
|
font-size: 0.8rem;
|
||||||
|
color: #aaa;
|
||||||
|
}
|
||||||
|
|
||||||
|
.classes-list {
|
||||||
|
display: flex;
|
||||||
|
flex-wrap: wrap;
|
||||||
|
gap: 5px;
|
||||||
|
margin-top: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.class-badge {
|
||||||
|
padding: 5px 10px;
|
||||||
|
border-radius: 15px;
|
||||||
|
font-size: 0.8rem;
|
||||||
|
background: #e94560;
|
||||||
|
color: white;
|
||||||
|
}
|
||||||
|
|
||||||
|
.hidden {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.file-input-wrapper {
|
||||||
|
position: relative;
|
||||||
|
overflow: hidden;
|
||||||
|
display: inline-block;
|
||||||
|
width: 100%;
|
||||||
|
}
|
||||||
|
|
||||||
|
.legend {
|
||||||
|
position: absolute;
|
||||||
|
bottom: 30px;
|
||||||
|
right: 10px;
|
||||||
|
background: rgba(22, 33, 62, 0.9);
|
||||||
|
padding: 15px;
|
||||||
|
border-radius: 5px;
|
||||||
|
z-index: 1000;
|
||||||
|
max-width: 200px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.legend h4 {
|
||||||
|
margin-bottom: 10px;
|
||||||
|
font-size: 0.9rem;
|
||||||
|
color: #e94560;
|
||||||
|
}
|
||||||
|
|
||||||
|
.legend-item {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
margin-bottom: 5px;
|
||||||
|
font-size: 0.8rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.legend-color {
|
||||||
|
width: 20px;
|
||||||
|
height: 20px;
|
||||||
|
margin-right: 10px;
|
||||||
|
border-radius: 3px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.class-editor-item {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 10px;
|
||||||
|
margin-bottom: 10px;
|
||||||
|
padding: 8px;
|
||||||
|
background: #1a1a2e;
|
||||||
|
border-radius: 5px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.class-editor-item input[type="color"] {
|
||||||
|
width: 40px;
|
||||||
|
height: 30px;
|
||||||
|
border: none;
|
||||||
|
cursor: pointer;
|
||||||
|
}
|
||||||
|
|
||||||
|
.class-editor-item input[type="text"] {
|
||||||
|
flex: 1;
|
||||||
|
padding: 8px;
|
||||||
|
border: 1px solid #0f3460;
|
||||||
|
border-radius: 3px;
|
||||||
|
background: #16213e;
|
||||||
|
color: #eee;
|
||||||
|
}
|
||||||
|
|
||||||
|
.class-editor-item span {
|
||||||
|
min-width: 60px;
|
||||||
|
color: #aaa;
|
||||||
|
font-size: 0.85rem;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
|
||||||
|
<body>
|
||||||
|
<div class="container">
|
||||||
|
<div class="sidebar">
|
||||||
|
<h1>🛰️ GIS Classification</h1>
|
||||||
|
<p style="color: #888; font-size: 0.9rem; margin-bottom: 20px;">
|
||||||
|
Land cover classification using machine learning
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<h2>1. Upload Data</h2>
|
||||||
|
<div class="form-group">
|
||||||
|
<label for="raster">Raster File (GeoTIFF)</label>
|
||||||
|
<input type="file" id="raster" accept=".tif,.tiff">
|
||||||
|
<small>Select Landsat or similar multispectral imagery</small>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="form-group">
|
||||||
|
<label for="vector">Shapefile Files (select all: .shp, .shx, .dbf, .prj)</label>
|
||||||
|
<input type="file" id="vector" accept=".shp,.shx,.dbf,.prj" multiple onchange="showSelectedFiles()">
|
||||||
|
<small>Hold Ctrl/Cmd to select multiple files. Required: .shp, .shx, .dbf</small>
|
||||||
|
<div id="fileList" style="margin-top: 8px; font-size: 0.85rem; color: #4fc3f7;"></div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<h2>2. Configure</h2>
|
||||||
|
<div class="form-group">
|
||||||
|
<label for="strategy">Classification Strategy</label>
|
||||||
|
<select id="strategy">
|
||||||
|
<option value="random_forest">Random Forest (Recommended)</option>
|
||||||
|
<option value="svm">Support Vector Machine</option>
|
||||||
|
<option value="logistic_regression">Logistic Regression</option>
|
||||||
|
<option value="mle">Maximum Likelihood (MLE)</option>
|
||||||
|
</select>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="form-group">
|
||||||
|
<label for="classColumn">Class Column</label>
|
||||||
|
<input type="text" id="classColumn" value="class" placeholder="Column name in shapefile">
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="form-group">
|
||||||
|
<label for="nEstimators">Estimators (Random Forest)</label>
|
||||||
|
<input type="number" id="nEstimators" value="100" min="1" max="500">
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="form-group">
|
||||||
|
<label for="kernel">SVM Kernel</label>
|
||||||
|
<select id="kernel">
|
||||||
|
<option value="linear">Linear (Fast)</option>
|
||||||
|
<option value="rbf">RBF (Accurate but slow)</option>
|
||||||
|
</select>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<h2>3. Train & Classify</h2>
|
||||||
|
<button class="btn btn-primary" id="trainBtn" onclick="train()">
|
||||||
|
Train Model
|
||||||
|
</button>
|
||||||
|
|
||||||
|
<div id="status" class="status hidden"></div>
|
||||||
|
|
||||||
|
<div id="metrics" class="metrics hidden">
|
||||||
|
<div class="metrics-grid">
|
||||||
|
<div class="metric">
|
||||||
|
<div class="metric-value" id="accuracyValue">-</div>
|
||||||
|
<div class="metric-label">Accuracy</div>
|
||||||
|
</div>
|
||||||
|
<div class="metric">
|
||||||
|
<div class="metric-value" id="kappaValue">-</div>
|
||||||
|
<div class="metric-label">Cohen's Kappa</div>
|
||||||
|
</div>
|
||||||
|
<div class="metric">
|
||||||
|
<div class="metric-value" id="trainSamplesValue">-</div>
|
||||||
|
<div class="metric-label">Train Samples</div>
|
||||||
|
</div>
|
||||||
|
<div class="metric">
|
||||||
|
<div class="metric-value" id="valSamplesValue">-</div>
|
||||||
|
<div class="metric-label">Validation Samples</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div id="classesContainer" class="hidden">
|
||||||
|
<label style="font-size: 0.8rem; color: #aaa; margin-top: 10px; display: block;">Classes:</label>
|
||||||
|
<div class="classes-list" id="classesList"></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<h2>4. Class Templates</h2>
|
||||||
|
<div class="form-group">
|
||||||
|
<label for="templateSelect">Load Template</label>
|
||||||
|
<select id="templateSelect" onchange="loadTemplate()">
|
||||||
|
<option value="">-- Select Template --</option>
|
||||||
|
</select>
|
||||||
|
<small>Pre-defined class names and colors</small>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div style="display: flex; gap: 10px; margin-bottom: 15px;">
|
||||||
|
<button class="btn btn-secondary" onclick="saveTemplate()" style="flex: 1;">
|
||||||
|
Save Template
|
||||||
|
</button>
|
||||||
|
<button class="btn btn-secondary" onclick="exportTemplate()" style="flex: 1;">
|
||||||
|
Export
|
||||||
|
</button>
|
||||||
|
<label class="btn btn-secondary" style="flex: 1; text-align: center; cursor: pointer;">
|
||||||
|
Import
|
||||||
|
<input type="file" id="importTemplate" accept=".json" onchange="importTemplate(this)"
|
||||||
|
style="display: none;">
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div id="classEditor" class="hidden"
|
||||||
|
style="background: #0f3460; padding: 15px; border-radius: 5px; margin-bottom: 15px;">
|
||||||
|
<h4 style="margin-bottom: 10px; color: #e94560;">Edit Class Names & Colors</h4>
|
||||||
|
<div id="classEditorItems"></div>
|
||||||
|
<button class="btn btn-primary" onclick="applyClassTemplate()" style="margin-top: 15px;">
|
||||||
|
Apply to Map
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="form-group">
|
||||||
|
<label for="opacitySlider">Layer Opacity: <span id="opacityValue">70%</span></label>
|
||||||
|
<input type="range" id="opacitySlider" min="0" max="100" value="70" oninput="updateOpacity()">
|
||||||
|
<small>Adjust classification layer transparency</small>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<button class="btn btn-secondary hidden" id="predictBtn" onclick="predict()">
|
||||||
|
Classify Raster
|
||||||
|
</button>
|
||||||
|
|
||||||
|
<a id="downloadLink" class="btn btn-secondary hidden"
|
||||||
|
style="text-align: center; text-decoration: none; display: block;" download>
|
||||||
|
Download Result
|
||||||
|
</a>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="map-container">
|
||||||
|
<div id="map"></div>
|
||||||
|
<div id="legend" class="legend hidden">
|
||||||
|
<h4>Classes</h4>
|
||||||
|
<div id="legendItems"></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Leaflet JS -->
|
||||||
|
<script src="https://unpkg.com/leaflet@1.9.4/dist/leaflet.js"></script>
|
||||||
|
|
||||||
|
<!-- GeoRasterLayer for Leaflet -->
|
||||||
|
<script src="https://unpkg.com/georaster"></script>
|
||||||
|
<script src="https://unpkg.com/georaster-layer-for-leaflet"></script>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
// Initialize map
|
||||||
|
const map = L.map('map', {
|
||||||
|
attributionControl: false
|
||||||
|
}).setView([0, 0], 3);
|
||||||
|
|
||||||
|
// Add satellite base layer (Esri World Imagery)
|
||||||
|
L.tileLayer('https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}', {
|
||||||
|
attribution: 'Tiles © Esri — Source: Esri, i-cubed, USDA, USGS, AEX, GeoEye, Getmapping, Aerogrid, IGN, IGP, UPR-EGP, and the GIS User Community',
|
||||||
|
maxZoom: 19
|
||||||
|
}).addTo(map);
|
||||||
|
|
||||||
|
// Add custom attribution control without country flags
|
||||||
|
L.control.attribution({
|
||||||
|
prefix: ''
|
||||||
|
}).addTo(map);
|
||||||
|
|
||||||
|
let sessionId = null;
|
||||||
|
let geoRasterLayer = null;
|
||||||
|
let currentClasses = [];
|
||||||
|
let currentPalette = [];
|
||||||
|
let classTemplates = {};
|
||||||
|
let customColorMapping = {};
|
||||||
|
let currentGeoRaster = null;
|
||||||
|
|
||||||
|
// Load templates from localStorage on init
|
||||||
|
loadTemplatesFromStorage();
|
||||||
|
|
||||||
|
function showSelectedFiles() {
|
||||||
|
const files = document.getElementById('vector').files;
|
||||||
|
const fileList = document.getElementById('fileList');
|
||||||
|
|
||||||
|
if (files.length === 0) {
|
||||||
|
fileList.innerHTML = '';
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const fileNames = Array.from(files).map(f => f.name).sort();
|
||||||
|
fileList.innerHTML = '<strong>Selected:</strong> ' + fileNames.join(', ');
|
||||||
|
}
|
||||||
|
|
||||||
|
function showStatus(message, type = 'info') {
|
||||||
|
const status = document.getElementById('status');
|
||||||
|
status.textContent = message;
|
||||||
|
status.className = `status ${type}`;
|
||||||
|
status.classList.remove('hidden');
|
||||||
|
}
|
||||||
|
|
||||||
|
function hideStatus() {
|
||||||
|
document.getElementById('status').classList.add('hidden');
|
||||||
|
}
|
||||||
|
|
||||||
|
async function train() {
|
||||||
|
const rasterFile = document.getElementById('raster').files[0];
|
||||||
|
const vectorFiles = document.getElementById('vector').files;
|
||||||
|
|
||||||
|
if (!rasterFile) {
|
||||||
|
showStatus('Please select a raster file', 'error');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!vectorFiles || vectorFiles.length === 0) {
|
||||||
|
showStatus('Please select shapefile files (.shp, .shx, .dbf)', 'error');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate shapefile files
|
||||||
|
const extensions = Array.from(vectorFiles).map(f => f.name.toLowerCase().split('.').pop());
|
||||||
|
const required = ['shp', 'shx', 'dbf'];
|
||||||
|
const missing = required.filter(ext => !extensions.includes(ext));
|
||||||
|
|
||||||
|
if (missing.length > 0) {
|
||||||
|
showStatus(`Missing required shapefile files: ${missing.map(e => '.' + e).join(', ')}`, 'error');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const formData = new FormData();
|
||||||
|
formData.append('raster', rasterFile);
|
||||||
|
|
||||||
|
// Append all shapefile files
|
||||||
|
for (let i = 0; i < vectorFiles.length; i++) {
|
||||||
|
formData.append('vector_files', vectorFiles[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
formData.append('strategy', document.getElementById('strategy').value);
|
||||||
|
formData.append('class_column', document.getElementById('classColumn').value);
|
||||||
|
formData.append('n_estimators', document.getElementById('nEstimators').value);
|
||||||
|
formData.append('kernel', document.getElementById('kernel').value);
|
||||||
|
formData.append('test_size', '0.2');
|
||||||
|
formData.append('random_state', '42');
|
||||||
|
|
||||||
|
const trainBtn = document.getElementById('trainBtn');
|
||||||
|
trainBtn.disabled = true;
|
||||||
|
trainBtn.textContent = 'Training...';
|
||||||
|
|
||||||
|
showStatus('Training classifier... This may take a few minutes.', 'loading');
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await fetch('/train', {
|
||||||
|
method: 'POST',
|
||||||
|
body: formData
|
||||||
|
});
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(data.detail || 'Training failed');
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionId = data.session_id;
|
||||||
|
|
||||||
|
// Display metrics
|
||||||
|
const metrics = data.metrics;
|
||||||
|
document.getElementById('accuracyValue').textContent = (metrics.accuracy * 100).toFixed(1) + '%';
|
||||||
|
document.getElementById('kappaValue').textContent = metrics.kappa.toFixed(4);
|
||||||
|
document.getElementById('trainSamplesValue').textContent = metrics.train_samples.toLocaleString();
|
||||||
|
document.getElementById('valSamplesValue').textContent = metrics.val_samples.toLocaleString();
|
||||||
|
|
||||||
|
// Display classes
|
||||||
|
const classesList = document.getElementById('classesList');
|
||||||
|
classesList.innerHTML = metrics.classes.map(c =>
|
||||||
|
`<span class="class-badge">Class ${c}</span>`
|
||||||
|
).join('');
|
||||||
|
document.getElementById('classesContainer').classList.remove('hidden');
|
||||||
|
|
||||||
|
document.getElementById('metrics').classList.remove('hidden');
|
||||||
|
document.getElementById('predictBtn').classList.remove('hidden');
|
||||||
|
|
||||||
|
showStatus(`Training completed! Accuracy: ${(metrics.accuracy * 100).toFixed(1)}%, Kappa: ${metrics.kappa.toFixed(4)}. You can classify raster and customize visualization with templates!`, 'success');
|
||||||
|
|
||||||
|
} catch (error) {
|
||||||
|
showStatus(`Error: ${error.message}`, 'error');
|
||||||
|
} finally {
|
||||||
|
trainBtn.disabled = false;
|
||||||
|
trainBtn.textContent = 'Train Model';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function predict() {
|
||||||
|
if (!sessionId) {
|
||||||
|
showStatus('No trained model available', 'error');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const predictBtn = document.getElementById('predictBtn');
|
||||||
|
predictBtn.disabled = true;
|
||||||
|
predictBtn.textContent = 'Classifying...';
|
||||||
|
|
||||||
|
showStatus('Classifying raster... This may take a moment.', 'loading');
|
||||||
|
|
||||||
|
const formData = new FormData();
|
||||||
|
formData.append('session_id', sessionId);
|
||||||
|
formData.append('output_format', 'geotiff');
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await fetch('/predict', {
|
||||||
|
method: 'POST',
|
||||||
|
body: formData
|
||||||
|
});
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(data.detail || 'Classification failed');
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update map view to show raster bounds
|
||||||
|
const bounds = [
|
||||||
|
[data.bounds.bottom, data.bounds.left],
|
||||||
|
[data.bounds.top, data.bounds.right]
|
||||||
|
];
|
||||||
|
|
||||||
|
map.fitBounds(bounds);
|
||||||
|
|
||||||
|
// Load and display classified raster
|
||||||
|
showStatus('Loading classification result...', 'loading');
|
||||||
|
|
||||||
|
// Download and display the classified raster
|
||||||
|
const downloadResponse = await fetch(`/result/${sessionId}/download`);
|
||||||
|
const blob = await downloadResponse.blob();
|
||||||
|
const arrayBuffer = await blob.arrayBuffer();
|
||||||
|
|
||||||
|
// Parse GeoTIFF
|
||||||
|
const geoRaster = await parseGeoraster(arrayBuffer);
|
||||||
|
currentGeoRaster = geoRaster;
|
||||||
|
|
||||||
|
// Remove existing layer
|
||||||
|
if (geoRasterLayer) {
|
||||||
|
map.removeLayer(geoRasterLayer);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create color palette for classes
|
||||||
|
const palette = generatePalette(data.classes.length);
|
||||||
|
|
||||||
|
// Add GeoRaster layer
|
||||||
|
geoRasterLayer = createGeoRasterLayer(geoRaster, palette);
|
||||||
|
geoRasterLayer.addTo(map);
|
||||||
|
|
||||||
|
// Apply initial opacity from slider
|
||||||
|
const initialOpacity = document.getElementById('opacitySlider').value / 100;
|
||||||
|
geoRasterLayer.setOpacity(initialOpacity);
|
||||||
|
|
||||||
|
// Zoom to layer bounds
|
||||||
|
if (geoRasterLayer.getBounds()) {
|
||||||
|
map.fitBounds(geoRasterLayer.getBounds());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update legend
|
||||||
|
currentClasses = data.classes;
|
||||||
|
currentPalette = palette;
|
||||||
|
updateLegend(data.classes, palette);
|
||||||
|
|
||||||
|
// Show class editor for customizing names
|
||||||
|
document.getElementById('classEditor').classList.remove('hidden');
|
||||||
|
const editorItems = document.getElementById('classEditorItems');
|
||||||
|
editorItems.innerHTML = '';
|
||||||
|
data.classes.forEach((cls, i) => {
|
||||||
|
const item = document.createElement('div');
|
||||||
|
item.className = 'class-editor-item';
|
||||||
|
item.innerHTML = `
|
||||||
|
<span>Class ${cls}</span>
|
||||||
|
<input type="color" value="${palette[i]}" data-class="${cls}" data-type="color">
|
||||||
|
<input type="text" value="Class ${cls}" data-class="${cls}" data-type="name" placeholder="Class name">
|
||||||
|
`;
|
||||||
|
editorItems.appendChild(item);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Update download link
|
||||||
|
const downloadLink = document.getElementById('downloadLink');
|
||||||
|
downloadLink.href = URL.createObjectURL(blob);
|
||||||
|
downloadLink.download = 'classified_result.tif';
|
||||||
|
downloadLink.classList.remove('hidden');
|
||||||
|
|
||||||
|
showStatus('Classification complete!', 'success');
|
||||||
|
|
||||||
|
} catch (error) {
|
||||||
|
showStatus(`Error: ${error.message}`, 'error');
|
||||||
|
console.error(error);
|
||||||
|
} finally {
|
||||||
|
predictBtn.disabled = false;
|
||||||
|
predictBtn.textContent = 'Classify Raster';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function generatePalette(nClasses) {
|
||||||
|
// Generate distinct colors for classes
|
||||||
|
const colors = [
|
||||||
|
'#e41a1c', '#377eb8', '#4daf4a', '#984ea3', '#ff7f00',
|
||||||
|
'#ffff33', '#a65628', '#f781bf', '#999999', '#66c2a5',
|
||||||
|
'#fc8d62', '#8da0cb', '#e78ac3', '#a6d854', '#ffd92f',
|
||||||
|
'#e5c494', '#b3b3b3', '#1b9e77', '#d95f02', '#7570b3'
|
||||||
|
];
|
||||||
|
|
||||||
|
const palette = [];
|
||||||
|
for (let i = 0; i < nClasses; i++) {
|
||||||
|
palette.push(colors[i % colors.length]);
|
||||||
|
}
|
||||||
|
return palette;
|
||||||
|
}
|
||||||
|
|
||||||
|
function updateLegend(classes, palette) {
|
||||||
|
const legend = document.getElementById('legend');
|
||||||
|
const legendItems = document.getElementById('legendItems');
|
||||||
|
|
||||||
|
legendItems.innerHTML = classes.map((cls, i) => `
|
||||||
|
<div class="legend-item">
|
||||||
|
<div class="legend-color" style="background: ${palette[i]}"></div>
|
||||||
|
<span>Class ${cls}</span>
|
||||||
|
</div>
|
||||||
|
`).join('');
|
||||||
|
|
||||||
|
legend.classList.remove('hidden');
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to parse GeoTIFF
|
||||||
|
async function parseGeoraster(arrayBuffer) {
|
||||||
|
return new Promise((resolve, reject) => {
|
||||||
|
parseGeorasterInternal(arrayBuffer).then(resolve).catch(reject);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function parseGeorasterInternal(arrayBuffer) {
|
||||||
|
return new Promise((resolve, reject) => {
|
||||||
|
try {
|
||||||
|
const georaster = new GeoRaster(arrayBuffer);
|
||||||
|
resolve(georaster);
|
||||||
|
} catch (e) {
|
||||||
|
// Fallback: use global parseGeoraster if available
|
||||||
|
if (typeof window.parseGeoraster === 'function') {
|
||||||
|
window.parseGeoraster(arrayBuffer).then(resolve).catch(reject);
|
||||||
|
} else {
|
||||||
|
reject(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== CLASS TEMPLATE FUNCTIONS ====================
|
||||||
|
|
||||||
|
function loadTemplatesFromStorage() {
|
||||||
|
const stored = localStorage.getItem('gis_class_templates');
|
||||||
|
if (stored) {
|
||||||
|
try {
|
||||||
|
classTemplates = JSON.parse(stored);
|
||||||
|
} catch (e) {
|
||||||
|
classTemplates = {};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Add default templates if none exist
|
||||||
|
if (Object.keys(classTemplates).length === 0) {
|
||||||
|
classTemplates = {
|
||||||
|
'Rainbow': {
|
||||||
|
'1': { name: 'Red', color: '#FF0000' },
|
||||||
|
'2': { name: 'Orange', color: '#FF7F00' },
|
||||||
|
'3': { name: 'Yellow', color: '#FFFF00' },
|
||||||
|
'4': { name: 'Green', color: '#00FF00' },
|
||||||
|
'5': { name: 'Blue', color: '#0000FF' },
|
||||||
|
'6': { name: 'Violet', color: '#8B00FF' }
|
||||||
|
},
|
||||||
|
};
|
||||||
|
saveTemplatesToStorage();
|
||||||
|
}
|
||||||
|
updateTemplateSelect();
|
||||||
|
}
|
||||||
|
|
||||||
|
function saveTemplatesToStorage() {
|
||||||
|
localStorage.setItem('gis_class_templates', JSON.stringify(classTemplates));
|
||||||
|
}
|
||||||
|
|
||||||
|
function updateTemplateSelect() {
|
||||||
|
const select = document.getElementById('templateSelect');
|
||||||
|
select.innerHTML = '<option value="">-- Select Template --</option>';
|
||||||
|
Object.keys(classTemplates).forEach(name => {
|
||||||
|
const option = document.createElement('option');
|
||||||
|
option.value = name;
|
||||||
|
option.textContent = name;
|
||||||
|
select.appendChild(option);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function loadTemplate() {
|
||||||
|
const select = document.getElementById('templateSelect');
|
||||||
|
const templateName = select.value;
|
||||||
|
if (!templateName || !classTemplates[templateName]) return;
|
||||||
|
|
||||||
|
const template = classTemplates[templateName];
|
||||||
|
const editorItems = document.getElementById('classEditorItems');
|
||||||
|
editorItems.innerHTML = '';
|
||||||
|
|
||||||
|
// Get current classes or use template keys
|
||||||
|
const classesToShow = currentClasses.length > 0 ? currentClasses : Object.keys(template).sort((a, b) => parseInt(a) - parseInt(b));
|
||||||
|
|
||||||
|
classesToShow.forEach(cls => {
|
||||||
|
const item = document.createElement('div');
|
||||||
|
item.className = 'class-editor-item';
|
||||||
|
const templateData = template[cls.toString()] || { name: `Class ${cls}`, color: '#808080' };
|
||||||
|
item.innerHTML = `
|
||||||
|
<span>Class ${cls}</span>
|
||||||
|
<input type="color" value="${templateData.color}" data-class="${cls}" data-type="color">
|
||||||
|
<input type="text" value="${templateData.name}" data-class="${cls}" data-type="name" placeholder="Class name">
|
||||||
|
`;
|
||||||
|
editorItems.appendChild(item);
|
||||||
|
});
|
||||||
|
|
||||||
|
document.getElementById('classEditor').classList.remove('hidden');
|
||||||
|
|
||||||
|
// Auto-apply the template
|
||||||
|
applyClassTemplate();
|
||||||
|
}
|
||||||
|
|
||||||
|
function saveTemplate() {
|
||||||
|
const name = prompt('Enter template name:');
|
||||||
|
if (!name) return;
|
||||||
|
|
||||||
|
if (!currentClasses.length) {
|
||||||
|
showStatus('No classes available. Run classification first.', 'error');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const template = {};
|
||||||
|
const editorItems = document.querySelectorAll('.class-editor-item');
|
||||||
|
|
||||||
|
editorItems.forEach(item => {
|
||||||
|
const classId = item.querySelector('input[data-type="color"]').dataset.class;
|
||||||
|
const color = item.querySelector('input[data-type="color"]').value;
|
||||||
|
const className = item.querySelector('input[data-type="name"]').value;
|
||||||
|
template[classId] = { name: className, color: color };
|
||||||
|
});
|
||||||
|
|
||||||
|
classTemplates[name] = template;
|
||||||
|
saveTemplatesToStorage();
|
||||||
|
updateTemplateSelect();
|
||||||
|
document.getElementById('templateSelect').value = name;
|
||||||
|
showStatus(`Template "${name}" saved!`, 'success');
|
||||||
|
}
|
||||||
|
|
||||||
|
function exportTemplate() {
|
||||||
|
const select = document.getElementById('templateSelect');
|
||||||
|
const templateName = select.value;
|
||||||
|
if (!templateName || !classTemplates[templateName]) {
|
||||||
|
showStatus('Please select a template to export', 'error');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const template = { name: templateName, classes: classTemplates[templateName] };
|
||||||
|
const blob = new Blob([JSON.stringify(template, null, 2)], { type: 'application/json' });
|
||||||
|
const url = URL.createObjectURL(blob);
|
||||||
|
const a = document.createElement('a');
|
||||||
|
a.href = url;
|
||||||
|
a.download = `${templateName.replace(/\s+/g, '_')}_template.json`;
|
||||||
|
a.click();
|
||||||
|
URL.revokeObjectURL(url);
|
||||||
|
}
|
||||||
|
|
||||||
|
function importTemplate(input) {
|
||||||
|
const file = input.files[0];
|
||||||
|
if (!file) return;
|
||||||
|
|
||||||
|
const reader = new FileReader();
|
||||||
|
reader.onload = function (e) {
|
||||||
|
try {
|
||||||
|
const template = JSON.parse(e.target.result);
|
||||||
|
if (!template.name || !template.classes) {
|
||||||
|
throw new Error('Invalid template format');
|
||||||
|
}
|
||||||
|
classTemplates[template.name] = template.classes;
|
||||||
|
saveTemplatesToStorage();
|
||||||
|
updateTemplateSelect();
|
||||||
|
document.getElementById('templateSelect').value = template.name;
|
||||||
|
showStatus(`Template "${template.name}" imported!`, 'success');
|
||||||
|
} catch (err) {
|
||||||
|
showStatus('Invalid template file: ' + err.message, 'error');
|
||||||
|
}
|
||||||
|
};
|
||||||
|
reader.readAsText(file);
|
||||||
|
input.value = '';
|
||||||
|
}
|
||||||
|
|
||||||
|
function applyClassTemplate() {
|
||||||
|
const editorItems = document.querySelectorAll('.class-editor-item');
|
||||||
|
const customMapping = {};
|
||||||
|
|
||||||
|
editorItems.forEach(item => {
|
||||||
|
const classId = item.querySelector('input[data-type="color"]').dataset.class;
|
||||||
|
const color = item.querySelector('input[data-type="color"]').value;
|
||||||
|
const className = item.querySelector('input[data-type="name"]').value;
|
||||||
|
customMapping[classId] = { name: className, color: color };
|
||||||
|
});
|
||||||
|
|
||||||
|
// Update legend with custom names
|
||||||
|
updateLegendWithCustomNames(currentClasses, currentPalette, customMapping);
|
||||||
|
showStatus('Class template applied!', 'success');
|
||||||
|
}
|
||||||
|
|
||||||
|
function updateLegendWithCustomNames(classes, palette, customMapping) {
|
||||||
|
const legend = document.getElementById('legend');
|
||||||
|
const legendItems = document.getElementById('legendItems');
|
||||||
|
|
||||||
|
legendItems.innerHTML = classes.map((cls, i) => {
|
||||||
|
const custom = customMapping[cls.toString()] || customMapping[i.toString()];
|
||||||
|
const name = custom ? custom.name : `Class ${cls}`;
|
||||||
|
const color = custom ? custom.color : palette[i];
|
||||||
|
return `
|
||||||
|
<div class="legend-item">
|
||||||
|
<div class="legend-color" style="background: ${color}"></div>
|
||||||
|
<span>${name}</span>
|
||||||
|
</div>
|
||||||
|
`;
|
||||||
|
}).join('');
|
||||||
|
|
||||||
|
legend.classList.remove('hidden');
|
||||||
|
}
|
||||||
|
|
||||||
|
function createGeoRasterLayer(geoRaster, palette, customMapping = {}) {
|
||||||
|
const layer = new GeoRasterLayer({
|
||||||
|
georaster: geoRaster,
|
||||||
|
opacity: 0.7,
|
||||||
|
pixelValuesToColorFn: function (pixelValues) {
|
||||||
|
const value = pixelValues[0];
|
||||||
|
if (value === null || value === undefined || value === geoRaster.nodata) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
const classIndex = Math.floor(value) - 1;
|
||||||
|
if (classIndex < 0 || classIndex >= palette.length) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
// Use custom color if available, otherwise use palette
|
||||||
|
const className = (classIndex + 1).toString();
|
||||||
|
if (customMapping[className] && customMapping[className].color) {
|
||||||
|
return customMapping[className].color;
|
||||||
|
}
|
||||||
|
return palette[classIndex];
|
||||||
|
},
|
||||||
|
resolution: Math.max(64, Math.min(256, geoRaster.height / 10))
|
||||||
|
});
|
||||||
|
// Force Leaflet to treat this as a new layer
|
||||||
|
layer._leaflet_id = 'georaster_' + Date.now();
|
||||||
|
return layer;
|
||||||
|
}
|
||||||
|
|
||||||
|
function applyClassTemplate() {
|
||||||
|
const editorItems = document.querySelectorAll('.class-editor-item');
|
||||||
|
const customMapping = {};
|
||||||
|
|
||||||
|
editorItems.forEach(item => {
|
||||||
|
const classId = item.querySelector('input[data-type="color"]').dataset.class;
|
||||||
|
const color = item.querySelector('input[data-type="color"]').value;
|
||||||
|
const className = item.querySelector('input[data-type="name"]').value;
|
||||||
|
customMapping[classId] = { name: className, color: color };
|
||||||
|
});
|
||||||
|
|
||||||
|
// Store custom mapping
|
||||||
|
window.currentCustomMapping = customMapping;
|
||||||
|
|
||||||
|
// Update legend first
|
||||||
|
updateLegendWithCustomNames(currentClasses, currentPalette, customMapping);
|
||||||
|
|
||||||
|
// Re-create GeoRaster layer with custom colors
|
||||||
|
if (currentGeoRaster && geoRasterLayer) {
|
||||||
|
// Remove old layer
|
||||||
|
map.removeLayer(geoRasterLayer);
|
||||||
|
|
||||||
|
// Create new layer with custom colors
|
||||||
|
geoRasterLayer = createGeoRasterLayer(currentGeoRaster, currentPalette, customMapping);
|
||||||
|
geoRasterLayer.addTo(map);
|
||||||
|
|
||||||
|
// Force complete redraw by triggering multiple refresh methods
|
||||||
|
setTimeout(() => {
|
||||||
|
// Clear cache of the layer
|
||||||
|
if (geoRasterLayer.clearCache) {
|
||||||
|
geoRasterLayer.clearCache();
|
||||||
|
}
|
||||||
|
// Try to redraw the layer
|
||||||
|
if (geoRasterLayer.redraw) {
|
||||||
|
geoRasterLayer.redraw();
|
||||||
|
}
|
||||||
|
// Trigger map move event to force tile refresh
|
||||||
|
map.fire('moveend');
|
||||||
|
}, 50);
|
||||||
|
}
|
||||||
|
|
||||||
|
showStatus('Class template applied! Zoom in/out if colors do not update immediately.', 'success');
|
||||||
|
}
|
||||||
|
|
||||||
|
function updateOpacity() {
|
||||||
|
const slider = document.getElementById('opacitySlider');
|
||||||
|
const valueDisplay = document.getElementById('opacityValue');
|
||||||
|
const opacity = slider.value / 100;
|
||||||
|
|
||||||
|
valueDisplay.textContent = slider.value + '%';
|
||||||
|
|
||||||
|
if (geoRasterLayer) {
|
||||||
|
geoRasterLayer.setOpacity(opacity);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
</html>
|
||||||
Loading…
Add table
Add a link
Reference in a new issue