perf: switch SVM default kernel to linear for better performance
This commit is contained in:
parent
af365cfe68
commit
e1f727831f
2 changed files with 21 additions and 11 deletions
5
main.py
5
main.py
|
|
@ -35,11 +35,10 @@ STRATEGY = RandomForestStrategy(
|
|||
random_state=42,
|
||||
)
|
||||
|
||||
# Option 2: Support Vector Machine
|
||||
# # Option 2: Support Vector Machine
|
||||
# STRATEGY = SVMStrategy(
|
||||
# kernel="rbf",
|
||||
# kernel="linear", # Fast prediction; use 'rbf' for better accuracy but much slower
|
||||
# C=1.0,
|
||||
# gamma="scale",
|
||||
# random_state=42,
|
||||
# )
|
||||
|
||||
|
|
|
|||
|
|
@ -51,11 +51,15 @@ class RandomForestStrategy(ClassificationStrategy):
|
|||
|
||||
|
||||
class SVMStrategy(ClassificationStrategy):
|
||||
"""Support Vector Machine classification strategy."""
|
||||
"""Support Vector Machine classification strategy.
|
||||
|
||||
Note: For large rasters, consider using kernel='linear' for faster prediction.
|
||||
The RBF kernel can be very slow during prediction on large images.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel: str = "rbf",
|
||||
kernel: str = "linear",
|
||||
C: float = 1.0,
|
||||
gamma: str = "scale",
|
||||
random_state: int = 42,
|
||||
|
|
@ -70,19 +74,26 @@ class SVMStrategy(ClassificationStrategy):
|
|||
C=C,
|
||||
gamma=gamma,
|
||||
random_state=random_state,
|
||||
probability=True,
|
||||
probability=False, # Disabled for performance on large rasters
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def train(self, X: np.ndarray, y: np.ndarray) -> None:
|
||||
self._clf.fit(X, y)
|
||||
|
||||
|
||||
def predict(self, X: np.ndarray) -> np.ndarray:
|
||||
return self._clf.predict(X)
|
||||
|
||||
|
||||
def predict_proba(self, X: np.ndarray) -> np.ndarray:
|
||||
return self._clf.predict_proba(X)
|
||||
|
||||
# probability=True significantly slows down both training and prediction
|
||||
# Return class-based probabilities as fallback
|
||||
predictions = self._clf.predict(X)
|
||||
n_classes = len(self._clf.classes_)
|
||||
proba = np.zeros((X.shape[0], n_classes))
|
||||
for i, cls in enumerate(self._clf.classes_):
|
||||
proba[:, i] = (predictions == cls).astype(float)
|
||||
return proba
|
||||
|
||||
def get_params(self) -> dict[str, Any]:
|
||||
return {
|
||||
"kernel": self.kernel,
|
||||
|
|
@ -90,7 +101,7 @@ class SVMStrategy(ClassificationStrategy):
|
|||
"gamma": self.gamma,
|
||||
"random_state": self.random_state,
|
||||
}
|
||||
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "SVM"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue