linfa_svm/
hyperparams.rs

1use crate::{KernelMethod, SolverParams, Svm, SvmError};
2use linfa::{platt_scaling::PlattParams, Float, ParamGuard, Platt};
3use linfa_kernel::{Kernel, KernelParams};
4use std::marker::PhantomData;
5
6/// SVM Hyperparameters
7///
8/// The SVM fitting process can be controlled in different ways. For classification the C and Nu
9/// parameters control the ratio of support vectors and accuracy, eps controls the required
10/// precision. After setting the desired parameters a model can be fitted by calling `fit`.
11///
12/// You can specify the expected return type with the turbofish syntax. If you want to enable
13/// Platt-Scaling for proper probability values, then use:
14/// ```no_run
15/// use linfa_svm::Svm;
16/// use linfa::dataset::Pr;
17/// let model = Svm::<f64, Pr>::params();
18/// ```
19/// or `bool` if you only wants to know the binary decision:
20/// ```no_run
21/// use linfa_svm::Svm;
22/// let model = Svm::<f64, bool>::params();
23/// ```
24///
25/// ## Example
26///
27/// ```ignore
28/// use linfa_svm::Svm;
29/// let model = Svm::<_, bool>::params()
30///     .eps(0.1f64)
31///     .shrinking(true)
32///     .nu_weight(0.1)
33///     .fit(&dataset);
34/// ```
35///
36#[derive(Debug, Clone, PartialEq)]
37pub struct SvmValidParams<F: Float, T> {
38    c: Option<(F, F)>,
39    nu: Option<(F, F)>,
40    solver_params: SolverParams<F>,
41    phantom: PhantomData<T>,
42    kernel: KernelParams<F>,
43    platt: PlattParams<F, ()>,
44}
45
46impl<F: Float, T> SvmValidParams<F, T> {
47    pub fn c(&self) -> Option<(F, F)> {
48        self.c
49    }
50
51    pub fn nu(&self) -> Option<(F, F)> {
52        self.nu
53    }
54
55    pub fn solver_params(&self) -> &SolverParams<F> {
56        &self.solver_params
57    }
58
59    pub fn kernel_params(&self) -> &KernelParams<F> {
60        &self.kernel
61    }
62
63    pub fn platt_params(&self) -> &PlattParams<F, ()> {
64        &self.platt
65    }
66}
67
68#[derive(Debug, Clone, PartialEq)]
69pub struct SvmParams<F: Float, T>(SvmValidParams<F, T>);
70
71impl<F: Float, T> SvmParams<F, T> {
72    /// Create hyper parameter set
73    ///
74    /// This creates a `SvmParams` and sets it to the default values:
75    ///  * C values of (1, 1)
76    ///  * Eps of 1e-7
77    ///  * No shrinking
78    ///  * Linear kernel
79    pub fn new() -> Self {
80        Self(SvmValidParams {
81            c: Some((F::one(), F::one())),
82            nu: None,
83            solver_params: SolverParams {
84                eps: F::cast(1e-7),
85                shrinking: false,
86            },
87            phantom: PhantomData,
88            kernel: Kernel::params().method(KernelMethod::Linear),
89            platt: Platt::params(),
90        })
91    }
92
93    /// Set stopping condition
94    ///
95    /// This parameter controls the stopping condition. It checks whether the sum of gradients of
96    /// the max violating pair is below this threshold and then stops the optimization proces.
97    pub fn eps(mut self, new_eps: F) -> Self {
98        self.0.solver_params.eps = new_eps;
99        self
100    }
101
102    /// Shrink active variable set
103    ///
104    /// This parameter controls whether the active variable set is shrinked or not. This can speed
105    /// up the optimization process, but may degredade the solution performance.
106    pub fn shrinking(mut self, shrinking: bool) -> Self {
107        self.0.solver_params.shrinking = shrinking;
108        self
109    }
110
111    /// Set the kernel to use for training
112    ///
113    /// This parameter specifies a mapping of input records to a new feature space by means
114    /// of the distance function between any couple of points mapped to such new space.
115    /// The SVM then applies a linear separation in the new feature space that may result in
116    /// a non linear partitioning of the original input space, thus increasing the expressiveness of
117    /// this model. To use the "base" SVM model it suffices to choose a `Linear` kernel.
118    pub fn with_kernel_params(mut self, kernel: KernelParams<F>) -> Self {
119        self.0.kernel = kernel;
120        self
121    }
122
123    /// Set the platt params for probability calibration
124    pub fn with_platt_params(mut self, platt: PlattParams<F, ()>) -> Self {
125        self.0.platt = platt;
126        self
127    }
128
129    /// Sets the model to use the Gaussian kernel. For this kernel the
130    /// distance between two points is computed as: `d(x, x') = exp(-norm(x - x')/eps)`
131    pub fn gaussian_kernel(mut self, eps: F) -> Self {
132        self.0.kernel = Kernel::params().method(KernelMethod::Gaussian(eps));
133        self
134    }
135
136    /// Sets the model to use the Polynomial kernel. For this kernel the
137    /// distance between two points is computed as: `d(x, x') = (<x, x'> + constant)^(degree)`
138    pub fn polynomial_kernel(mut self, constant: F, degree: F) -> Self {
139        self.0.kernel = Kernel::params().method(KernelMethod::Polynomial(constant, degree));
140        self
141    }
142
143    /// Sets the model to use the Linear kernel. For this kernel the
144    /// distance between two points is computed as : `d(x, x') = <x, x'>`
145    pub fn linear_kernel(mut self) -> Self {
146        self.0.kernel = Kernel::params().method(KernelMethod::Linear);
147        self
148    }
149}
150
151impl<F: Float, T> SvmParams<F, T> {
152    /// Set the C value for positive and negative samples.
153    pub fn pos_neg_weights(mut self, c_pos: F, c_neg: F) -> Self {
154        self.0.c = Some((c_pos, c_neg));
155        self.0.nu = None;
156        self
157    }
158
159    /// Set the Nu value for classification
160    ///
161    /// The Nu value should lie in range [0, 1] and sets the relation between support vectors and
162    /// solution performance.
163    pub fn nu_weight(mut self, nu: F) -> Self {
164        self.0.nu = Some((nu, nu));
165        self.0.c = None;
166        self
167    }
168}
169
170impl<F: Float> SvmParams<F, F> {
171    /// Set the C value for regression and solver epsilon stopping condition.
172    /// Loss epsilon value is fixed at 0.1.
173    #[deprecated(since = "0.7.2", note = "Use .c_svr() and .eps()")]
174    pub fn c_eps(mut self, c: F, eps: F) -> Self {
175        self.0.c = Some((c, F::cast(0.1)));
176        self.0.nu = None;
177        self.0.solver_params.eps = eps;
178        self
179    }
180
181    /// Set the Nu value for regression and solver epsilon stopping condition.
182    /// C value used value is fixed at 1.0.
183    #[deprecated(since = "0.7.2", note = "Use .nu_svr() and .eps()")]
184    pub fn nu_eps(mut self, nu: F, eps: F) -> Self {
185        self.0.nu = Some((nu, F::one()));
186        self.0.c = None;
187        self.0.solver_params.eps = eps;
188        self
189    }
190
191    /// Set the C value and optionnaly an epsilon value used in loss function (default 0.1) for regression
192    pub fn c_svr(mut self, c: F, loss_eps: Option<F>) -> Self {
193        self.0.c = Some((c, loss_eps.unwrap_or(F::cast(0.1))));
194        self.0.nu = None;
195        self
196    }
197
198    /// Set the Nu and optionally a C value (default 1.) for regression
199    pub fn nu_svr(mut self, nu: F, c: Option<F>) -> Self {
200        self.0.nu = Some((nu, c.unwrap_or(F::one())));
201        self.0.c = None;
202        self
203    }
204}
205
206impl<F: Float, L> Default for SvmParams<F, L> {
207    fn default() -> Self {
208        Self::new()
209    }
210}
211
212impl<F: Float, L> Svm<F, L> {
213    pub fn params() -> SvmParams<F, L> {
214        SvmParams::new()
215    }
216}
217
218impl<F: Float, L> ParamGuard for SvmParams<F, L> {
219    type Checked = SvmValidParams<F, L>;
220    type Error = SvmError;
221
222    fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
223        self.0.platt_params().check_ref()?;
224
225        if self.0.solver_params.eps.is_negative()
226            || self.0.solver_params.eps.is_nan()
227            || self.0.solver_params.eps.is_infinite()
228        {
229            return Err(SvmError::InvalidEps(
230                self.0.solver_params.eps.to_f32().unwrap(),
231            ));
232        }
233        if let Some((c1, c2)) = self.0.c {
234            if c1 <= F::zero() || c2 <= F::zero() {
235                return Err(SvmError::InvalidC((
236                    c1.to_f32().unwrap(),
237                    c2.to_f32().unwrap(),
238                )));
239            }
240        }
241        if let Some((nu, _)) = self.0.nu {
242            if nu <= F::zero() || nu > F::one() {
243                return Err(SvmError::InvalidNu(nu.to_f32().unwrap()));
244            }
245        }
246
247        Ok(&self.0)
248    }
249
250    fn check(self) -> Result<Self::Checked, Self::Error> {
251        self.check_ref()?;
252        Ok(self.0)
253    }
254}