linfa_svm/
regression.rs

1//! Support Vector Regression
2use linfa::{
3    dataset::{AsSingleTargets, DatasetBase},
4    traits::Fit,
5    traits::Transformer,
6    traits::{Predict, PredictInplace},
7};
8use linfa_kernel::Kernel;
9use ndarray::{Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Data, Ix2};
10
11use super::error::{Result, SvmError};
12use super::permutable_kernel::PermutableKernelRegression;
13use super::solver_smo::SolverState;
14use super::SolverParams;
15use super::{Float, Svm, SvmValidParams};
16
17/// Support Vector Regression with epsilon tolerance
18///
19/// This methods solves a binary SVC problem with a penalizing parameter epsilon between (0, inf). This defines the margin of tolerance, where no penalty is given to errors.
20///
21/// # Parameters
22///
23/// * `params` - Solver parameters (threshold etc.)
24/// * `kernel` - the kernel matrix `Q`
25/// * `targets` - the continuous targets `y_i`
26/// * `c` - C value for all targets
27/// * `p` - epsilon value for all targets
28pub fn fit_epsilon<F: Float>(
29    params: SolverParams<F>,
30    dataset: ArrayView2<F>,
31    kernel: Kernel<F>,
32    target: &[F],
33    c: F,
34    p: F,
35) -> Svm<F, F> {
36    let mut linear_term = vec![F::zero(); 2 * target.len()];
37    let mut targets = vec![true; 2 * target.len()];
38
39    for i in 0..target.len() {
40        linear_term[i] = p - target[i];
41        targets[i] = true;
42
43        linear_term[i + target.len()] = p + target[i];
44        targets[i + target.len()] = false;
45    }
46
47    let kernel = PermutableKernelRegression::new(kernel);
48    let solver = SolverState::new(
49        vec![F::zero(); 2 * target.len()],
50        linear_term,
51        targets.to_vec(),
52        dataset,
53        kernel,
54        vec![c; 2 * target.len()],
55        params,
56        false,
57    );
58
59    let res = solver.solve();
60
61    res.with_phantom()
62}
63
64/// Support Vector Regression with nu parameter
65///
66/// This methods solves a binary SVC problem with parameter nu, defining how many support vectors should be used. This parameter should be in range (0, 1).
67///
68/// # Parameters
69///
70/// * `params` - Solver parameters (threshold etc.)
71/// * `kernel` - the kernel matrix `Q`
72/// * `targets` - the continuous targets `y_i`
73/// * `c` - C value for all targets
74/// * `nu` - nu value for all targets
75pub fn fit_nu<F: Float>(
76    params: SolverParams<F>,
77    dataset: ArrayView2<F>,
78    kernel: Kernel<F>,
79    target: &[F],
80    nu: F,
81    c: F,
82) -> Svm<F, F> {
83    let mut alpha = vec![F::zero(); 2 * target.len()];
84    let mut linear_term = vec![F::zero(); 2 * target.len()];
85    let mut targets = vec![true; 2 * target.len()];
86
87    let mut sum = c * nu * F::cast(target.len()) / F::cast(2.0);
88    for i in 0..target.len() {
89        alpha[i] = F::min(sum, c);
90        alpha[i + target.len()] = F::min(sum, c);
91        sum -= alpha[i];
92
93        linear_term[i] = -target[i];
94        targets[i] = true;
95
96        linear_term[i + target.len()] = target[i];
97        targets[i + target.len()] = false;
98    }
99
100    let kernel = PermutableKernelRegression::new(kernel);
101    let solver = SolverState::new(
102        alpha,
103        linear_term,
104        targets.to_vec(),
105        dataset,
106        kernel,
107        vec![c; 2 * target.len()],
108        params,
109        false,
110    );
111
112    let res = solver.solve();
113
114    res.with_phantom()
115}
116
117/// Regress observations
118///
119/// Take a number of observations and project them to optimal continuous targets.
120macro_rules! impl_regression {
121    ($records:ty, $targets:ty, $f:ty) => {
122        impl Fit<$records, $targets, SvmError> for SvmValidParams<$f, $f> {
123            type Object = Svm<$f, $f>;
124
125            fn fit(&self, dataset: &DatasetBase<$records, $targets>) -> Result<Self::Object> {
126                let kernel = self.kernel_params().transform(dataset.records());
127                let target = dataset.as_single_targets();
128                let target = target.as_slice().unwrap();
129
130                let ret = match (self.c(), self.nu()) {
131                    (Some((c, p)), _) => fit_epsilon(
132                        self.solver_params().clone(),
133                        dataset.records().view(),
134                        kernel,
135                        target,
136                        c,
137                        p,
138                    ),
139                    (None, Some((nu, c))) => fit_nu(
140                        self.solver_params().clone(),
141                        dataset.records().view(),
142                        kernel,
143                        target,
144                        nu,
145                        c,
146                    ),
147                    _ => panic!("Set either C value or Nu value"),
148                };
149
150                Ok(ret)
151            }
152        }
153    };
154}
155
156impl_regression!(Array2<f32>, Array1<f32>, f32);
157impl_regression!(Array2<f64>, Array1<f64>, f64);
158impl_regression!(ArrayView2<'_, f32>, ArrayView1<'_, f32>, f32);
159impl_regression!(ArrayView2<'_, f64>, ArrayView1<'_, f64>, f64);
160
161macro_rules! impl_predict {
162    ( $($t:ty),* ) => {
163    $(
164        /// Predict a probability with a feature vector
165        impl Predict<Array1<$t>, $t> for Svm<$t, $t> {
166            fn predict(&self, data: Array1<$t>) -> $t {
167                self.weighted_sum(&data) - self.rho
168            }
169        }
170        /// Predict a probability with a feature vector
171        impl<'a> Predict<ArrayView1<'a, $t>, $t> for Svm<$t, $t> {
172            fn predict(&self, data: ArrayView1<'a, $t>) -> $t {
173                self.weighted_sum(&data) - self.rho
174            }
175        }
176
177        /// Classify observations
178        ///
179        /// This function takes a number of features and predicts target probabilities that they belong to
180        /// the positive class.
181        impl<D: Data<Elem = $t>> PredictInplace<ArrayBase<D, Ix2>, Array1<$t>> for Svm<$t, $t> {
182            fn predict_inplace(&'_ self, data: &ArrayBase<D, Ix2>, targets: &mut Array1<$t>) {
183                assert_eq!(data.nrows(), targets.len(), "The number of data points must match the number of output targets.");
184
185                for (data, target) in data.outer_iter().zip(targets.iter_mut()) {
186                    *target = self.weighted_sum(&data) - self.rho;
187                }
188            }
189
190            fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<$t> {
191                Array1::zeros(x.nrows())
192            }
193        }
194
195    ) *
196    }
197}
198
199impl_predict!(f32, f64);
200
201#[cfg(test)]
202pub mod tests {
203    use super::Svm;
204    use crate::error::Result;
205
206    use linfa::dataset::Dataset;
207    use linfa::metrics::SingleTargetRegression;
208    use linfa::traits::{Fit, Predict};
209    use linfa::DatasetBase;
210    use ndarray::{Array, Array1, Array2};
211
212    fn _check_model(model: Svm<f64, f64>, dataset: &DatasetBase<Array2<f64>, Array1<f64>>) {
213        println!("{model}");
214        let predicted = model.predict(dataset.records());
215        let err = predicted.mean_squared_error(&dataset).unwrap();
216        println!("err={err}");
217        assert!(predicted.mean_squared_error(&dataset).unwrap() < 1e-2);
218    }
219
220    #[test]
221    fn test_epsilon_regression_linear() -> Result<()> {
222        // simple 2d straight line
223        let targets = Array::linspace(0f64, 10., 100);
224        let records = targets.clone().into_shape((100, 1)).unwrap();
225        let dataset = Dataset::new(records, targets);
226
227        let model = Svm::params()
228            .c_svr(5., None)
229            .linear_kernel()
230            .fit(&dataset)?;
231        _check_model(model, &dataset);
232
233        // Old API
234        #[allow(deprecated)]
235        let model2 = Svm::params()
236            .c_eps(5., 1e-3)
237            .linear_kernel()
238            .fit(&dataset)?;
239        _check_model(model2, &dataset);
240
241        Ok(())
242    }
243
244    #[test]
245    fn test_nu_regression_linear() -> Result<()> {
246        // simple 2d straight line
247        let targets = Array::linspace(0f64, 10., 100);
248        let records = targets.clone().into_shape((100, 1)).unwrap();
249        let dataset = Dataset::new(records, targets);
250
251        // Test the precomputed dot product in the linear kernel case
252        let model = Svm::params()
253            .nu_svr(0.5, Some(1.))
254            .linear_kernel()
255            .fit(&dataset)?;
256        _check_model(model, &dataset);
257
258        // Old API
259        #[allow(deprecated)]
260        let model2 = Svm::params()
261            .nu_eps(0.5, 1e-3)
262            .linear_kernel()
263            .fit(&dataset)?;
264        _check_model(model2, &dataset);
265        Ok(())
266    }
267
268    #[test]
269    fn test_epsilon_regression_gaussian() -> Result<()> {
270        let records = Array::linspace(0f64, 10., 100)
271            .into_shape((100, 1))
272            .unwrap();
273        let sin_curve = records.mapv(|v| v.sin()).into_shape((100,)).unwrap();
274        let dataset = Dataset::new(records, sin_curve);
275
276        let model = Svm::params()
277            .c_svr(100., Some(0.1))
278            .gaussian_kernel(10.)
279            .eps(1e-3)
280            .fit(&dataset)?;
281        _check_model(model, &dataset);
282        Ok(())
283    }
284
285    #[test]
286    fn test_nu_regression_polynomial() -> Result<()> {
287        let n = 100;
288        let records = Array::linspace(0f64, 5., n).into_shape((n, 1)).unwrap();
289        let sin_curve = records.mapv(|v| v.sin()).into_shape((n,)).unwrap();
290        let dataset = Dataset::new(records, sin_curve);
291
292        let model = Svm::params()
293            .nu_svr(0.01, None)
294            .polynomial_kernel(1., 3.)
295            .eps(1e-3)
296            .fit(&dataset)?;
297        _check_model(model, &dataset);
298        Ok(())
299    }
300}