feat: add FastAPI web interface for GIS classification

This commit is contained in:
Andrew 2026-03-15 14:28:51 +07:00
parent 5a9b8469bd
commit 6815769d2b
5 changed files with 1458 additions and 15 deletions

View file

@ -10,8 +10,10 @@ gis-classification/
├── requirements.txt # Dependencies ├── requirements.txt # Dependencies
├── data/ # Input data folder ├── data/ # Input data folder
├── output/ # Classification results ├── output/ # Classification results
├── static/ # Web frontend
└── src/ └── src/
├── classifier.py # Main classification pipeline ├── classifier.py # Main classification pipeline
├── api.py # FastAPI web server
├── data/ ├── data/
│ └── loader.py # Data loading (GeoTIFF, Shapefile) │ └── loader.py # Data loading (GeoTIFF, Shapefile)
├── strategies/ ├── strategies/
@ -26,8 +28,15 @@ gis-classification/
pip install -r requirements.txt pip install -r requirements.txt
``` ```
Or with uv:
```bash
uv pip install -r requirements.txt
```
## Usage ## Usage
### CLI Mode
1. Place your input files in `data/`: 1. Place your input files in `data/`:
- `landsat.tif` - GeoTIFF from Landsat - `landsat.tif` - GeoTIFF from Landsat
- `polygons.shp` - Shapefile with class labels - `polygons.shp` - Shapefile with class labels
@ -42,6 +51,31 @@ pip install -r requirements.txt
python main.py python main.py
``` ```
### Web Interface
Start the web server:
```bash
uvicorn src.api:app --host 0.0.0.0 --port 8000 --reload
```
Then open http://localhost:8000 in your browser.
**Features:**
- Upload GeoTIFF raster and Shapefile training data
- Select classification strategy (Random Forest, SVM, Logistic Regression, MLE)
- View training metrics (Accuracy, Cohen's Kappa)
- Interactive map visualization with Leaflet
- Download classified GeoTIFF results
**API Endpoints:**
- `GET /` - Web interface
- `POST /train` - Train classifier with uploaded files
- `POST /predict` - Run classification
- `GET /result/{session_id}` - Get result metadata
- `GET /result/{session_id}/download` - Download classified GeoTIFF
- `GET /docs` - Interactive API documentation (Swagger UI)
## Adding Custom Classification Strategy ## Adding Custom Classification Strategy
Create a new class implementing `ClassificationStrategy`: Create a new class implementing `ClassificationStrategy`:
@ -54,17 +88,17 @@ class MyCustomStrategy(ClassificationStrategy):
def train(self, X: np.ndarray, y: np.ndarray) -> None: def train(self, X: np.ndarray, y: np.ndarray) -> None:
# Your training logic # Your training logic
pass pass
def predict(self, X: np.ndarray) -> np.ndarray: def predict(self, X: np.ndarray) -> np.ndarray:
# Your prediction logic # Your prediction logic
pass pass
def predict_proba(self, X: np.ndarray) -> np.ndarray: def predict_proba(self, X: np.ndarray) -> np.ndarray:
pass pass
def get_params(self) -> dict: def get_params(self) -> dict:
pass pass
@property @property
def name(self) -> str: def name(self) -> str:
return "MyCustom" return "MyCustom"
@ -78,4 +112,11 @@ STRATEGY = MyCustomStrategy()
## Output ## Output
- `output/classified.tif` - Classified raster (GeoTIFF) - `output/classified.tif` - Classified raster (GeoTIFF)
- Console output with accuracy metrics - Console output with accuracy metrics (Accuracy, Cohen's Kappa)
- Web interface visualization with interactive map
## Metrics
The classifier reports:
- **Accuracy**: Overall classification accuracy
- **Cohen's Kappa**: Agreement statistic accounting for chance (values > 0.8 indicate excellent agreement)

20
main.py
View file

@ -29,19 +29,19 @@ OUTPUT_PATH = os.path.join("output", "classified.tif")
# Change strategy by uncommenting desired option: # Change strategy by uncommenting desired option:
# Option 1: Random Forest (recommended for GIS) # Option 1: Random Forest (recommended for GIS)
STRATEGY = RandomForestStrategy( # STRATEGY = RandomForestStrategy(
n_estimators=100, # n_estimators=100,
max_depth=None, # max_depth=None,
random_state=42,
)
# # Option 2: Support Vector Machine
# STRATEGY = SVMStrategy(
# kernel="linear", # Fast prediction; use 'rbf' for better accuracy but much slower
# C=1.0,
# random_state=42, # random_state=42,
# ) # )
# # Option 2: Support Vector Machine
STRATEGY = SVMStrategy(
kernel="linear", # Fast prediction; use 'rbf' for better accuracy but much slower
C=1.0,
random_state=42,
)
# Option 3: Logistic Regression # Option 3: Logistic Regression
# STRATEGY = LogisticRegressionStrategy( # STRATEGY = LogisticRegressionStrategy(
# penalty="l2", # penalty="l2",

View file

@ -1,3 +1,4 @@
# Core dependencies
rasterio>=1.3.0 rasterio>=1.3.0
geopandas>=0.12.0 geopandas>=0.12.0
shapely>=2.0.0 shapely>=2.0.0
@ -5,3 +6,12 @@ scikit-learn>=1.3.0
scipy>=1.10.0 scipy>=1.10.0
numpy>=1.24.0 numpy>=1.24.0
pandas>=2.0.0 pandas>=2.0.0
# Web API
fastapi>=0.104.0
uvicorn[standard]>=0.24.0
python-multipart>=0.0.6
pydantic>=2.0.0
# Visualization (optional, for tile generation)
matplotlib>=3.7.0

398
src/api.py Normal file
View file

@ -0,0 +1,398 @@
"""FastAPI web API for GIS classification."""
import os
import io
import tempfile
import shutil
from typing import Optional
from fastapi import FastAPI, File, UploadFile, HTTPException, Form
from fastapi.responses import JSONResponse, FileResponse, HTMLResponse, Response
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
import numpy as np
import rasterio
from rasterio.io import MemoryFile
from .classifier import GISClassifier
from .data import load_raster, load_vector, extract_raster_values_by_polygons
from .strategies.classifiers import (
RandomForestStrategy,
SVMStrategy,
LogisticRegressionStrategy,
MLEStrategy,
)
app = FastAPI(
title="GIS Classification API",
description="Land cover classification using machine learning",
version="1.0.0",
)
# Store for active classifiers (in production, use proper session management)
_classifiers: dict[str, dict] = {}
class TrainRequest(BaseModel):
"""Request model for training."""
strategy: str = "random_forest"
class_column: str = "class"
test_size: float = 0.2
random_state: int = 42
# Strategy-specific parameters
n_estimators: int = 100
max_depth: Optional[int] = None
kernel: str = "linear"
C: float = 1.0
class TrainResponse(BaseModel):
"""Response model for training."""
status: str
message: str
metrics: Optional[dict] = None
def _get_strategy(name: str, params: dict) -> object:
"""Create classification strategy by name."""
strategies = {
"random_forest": RandomForestStrategy,
"svm": SVMStrategy,
"logistic_regression": LogisticRegressionStrategy,
"mle": MLEStrategy,
}
if name not in strategies:
raise ValueError(f"Unknown strategy: {name}. Available: {list(strategies.keys())}")
StrategyClass = strategies[name]
# Map parameters based on strategy
if name == "random_forest":
return StrategyClass(
n_estimators=params.get("n_estimators", 100),
max_depth=params.get("max_depth"),
random_state=params.get("random_state", 42),
)
elif name == "svm":
return StrategyClass(
kernel=params.get("kernel", "linear"),
C=params.get("C", 1.0),
random_state=params.get("random_state", 42),
)
elif name == "logistic_regression":
return StrategyClass(
C=params.get("C", 1.0),
max_iter=params.get("max_iter", 1000),
random_state=params.get("random_state", 42),
)
elif name == "mle":
return StrategyClass(
reg_covar=params.get("reg_covar", 1e-6),
)
return StrategyClass()
@app.get("/", response_class=HTMLResponse)
async def root():
"""Serve the frontend HTML page."""
static_path = os.path.join(os.path.dirname(__file__), "..", "static")
html_path = os.path.join(static_path, "index.html")
if os.path.exists(html_path):
with open(html_path, "r", encoding="utf-8") as f:
return HTMLResponse(content=f.read())
return HTMLResponse(
content="""
<html>
<head><title>GIS Classification API</title></head>
<body>
<h1>GIS Classification API</h1>
<p>API is running. Visit <a href="/docs">/docs</a> for interactive documentation.</p>
<h2>Endpoints:</h2>
<ul>
<li><code>POST /train</code> - Train classifier with uploaded files</li>
<li><code>POST /predict</code> - Classify a raster</li>
<li><code>GET /result/{session_id}</code> - Get classification result</li>
<li><code>GET /result/{session_id}/download</code> - Download classified GeoTIFF</li>
</ul>
</body>
</html>
""",
)
@app.post("/train")
async def train(
raster: UploadFile = File(..., description="GeoTIFF raster file"),
vector_files: list[UploadFile] = File(..., description="Shapefile files (.shp, .shx, .dbf, .prj)"),
strategy: str = Form("random_forest"),
class_column: str = Form("class"),
test_size: float = Form(0.2),
random_state: int = Form(42),
n_estimators: int = Form(100),
kernel: str = Form("linear"),
C: float = Form(1.0),
):
"""
Train a classification model.
- **raster**: GeoTIFF file with multispectral data
- **vector_files**: Shapefile files (.shp, .shx, .dbf, .prj) - upload all together
- **strategy**: Classification algorithm (random_forest, svm, logistic_regression, mle)
- **class_column**: Column name in shapefile containing class labels
"""
import uuid
session_id = str(uuid.uuid4())
try:
# Create temp directory for session
temp_dir = tempfile.mkdtemp(prefix=f"gis_{session_id}_")
# Save uploaded raster
raster_path = os.path.join(temp_dir, "raster.tif")
with open(raster_path, "wb") as f:
f.write(await raster.read())
# Save uploaded shapefile files with consistent base name
vector_base = os.path.join(temp_dir, "vector")
vector_path = f"{vector_base}.shp"
for file in vector_files:
# Get file extension
ext = os.path.splitext(file.filename)[1].lower()
# Save with consistent base name
file_path = f"{vector_base}{ext}"
with open(file_path, "wb") as f:
f.write(await file.read())
# Verify required shapefile files exist
required_exts = ['.shp', '.shx', '.dbf']
missing = [ext for ext in required_exts if not os.path.exists(f"{vector_base}{ext}")]
if missing:
raise ValueError(f"Missing required shapefile files: {', '.join(missing)}")
# Create strategy
strategy_obj = _get_strategy(
strategy,
{
"n_estimators": n_estimators,
"kernel": kernel,
"C": C,
"random_state": random_state,
},
)
# Train classifier
classifier = GISClassifier(strategy=strategy_obj)
metrics = classifier.train(
raster_path=raster_path,
vector_path=vector_path,
class_column=class_column,
test_size=test_size,
random_state=random_state,
)
# Store session data
_classifiers[session_id] = {
"classifier": classifier,
"raster_path": raster_path,
"vector_path": vector_path,
"temp_dir": temp_dir,
"result_path": None,
}
# Convert numpy types for JSON serialization
serializable_metrics = {
"train_samples": int(metrics["train_samples"]),
"val_samples": int(metrics["val_samples"]),
"accuracy": float(metrics["accuracy"]),
"kappa": float(metrics.get("kappa", 0)),
"classes": [int(c) if hasattr(c, "item") else c for c in metrics["classes"]],
}
return {
"status": "success",
"session_id": session_id,
"message": f"Training completed with {strategy} strategy",
"metrics": serializable_metrics,
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/predict")
async def predict(
session_id: str = Form(...),
output_format: str = Form("geotiff"),
):
"""
Run classification on the loaded raster.
- **session_id**: Session ID from /train endpoint
- **output_format**: Output format (geotiff, geojson)
"""
if session_id not in _classifiers:
raise HTTPException(status_code=404, detail="Session not found. Please train first.")
session = _classifiers[session_id]
classifier = session["classifier"]
raster_path = session["raster_path"]
try:
# Generate output path
if output_format == "geotiff":
output_path = os.path.join(session["temp_dir"], "classified.tif")
else:
output_path = os.path.join(session["temp_dir"], "classified.tif")
# Run prediction
result = classifier.predict(raster_path=raster_path, output_path=output_path)
session["result_path"] = output_path
session["result"] = result
# Get raster bounds for frontend
with rasterio.open(raster_path) as src:
bounds = src.bounds
crs = str(src.crs)
return {
"status": "success",
"session_id": session_id,
"output_format": output_format,
"bounds": {
"left": bounds.left,
"bottom": bounds.bottom,
"right": bounds.right,
"top": bounds.top,
},
"crs": crs,
"shape": [int(result.predicted_array.shape[0]), int(result.predicted_array.shape[1])],
"classes": [int(c) if hasattr(c, "item") else c for c in result.classes],
"message": "Classification completed",
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/result/{session_id}")
async def get_result(session_id: str):
"""Get classification result metadata."""
if session_id not in _classifiers:
raise HTTPException(status_code=404, detail="Session not found")
session = _classifiers[session_id]
if session.get("result_path") is None:
raise HTTPException(status_code=400, detail="No prediction result available")
result = session.get("result")
raster_path = session["raster_path"]
with rasterio.open(raster_path) as src:
bounds = src.bounds
crs = str(src.crs)
return {
"session_id": session_id,
"bounds": {
"left": bounds.left,
"bottom": bounds.bottom,
"right": bounds.right,
"top": bounds.top,
},
"crs": crs,
"classes": [int(c) if hasattr(c, "item") else c for c in result.classes],
"output_path": session["result_path"],
}
@app.get("/result/{session_id}/download")
async def download_result(session_id: str):
"""Download classified GeoTIFF file."""
if session_id not in _classifiers:
raise HTTPException(status_code=404, detail="Session not found")
session = _classifiers[session_id]
result_path = session.get("result_path")
if result_path is None or not os.path.exists(result_path):
raise HTTPException(status_code=404, detail="Result file not found")
return FileResponse(
path=result_path,
media_type="application/geotiff",
filename="classified.tif",
)
@app.get("/result/{session_id}/tile/{z}/{x}/{y}.png")
async def get_tile(session_id: str, z: int, x: int, y: int):
"""Get classification result as PNG tile (experimental)."""
if session_id not in _classifiers:
raise HTTPException(status_code=404, detail="Session not found")
session = _classifiers[session_id]
result_path = session.get("result_path")
if result_path is None or not os.path.exists(result_path):
raise HTTPException(status_code=404, detail="Result file not found")
try:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import rasterio
from rasterio.warp import transform_bounds
from rasterio.io import MemoryFile
with rasterio.open(result_path) as src:
# Read data
data = src.read(1)
# Create color map
n_classes = len(np.unique(data))
cmap = plt.cm.get_cmap("tab10", n_classes)
# Create RGBA image
fig, ax = plt.subplots(figsize=(1, 1))
ax.imshow(data, cmap=cmap, vmin=0, vmax=n_classes - 1)
ax.axis("off")
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
buf.seek(0)
plt.close(fig)
return Response(content=buf.getvalue(), media_type="image/png")
except ImportError:
raise HTTPException(status_code=501, detail="Tile generation requires matplotlib")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health():
"""Health check endpoint."""
return {"status": "healthy", "version": "1.0.0"}
@app.on_event("shutdown")
async def cleanup():
"""Clean up temporary files on shutdown."""
for session_id, session in _classifiers.items():
temp_dir = session.get("temp_dir")
if temp_dir and os.path.exists(temp_dir):
shutil.rmtree(temp_dir, ignore_errors=True)

994
static/index.html Normal file
View file

@ -0,0 +1,994 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>GIS Classification</title>
<!-- Leaflet CSS -->
<link rel="stylesheet" href="https://unpkg.com/leaflet@1.9.4/dist/leaflet.css" />
<style>
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
background: #1a1a2e;
color: #eee;
min-height: 100vh;
}
.container {
display: flex;
height: 100vh;
}
.sidebar {
width: 400px;
background: #16213e;
padding: 20px;
overflow-y: auto;
border-right: 1px solid #0f3460;
}
.sidebar h1 {
font-size: 1.5rem;
margin-bottom: 10px;
color: #e94560;
}
.sidebar h2 {
font-size: 1.1rem;
margin: 20px 0 10px;
color: #0f3460;
border-bottom: 2px solid #e94560;
padding-bottom: 5px;
}
.form-group {
margin-bottom: 15px;
}
.form-group label {
display: block;
margin-bottom: 5px;
font-size: 0.9rem;
color: #aaa;
}
.form-group input,
.form-group select {
width: 100%;
padding: 10px;
border: 1px solid #0f3460;
border-radius: 5px;
background: #1a1a2e;
color: #eee;
font-size: 0.9rem;
}
.form-group input[type="file"] {
padding: 8px;
}
.form-group small {
display: block;
margin-top: 3px;
color: #666;
font-size: 0.8rem;
}
.btn {
width: 100%;
padding: 12px;
border: none;
border-radius: 5px;
font-size: 1rem;
cursor: pointer;
transition: all 0.3s;
}
.btn-primary {
background: #e94560;
color: white;
}
.btn-primary:hover {
background: #ff6b6b;
}
.btn-primary:disabled {
background: #555;
cursor: not-allowed;
}
.btn-secondary {
background: #0f3460;
color: white;
margin-top: 10px;
}
.btn-secondary:hover {
background: #1a4a7a;
}
.map-container {
flex: 1;
position: relative;
}
#map {
width: 100%;
height: 100%;
}
.status {
padding: 10px;
border-radius: 5px;
margin-top: 15px;
font-size: 0.9rem;
}
.status.info {
background: #0f3460;
border-left: 3px solid #4fc3f7;
}
.status.success {
background: #1b5e20;
border-left: 3px solid #4caf50;
}
.status.error {
background: #b71c1c;
border-left: 3px solid #f44336;
}
.status.loading {
background: #0f3460;
border-left: 3px solid #ff9800;
animation: pulse 1.5s infinite;
}
@keyframes pulse {
0%,
100% {
opacity: 1;
}
50% {
opacity: 0.7;
}
}
.metrics {
background: #0f3460;
padding: 15px;
border-radius: 5px;
margin-top: 15px;
}
.metrics-grid {
display: grid;
grid-template-columns: 1fr 1fr;
gap: 10px;
}
.metric {
text-align: center;
}
.metric-value {
font-size: 1.5rem;
font-weight: bold;
color: #e94560;
}
.metric-label {
font-size: 0.8rem;
color: #aaa;
}
.classes-list {
display: flex;
flex-wrap: wrap;
gap: 5px;
margin-top: 10px;
}
.class-badge {
padding: 5px 10px;
border-radius: 15px;
font-size: 0.8rem;
background: #e94560;
color: white;
}
.hidden {
display: none;
}
.file-input-wrapper {
position: relative;
overflow: hidden;
display: inline-block;
width: 100%;
}
.legend {
position: absolute;
bottom: 30px;
right: 10px;
background: rgba(22, 33, 62, 0.9);
padding: 15px;
border-radius: 5px;
z-index: 1000;
max-width: 200px;
}
.legend h4 {
margin-bottom: 10px;
font-size: 0.9rem;
color: #e94560;
}
.legend-item {
display: flex;
align-items: center;
margin-bottom: 5px;
font-size: 0.8rem;
}
.legend-color {
width: 20px;
height: 20px;
margin-right: 10px;
border-radius: 3px;
}
.class-editor-item {
display: flex;
align-items: center;
gap: 10px;
margin-bottom: 10px;
padding: 8px;
background: #1a1a2e;
border-radius: 5px;
}
.class-editor-item input[type="color"] {
width: 40px;
height: 30px;
border: none;
cursor: pointer;
}
.class-editor-item input[type="text"] {
flex: 1;
padding: 8px;
border: 1px solid #0f3460;
border-radius: 3px;
background: #16213e;
color: #eee;
}
.class-editor-item span {
min-width: 60px;
color: #aaa;
font-size: 0.85rem;
}
</style>
</head>
<body>
<div class="container">
<div class="sidebar">
<h1>🛰️ GIS Classification</h1>
<p style="color: #888; font-size: 0.9rem; margin-bottom: 20px;">
Land cover classification using machine learning
</p>
<h2>1. Upload Data</h2>
<div class="form-group">
<label for="raster">Raster File (GeoTIFF)</label>
<input type="file" id="raster" accept=".tif,.tiff">
<small>Select Landsat or similar multispectral imagery</small>
</div>
<div class="form-group">
<label for="vector">Shapefile Files (select all: .shp, .shx, .dbf, .prj)</label>
<input type="file" id="vector" accept=".shp,.shx,.dbf,.prj" multiple onchange="showSelectedFiles()">
<small>Hold Ctrl/Cmd to select multiple files. Required: .shp, .shx, .dbf</small>
<div id="fileList" style="margin-top: 8px; font-size: 0.85rem; color: #4fc3f7;"></div>
</div>
<h2>2. Configure</h2>
<div class="form-group">
<label for="strategy">Classification Strategy</label>
<select id="strategy">
<option value="random_forest">Random Forest (Recommended)</option>
<option value="svm">Support Vector Machine</option>
<option value="logistic_regression">Logistic Regression</option>
<option value="mle">Maximum Likelihood (MLE)</option>
</select>
</div>
<div class="form-group">
<label for="classColumn">Class Column</label>
<input type="text" id="classColumn" value="class" placeholder="Column name in shapefile">
</div>
<div class="form-group">
<label for="nEstimators">Estimators (Random Forest)</label>
<input type="number" id="nEstimators" value="100" min="1" max="500">
</div>
<div class="form-group">
<label for="kernel">SVM Kernel</label>
<select id="kernel">
<option value="linear">Linear (Fast)</option>
<option value="rbf">RBF (Accurate but slow)</option>
</select>
</div>
<h2>3. Train & Classify</h2>
<button class="btn btn-primary" id="trainBtn" onclick="train()">
Train Model
</button>
<div id="status" class="status hidden"></div>
<div id="metrics" class="metrics hidden">
<div class="metrics-grid">
<div class="metric">
<div class="metric-value" id="accuracyValue">-</div>
<div class="metric-label">Accuracy</div>
</div>
<div class="metric">
<div class="metric-value" id="kappaValue">-</div>
<div class="metric-label">Cohen's Kappa</div>
</div>
<div class="metric">
<div class="metric-value" id="trainSamplesValue">-</div>
<div class="metric-label">Train Samples</div>
</div>
<div class="metric">
<div class="metric-value" id="valSamplesValue">-</div>
<div class="metric-label">Validation Samples</div>
</div>
</div>
<div id="classesContainer" class="hidden">
<label style="font-size: 0.8rem; color: #aaa; margin-top: 10px; display: block;">Classes:</label>
<div class="classes-list" id="classesList"></div>
</div>
</div>
<h2>4. Class Templates</h2>
<div class="form-group">
<label for="templateSelect">Load Template</label>
<select id="templateSelect" onchange="loadTemplate()">
<option value="">-- Select Template --</option>
</select>
<small>Pre-defined class names and colors</small>
</div>
<div style="display: flex; gap: 10px; margin-bottom: 15px;">
<button class="btn btn-secondary" onclick="saveTemplate()" style="flex: 1;">
Save Template
</button>
<button class="btn btn-secondary" onclick="exportTemplate()" style="flex: 1;">
Export
</button>
<label class="btn btn-secondary" style="flex: 1; text-align: center; cursor: pointer;">
Import
<input type="file" id="importTemplate" accept=".json" onchange="importTemplate(this)"
style="display: none;">
</label>
</div>
<div id="classEditor" class="hidden"
style="background: #0f3460; padding: 15px; border-radius: 5px; margin-bottom: 15px;">
<h4 style="margin-bottom: 10px; color: #e94560;">Edit Class Names & Colors</h4>
<div id="classEditorItems"></div>
<button class="btn btn-primary" onclick="applyClassTemplate()" style="margin-top: 15px;">
Apply to Map
</button>
</div>
<div class="form-group">
<label for="opacitySlider">Layer Opacity: <span id="opacityValue">70%</span></label>
<input type="range" id="opacitySlider" min="0" max="100" value="70" oninput="updateOpacity()">
<small>Adjust classification layer transparency</small>
</div>
<button class="btn btn-secondary hidden" id="predictBtn" onclick="predict()">
Classify Raster
</button>
<a id="downloadLink" class="btn btn-secondary hidden"
style="text-align: center; text-decoration: none; display: block;" download>
Download Result
</a>
</div>
<div class="map-container">
<div id="map"></div>
<div id="legend" class="legend hidden">
<h4>Classes</h4>
<div id="legendItems"></div>
</div>
</div>
</div>
<!-- Leaflet JS -->
<script src="https://unpkg.com/leaflet@1.9.4/dist/leaflet.js"></script>
<!-- GeoRasterLayer for Leaflet -->
<script src="https://unpkg.com/georaster"></script>
<script src="https://unpkg.com/georaster-layer-for-leaflet"></script>
<script>
// Initialize map
const map = L.map('map', {
attributionControl: false
}).setView([0, 0], 3);
// Add satellite base layer (Esri World Imagery)
L.tileLayer('https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}', {
attribution: 'Tiles &copy; Esri &mdash; Source: Esri, i-cubed, USDA, USGS, AEX, GeoEye, Getmapping, Aerogrid, IGN, IGP, UPR-EGP, and the GIS User Community',
maxZoom: 19
}).addTo(map);
// Add custom attribution control without country flags
L.control.attribution({
prefix: ''
}).addTo(map);
let sessionId = null;
let geoRasterLayer = null;
let currentClasses = [];
let currentPalette = [];
let classTemplates = {};
let customColorMapping = {};
let currentGeoRaster = null;
// Load templates from localStorage on init
loadTemplatesFromStorage();
function showSelectedFiles() {
const files = document.getElementById('vector').files;
const fileList = document.getElementById('fileList');
if (files.length === 0) {
fileList.innerHTML = '';
return;
}
const fileNames = Array.from(files).map(f => f.name).sort();
fileList.innerHTML = '<strong>Selected:</strong> ' + fileNames.join(', ');
}
function showStatus(message, type = 'info') {
const status = document.getElementById('status');
status.textContent = message;
status.className = `status ${type}`;
status.classList.remove('hidden');
}
function hideStatus() {
document.getElementById('status').classList.add('hidden');
}
async function train() {
const rasterFile = document.getElementById('raster').files[0];
const vectorFiles = document.getElementById('vector').files;
if (!rasterFile) {
showStatus('Please select a raster file', 'error');
return;
}
if (!vectorFiles || vectorFiles.length === 0) {
showStatus('Please select shapefile files (.shp, .shx, .dbf)', 'error');
return;
}
// Validate shapefile files
const extensions = Array.from(vectorFiles).map(f => f.name.toLowerCase().split('.').pop());
const required = ['shp', 'shx', 'dbf'];
const missing = required.filter(ext => !extensions.includes(ext));
if (missing.length > 0) {
showStatus(`Missing required shapefile files: ${missing.map(e => '.' + e).join(', ')}`, 'error');
return;
}
const formData = new FormData();
formData.append('raster', rasterFile);
// Append all shapefile files
for (let i = 0; i < vectorFiles.length; i++) {
formData.append('vector_files', vectorFiles[i]);
}
formData.append('strategy', document.getElementById('strategy').value);
formData.append('class_column', document.getElementById('classColumn').value);
formData.append('n_estimators', document.getElementById('nEstimators').value);
formData.append('kernel', document.getElementById('kernel').value);
formData.append('test_size', '0.2');
formData.append('random_state', '42');
const trainBtn = document.getElementById('trainBtn');
trainBtn.disabled = true;
trainBtn.textContent = 'Training...';
showStatus('Training classifier... This may take a few minutes.', 'loading');
try {
const response = await fetch('/train', {
method: 'POST',
body: formData
});
const data = await response.json();
if (!response.ok) {
throw new Error(data.detail || 'Training failed');
}
sessionId = data.session_id;
// Display metrics
const metrics = data.metrics;
document.getElementById('accuracyValue').textContent = (metrics.accuracy * 100).toFixed(1) + '%';
document.getElementById('kappaValue').textContent = metrics.kappa.toFixed(4);
document.getElementById('trainSamplesValue').textContent = metrics.train_samples.toLocaleString();
document.getElementById('valSamplesValue').textContent = metrics.val_samples.toLocaleString();
// Display classes
const classesList = document.getElementById('classesList');
classesList.innerHTML = metrics.classes.map(c =>
`<span class="class-badge">Class ${c}</span>`
).join('');
document.getElementById('classesContainer').classList.remove('hidden');
document.getElementById('metrics').classList.remove('hidden');
document.getElementById('predictBtn').classList.remove('hidden');
showStatus(`Training completed! Accuracy: ${(metrics.accuracy * 100).toFixed(1)}%, Kappa: ${metrics.kappa.toFixed(4)}. You can classify raster and customize visualization with templates!`, 'success');
} catch (error) {
showStatus(`Error: ${error.message}`, 'error');
} finally {
trainBtn.disabled = false;
trainBtn.textContent = 'Train Model';
}
}
async function predict() {
if (!sessionId) {
showStatus('No trained model available', 'error');
return;
}
const predictBtn = document.getElementById('predictBtn');
predictBtn.disabled = true;
predictBtn.textContent = 'Classifying...';
showStatus('Classifying raster... This may take a moment.', 'loading');
const formData = new FormData();
formData.append('session_id', sessionId);
formData.append('output_format', 'geotiff');
try {
const response = await fetch('/predict', {
method: 'POST',
body: formData
});
const data = await response.json();
if (!response.ok) {
throw new Error(data.detail || 'Classification failed');
}
// Update map view to show raster bounds
const bounds = [
[data.bounds.bottom, data.bounds.left],
[data.bounds.top, data.bounds.right]
];
map.fitBounds(bounds);
// Load and display classified raster
showStatus('Loading classification result...', 'loading');
// Download and display the classified raster
const downloadResponse = await fetch(`/result/${sessionId}/download`);
const blob = await downloadResponse.blob();
const arrayBuffer = await blob.arrayBuffer();
// Parse GeoTIFF
const geoRaster = await parseGeoraster(arrayBuffer);
currentGeoRaster = geoRaster;
// Remove existing layer
if (geoRasterLayer) {
map.removeLayer(geoRasterLayer);
}
// Create color palette for classes
const palette = generatePalette(data.classes.length);
// Add GeoRaster layer
geoRasterLayer = createGeoRasterLayer(geoRaster, palette);
geoRasterLayer.addTo(map);
// Apply initial opacity from slider
const initialOpacity = document.getElementById('opacitySlider').value / 100;
geoRasterLayer.setOpacity(initialOpacity);
// Zoom to layer bounds
if (geoRasterLayer.getBounds()) {
map.fitBounds(geoRasterLayer.getBounds());
}
// Update legend
currentClasses = data.classes;
currentPalette = palette;
updateLegend(data.classes, palette);
// Show class editor for customizing names
document.getElementById('classEditor').classList.remove('hidden');
const editorItems = document.getElementById('classEditorItems');
editorItems.innerHTML = '';
data.classes.forEach((cls, i) => {
const item = document.createElement('div');
item.className = 'class-editor-item';
item.innerHTML = `
<span>Class ${cls}</span>
<input type="color" value="${palette[i]}" data-class="${cls}" data-type="color">
<input type="text" value="Class ${cls}" data-class="${cls}" data-type="name" placeholder="Class name">
`;
editorItems.appendChild(item);
});
// Update download link
const downloadLink = document.getElementById('downloadLink');
downloadLink.href = URL.createObjectURL(blob);
downloadLink.download = 'classified_result.tif';
downloadLink.classList.remove('hidden');
showStatus('Classification complete!', 'success');
} catch (error) {
showStatus(`Error: ${error.message}`, 'error');
console.error(error);
} finally {
predictBtn.disabled = false;
predictBtn.textContent = 'Classify Raster';
}
}
function generatePalette(nClasses) {
// Generate distinct colors for classes
const colors = [
'#e41a1c', '#377eb8', '#4daf4a', '#984ea3', '#ff7f00',
'#ffff33', '#a65628', '#f781bf', '#999999', '#66c2a5',
'#fc8d62', '#8da0cb', '#e78ac3', '#a6d854', '#ffd92f',
'#e5c494', '#b3b3b3', '#1b9e77', '#d95f02', '#7570b3'
];
const palette = [];
for (let i = 0; i < nClasses; i++) {
palette.push(colors[i % colors.length]);
}
return palette;
}
function updateLegend(classes, palette) {
const legend = document.getElementById('legend');
const legendItems = document.getElementById('legendItems');
legendItems.innerHTML = classes.map((cls, i) => `
<div class="legend-item">
<div class="legend-color" style="background: ${palette[i]}"></div>
<span>Class ${cls}</span>
</div>
`).join('');
legend.classList.remove('hidden');
}
// Helper function to parse GeoTIFF
async function parseGeoraster(arrayBuffer) {
return new Promise((resolve, reject) => {
parseGeorasterInternal(arrayBuffer).then(resolve).catch(reject);
});
}
function parseGeorasterInternal(arrayBuffer) {
return new Promise((resolve, reject) => {
try {
const georaster = new GeoRaster(arrayBuffer);
resolve(georaster);
} catch (e) {
// Fallback: use global parseGeoraster if available
if (typeof window.parseGeoraster === 'function') {
window.parseGeoraster(arrayBuffer).then(resolve).catch(reject);
} else {
reject(e);
}
}
});
}
// ==================== CLASS TEMPLATE FUNCTIONS ====================
function loadTemplatesFromStorage() {
const stored = localStorage.getItem('gis_class_templates');
if (stored) {
try {
classTemplates = JSON.parse(stored);
} catch (e) {
classTemplates = {};
}
}
// Add default templates if none exist
if (Object.keys(classTemplates).length === 0) {
classTemplates = {
'Rainbow': {
'1': { name: 'Red', color: '#FF0000' },
'2': { name: 'Orange', color: '#FF7F00' },
'3': { name: 'Yellow', color: '#FFFF00' },
'4': { name: 'Green', color: '#00FF00' },
'5': { name: 'Blue', color: '#0000FF' },
'6': { name: 'Violet', color: '#8B00FF' }
},
};
saveTemplatesToStorage();
}
updateTemplateSelect();
}
function saveTemplatesToStorage() {
localStorage.setItem('gis_class_templates', JSON.stringify(classTemplates));
}
function updateTemplateSelect() {
const select = document.getElementById('templateSelect');
select.innerHTML = '<option value="">-- Select Template --</option>';
Object.keys(classTemplates).forEach(name => {
const option = document.createElement('option');
option.value = name;
option.textContent = name;
select.appendChild(option);
});
}
function loadTemplate() {
const select = document.getElementById('templateSelect');
const templateName = select.value;
if (!templateName || !classTemplates[templateName]) return;
const template = classTemplates[templateName];
const editorItems = document.getElementById('classEditorItems');
editorItems.innerHTML = '';
// Get current classes or use template keys
const classesToShow = currentClasses.length > 0 ? currentClasses : Object.keys(template).sort((a, b) => parseInt(a) - parseInt(b));
classesToShow.forEach(cls => {
const item = document.createElement('div');
item.className = 'class-editor-item';
const templateData = template[cls.toString()] || { name: `Class ${cls}`, color: '#808080' };
item.innerHTML = `
<span>Class ${cls}</span>
<input type="color" value="${templateData.color}" data-class="${cls}" data-type="color">
<input type="text" value="${templateData.name}" data-class="${cls}" data-type="name" placeholder="Class name">
`;
editorItems.appendChild(item);
});
document.getElementById('classEditor').classList.remove('hidden');
// Auto-apply the template
applyClassTemplate();
}
function saveTemplate() {
const name = prompt('Enter template name:');
if (!name) return;
if (!currentClasses.length) {
showStatus('No classes available. Run classification first.', 'error');
return;
}
const template = {};
const editorItems = document.querySelectorAll('.class-editor-item');
editorItems.forEach(item => {
const classId = item.querySelector('input[data-type="color"]').dataset.class;
const color = item.querySelector('input[data-type="color"]').value;
const className = item.querySelector('input[data-type="name"]').value;
template[classId] = { name: className, color: color };
});
classTemplates[name] = template;
saveTemplatesToStorage();
updateTemplateSelect();
document.getElementById('templateSelect').value = name;
showStatus(`Template "${name}" saved!`, 'success');
}
function exportTemplate() {
const select = document.getElementById('templateSelect');
const templateName = select.value;
if (!templateName || !classTemplates[templateName]) {
showStatus('Please select a template to export', 'error');
return;
}
const template = { name: templateName, classes: classTemplates[templateName] };
const blob = new Blob([JSON.stringify(template, null, 2)], { type: 'application/json' });
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = `${templateName.replace(/\s+/g, '_')}_template.json`;
a.click();
URL.revokeObjectURL(url);
}
function importTemplate(input) {
const file = input.files[0];
if (!file) return;
const reader = new FileReader();
reader.onload = function (e) {
try {
const template = JSON.parse(e.target.result);
if (!template.name || !template.classes) {
throw new Error('Invalid template format');
}
classTemplates[template.name] = template.classes;
saveTemplatesToStorage();
updateTemplateSelect();
document.getElementById('templateSelect').value = template.name;
showStatus(`Template "${template.name}" imported!`, 'success');
} catch (err) {
showStatus('Invalid template file: ' + err.message, 'error');
}
};
reader.readAsText(file);
input.value = '';
}
function applyClassTemplate() {
const editorItems = document.querySelectorAll('.class-editor-item');
const customMapping = {};
editorItems.forEach(item => {
const classId = item.querySelector('input[data-type="color"]').dataset.class;
const color = item.querySelector('input[data-type="color"]').value;
const className = item.querySelector('input[data-type="name"]').value;
customMapping[classId] = { name: className, color: color };
});
// Update legend with custom names
updateLegendWithCustomNames(currentClasses, currentPalette, customMapping);
showStatus('Class template applied!', 'success');
}
function updateLegendWithCustomNames(classes, palette, customMapping) {
const legend = document.getElementById('legend');
const legendItems = document.getElementById('legendItems');
legendItems.innerHTML = classes.map((cls, i) => {
const custom = customMapping[cls.toString()] || customMapping[i.toString()];
const name = custom ? custom.name : `Class ${cls}`;
const color = custom ? custom.color : palette[i];
return `
<div class="legend-item">
<div class="legend-color" style="background: ${color}"></div>
<span>${name}</span>
</div>
`;
}).join('');
legend.classList.remove('hidden');
}
function createGeoRasterLayer(geoRaster, palette, customMapping = {}) {
const layer = new GeoRasterLayer({
georaster: geoRaster,
opacity: 0.7,
pixelValuesToColorFn: function (pixelValues) {
const value = pixelValues[0];
if (value === null || value === undefined || value === geoRaster.nodata) {
return null;
}
const classIndex = Math.floor(value) - 1;
if (classIndex < 0 || classIndex >= palette.length) {
return null;
}
// Use custom color if available, otherwise use palette
const className = (classIndex + 1).toString();
if (customMapping[className] && customMapping[className].color) {
return customMapping[className].color;
}
return palette[classIndex];
},
resolution: Math.max(64, Math.min(256, geoRaster.height / 10))
});
// Force Leaflet to treat this as a new layer
layer._leaflet_id = 'georaster_' + Date.now();
return layer;
}
function applyClassTemplate() {
const editorItems = document.querySelectorAll('.class-editor-item');
const customMapping = {};
editorItems.forEach(item => {
const classId = item.querySelector('input[data-type="color"]').dataset.class;
const color = item.querySelector('input[data-type="color"]').value;
const className = item.querySelector('input[data-type="name"]').value;
customMapping[classId] = { name: className, color: color };
});
// Store custom mapping
window.currentCustomMapping = customMapping;
// Update legend first
updateLegendWithCustomNames(currentClasses, currentPalette, customMapping);
// Re-create GeoRaster layer with custom colors
if (currentGeoRaster && geoRasterLayer) {
// Remove old layer
map.removeLayer(geoRasterLayer);
// Create new layer with custom colors
geoRasterLayer = createGeoRasterLayer(currentGeoRaster, currentPalette, customMapping);
geoRasterLayer.addTo(map);
// Force complete redraw by triggering multiple refresh methods
setTimeout(() => {
// Clear cache of the layer
if (geoRasterLayer.clearCache) {
geoRasterLayer.clearCache();
}
// Try to redraw the layer
if (geoRasterLayer.redraw) {
geoRasterLayer.redraw();
}
// Trigger map move event to force tile refresh
map.fire('moveend');
}, 50);
}
showStatus('Class template applied! Zoom in/out if colors do not update immediately.', 'success');
}
function updateOpacity() {
const slider = document.getElementById('opacitySlider');
const valueDisplay = document.getElementById('opacityValue');
const opacity = slider.value / 100;
valueDisplay.textContent = slider.value + '%';
if (geoRasterLayer) {
geoRasterLayer.setOpacity(opacity);
}
}
</script>
</body>
</html>