perf: switch SVM default kernel to linear for better performance
This commit is contained in:
parent
af365cfe68
commit
e1f727831f
2 changed files with 21 additions and 11 deletions
5
main.py
5
main.py
|
|
@ -35,11 +35,10 @@ STRATEGY = RandomForestStrategy(
|
||||||
random_state=42,
|
random_state=42,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Option 2: Support Vector Machine
|
# # Option 2: Support Vector Machine
|
||||||
# STRATEGY = SVMStrategy(
|
# STRATEGY = SVMStrategy(
|
||||||
# kernel="rbf",
|
# kernel="linear", # Fast prediction; use 'rbf' for better accuracy but much slower
|
||||||
# C=1.0,
|
# C=1.0,
|
||||||
# gamma="scale",
|
|
||||||
# random_state=42,
|
# random_state=42,
|
||||||
# )
|
# )
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -51,11 +51,15 @@ class RandomForestStrategy(ClassificationStrategy):
|
||||||
|
|
||||||
|
|
||||||
class SVMStrategy(ClassificationStrategy):
|
class SVMStrategy(ClassificationStrategy):
|
||||||
"""Support Vector Machine classification strategy."""
|
"""Support Vector Machine classification strategy.
|
||||||
|
|
||||||
|
Note: For large rasters, consider using kernel='linear' for faster prediction.
|
||||||
|
The RBF kernel can be very slow during prediction on large images.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
kernel: str = "rbf",
|
kernel: str = "linear",
|
||||||
C: float = 1.0,
|
C: float = 1.0,
|
||||||
gamma: str = "scale",
|
gamma: str = "scale",
|
||||||
random_state: int = 42,
|
random_state: int = 42,
|
||||||
|
|
@ -70,19 +74,26 @@ class SVMStrategy(ClassificationStrategy):
|
||||||
C=C,
|
C=C,
|
||||||
gamma=gamma,
|
gamma=gamma,
|
||||||
random_state=random_state,
|
random_state=random_state,
|
||||||
probability=True,
|
probability=False, # Disabled for performance on large rasters
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
def train(self, X: np.ndarray, y: np.ndarray) -> None:
|
def train(self, X: np.ndarray, y: np.ndarray) -> None:
|
||||||
self._clf.fit(X, y)
|
self._clf.fit(X, y)
|
||||||
|
|
||||||
def predict(self, X: np.ndarray) -> np.ndarray:
|
def predict(self, X: np.ndarray) -> np.ndarray:
|
||||||
return self._clf.predict(X)
|
return self._clf.predict(X)
|
||||||
|
|
||||||
def predict_proba(self, X: np.ndarray) -> np.ndarray:
|
def predict_proba(self, X: np.ndarray) -> np.ndarray:
|
||||||
return self._clf.predict_proba(X)
|
# probability=True significantly slows down both training and prediction
|
||||||
|
# Return class-based probabilities as fallback
|
||||||
|
predictions = self._clf.predict(X)
|
||||||
|
n_classes = len(self._clf.classes_)
|
||||||
|
proba = np.zeros((X.shape[0], n_classes))
|
||||||
|
for i, cls in enumerate(self._clf.classes_):
|
||||||
|
proba[:, i] = (predictions == cls).astype(float)
|
||||||
|
return proba
|
||||||
|
|
||||||
def get_params(self) -> dict[str, Any]:
|
def get_params(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"kernel": self.kernel,
|
"kernel": self.kernel,
|
||||||
|
|
@ -90,7 +101,7 @@ class SVMStrategy(ClassificationStrategy):
|
||||||
"gamma": self.gamma,
|
"gamma": self.gamma,
|
||||||
"random_state": self.random_state,
|
"random_state": self.random_state,
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return "SVM"
|
return "SVM"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue