linfa_svm/
classification.rs

1use linfa::dataset::AsSingleTargets;
2use linfa::prelude::Transformer;
3use linfa::{
4    composing::platt_scaling::{platt_newton_method, platt_predict, PlattParams},
5    dataset::{CountedTargets, DatasetBase, Pr},
6    traits::Fit,
7    traits::{Predict, PredictInplace},
8    ParamGuard,
9};
10use ndarray::{Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Data, Ix1, Ix2};
11use std::cmp::Ordering;
12
13use super::error::{Result, SvmError};
14use super::permutable_kernel::{PermutableKernel, PermutableKernelOneClass};
15use super::solver_smo::SolverState;
16use super::SolverParams;
17use super::{Float, Svm, SvmValidParams};
18use linfa_kernel::Kernel;
19
20fn calibrate_with_platt<F: Float, D: Data<Elem = F>, T: AsSingleTargets<Elem = bool>>(
21    mut obj: Svm<F, F>,
22    params: &PlattParams<F, ()>,
23    dataset: &DatasetBase<ArrayBase<D, Ix2>, T>,
24) -> Result<Svm<F, Pr>> {
25    let pred = dataset
26        .records()
27        .outer_iter()
28        .map(|x| obj.weighted_sum(&x) - obj.rho)
29        .collect::<Array1<_>>();
30
31    let (a, b) = platt_newton_method(
32        pred.view(),
33        dataset.as_single_targets(),
34        params.check_ref()?,
35    )?;
36    obj.probability_coeffs = Some((a, b));
37
38    Ok(obj.with_phantom())
39}
40
41/// Support Vector Classification with C-penalizing parameter
42///
43/// This methods solves a binary SVC problem with a penalizing parameter C between (0, inf). The
44/// dual problem has the form
45/// ```ignore
46/// min_a 1/2*a^tQ a - e^T a s.t. y^t = 0, 0 <= a_i <= C_i
47/// ```
48/// with `Q_ij = y_i y_j K(x_i, x_j)` the kernel matrix.
49///
50/// # Parameters
51///
52/// * `params` - Solver parameters (threshold etc.)
53/// * `kernel` - the kernel matrix `Q`
54/// * `targets` - the ground truth targets `y_i`
55/// * `cpos` - C for positive targets
56/// * `cneg` - C for negative targets
57pub fn fit_c<F: Float>(
58    params: SolverParams<F>,
59    dataset: ArrayView2<F>,
60    kernel: Kernel<F>,
61    targets: &[bool],
62    cpos: F,
63    cneg: F,
64) -> Svm<F, F> {
65    let bounds = targets
66        .iter()
67        .map(|x| if *x { cpos } else { cneg })
68        .collect::<Vec<_>>();
69
70    let kernel = PermutableKernel::new(kernel, targets.to_vec());
71
72    let solver = SolverState::new(
73        vec![F::zero(); targets.len()],
74        vec![-F::one(); targets.len()],
75        targets.to_vec(),
76        dataset,
77        kernel,
78        bounds,
79        params,
80        false,
81    );
82
83    let mut res = solver.solve();
84
85    res.alpha = res
86        .alpha
87        .into_iter()
88        .zip(targets.iter())
89        .map(|(a, b)| if *b { a } else { -a })
90        .collect();
91
92    res
93}
94
95/// Support Vector Classification with Nu-penalizing term
96///
97/// This methods solves a binary SVC problem with a penalizing parameter nu between (0, 1). The
98/// dual problem has the form
99/// ```ignore
100/// min_a 1/2*a^tQ a s.t. y^t a = 0, 0 <= a_i <= 1/l, e^t a > nu
101/// ```
102/// with `Q_ij = y_i y_j K(x_i, x_j)` the kernel matrix.
103///
104/// # Parameters
105///
106/// * `params` - Solver parameters (threshold etc.)
107/// * `kernel` - the kernel matrix `Q`
108/// * `targets` - the ground truth targets `y_i`
109/// * `nu` - Nu penalizing term
110pub fn fit_nu<F: Float>(
111    params: SolverParams<F>,
112    dataset: ArrayView2<F>,
113    kernel: Kernel<F>,
114    targets: &[bool],
115    nu: F,
116) -> Svm<F, F> {
117    let mut sum_pos = nu * F::cast(targets.len()) / F::cast(2.0);
118    let mut sum_neg = nu * F::cast(targets.len()) / F::cast(2.0);
119    let init_alpha = targets
120        .iter()
121        .map(|x| {
122            if *x {
123                let val = F::min(F::one(), sum_pos);
124                sum_pos -= val;
125                val
126            } else {
127                let val = F::min(F::one(), sum_neg);
128                sum_neg -= val;
129                val
130            }
131        })
132        .collect::<Vec<_>>();
133
134    let kernel = PermutableKernel::new(kernel, targets.to_vec());
135
136    let solver = SolverState::new(
137        init_alpha,
138        vec![F::zero(); targets.len()],
139        targets.to_vec(),
140        dataset,
141        kernel,
142        vec![F::one(); targets.len()],
143        params,
144        true,
145    );
146
147    let mut res = solver.solve();
148
149    let r = res.r.unwrap();
150
151    res.alpha = res
152        .alpha
153        .into_iter()
154        .zip(targets.iter())
155        .map(|(a, b)| if *b { a } else { -a })
156        .map(|x| x / r)
157        .collect();
158    res.rho /= r;
159    res.obj /= r * r;
160
161    res
162}
163
164/// Support Vector Classification for one-class problems
165///
166/// This methods solves a binary SVC, when there are no targets available. This can, for example be
167/// useful, when outliers should be rejected.
168///
169/// # Parameters
170///
171/// * `params` - Solver parameters (threshold etc.)
172/// * `kernel` - the kernel matrix `Q`
173/// * `nu` - Nu penalizing term
174pub fn fit_one_class<F: Float + num_traits::ToPrimitive>(
175    params: SolverParams<F>,
176    dataset: ArrayView2<F>,
177    kernel: Kernel<F>,
178    nu: F,
179) -> Svm<F, F> {
180    let size = kernel.size();
181    let n = (nu * F::cast(size)).to_usize().unwrap();
182
183    let init_alpha = (0..size)
184        .map(|x| match x.cmp(&n) {
185            Ordering::Less => F::one(),
186            Ordering::Greater => F::zero(),
187            Ordering::Equal => nu * F::cast(size) - F::cast(x),
188        })
189        .collect::<Vec<_>>();
190
191    let kernel = PermutableKernelOneClass::new(kernel);
192
193    let solver = SolverState::new(
194        init_alpha,
195        vec![F::zero(); size],
196        vec![true; size],
197        dataset,
198        kernel,
199        vec![F::one(); size],
200        params,
201        false,
202    );
203
204    solver.solve()
205}
206
207/// Fit binary classification problem
208///
209/// For a given dataset with kernel matrix as records and two class problem as targets this fits
210/// a optimal hyperplane to the problem and returns the solution as a model. The model predicts
211/// probabilities for whether a sample belongs to the first or second class.
212macro_rules! impl_classification {
213    ($records:ty, $targets:ty) => {
214        impl<F: Float> Fit<$records, $targets, SvmError> for SvmValidParams<F, Pr> {
215            type Object = Svm<F, Pr>;
216
217            fn fit(&self, dataset: &DatasetBase<$records, $targets>) -> Result<Self::Object> {
218                let kernel = self.kernel_params().transform(dataset.records());
219                let target = dataset.as_single_targets();
220                let target = target.as_slice().unwrap();
221
222                let ret = match (self.c(), self.nu()) {
223                    (Some((c_p, c_n)), _) => fit_c(
224                        self.solver_params().clone(),
225                        dataset.records().view(),
226                        kernel,
227                        target,
228                        c_p,
229                        c_n,
230                    ),
231                    (None, Some((nu, _))) => fit_nu(
232                        self.solver_params().clone(),
233                        dataset.records().view(),
234                        kernel,
235                        target,
236                        nu,
237                    ),
238                    _ => panic!("Set either C value or Nu value"),
239                };
240
241                calibrate_with_platt(ret, &self.platt_params(), dataset)
242            }
243        }
244
245        impl<F: Float> Fit<$records, $targets, SvmError> for SvmValidParams<F, bool> {
246            type Object = Svm<F, bool>;
247
248            fn fit(&self, dataset: &DatasetBase<$records, $targets>) -> Result<Self::Object> {
249                let kernel = self.kernel_params().transform(dataset.records());
250                let target = dataset.as_single_targets();
251                let target = target.as_slice().unwrap();
252
253                let ret = match (self.c(), self.nu()) {
254                    (Some((c_p, c_n)), _) => fit_c(
255                        self.solver_params().clone(),
256                        dataset.records().view(),
257                        kernel,
258                        target,
259                        c_p,
260                        c_n,
261                    ),
262                    (None, Some((nu, _))) => fit_nu(
263                        self.solver_params().clone(),
264                        dataset.records().view(),
265                        kernel,
266                        target,
267                        nu,
268                    ),
269                    _ => panic!("Set either C value or Nu value"),
270                };
271
272                Ok(ret.with_phantom())
273            }
274        }
275    };
276}
277
278impl_classification!(Array2<F>, Array1<bool>);
279impl_classification!(ArrayView2<'_, F>, ArrayView1<'_, bool>);
280impl_classification!(Array2<F>, CountedTargets<bool, Array1<bool>>);
281impl_classification!(ArrayView2<'_, F>, CountedTargets<bool, Array1<bool>>);
282impl_classification!(ArrayView2<'_, F>, CountedTargets<bool, ArrayView1<'_, bool>>);
283
284/// Fit one-class problem
285///
286/// This fits a SVM model to a dataset with only positive samples and uses the one-class
287/// implementation of SVM.
288macro_rules! impl_oneclass {
289    ($records:ty, $targets:ty) => {
290        impl<F: Float> Fit<$records, $targets, SvmError> for SvmValidParams<F, Pr> {
291            type Object = Svm<F, bool>;
292
293            fn fit(&self, dataset: &DatasetBase<$records, $targets>) -> Result<Self::Object> {
294                let kernel = self.kernel_params().transform(dataset.records());
295                let records = dataset.records().view();
296
297                let ret = match self.nu() {
298                    Some((nu, _)) => {
299                        fit_one_class(self.solver_params().clone(), records, kernel, nu)
300                    }
301                    None => panic!("One class needs Nu value"),
302                };
303
304                Ok(ret.with_phantom())
305            }
306        }
307    };
308}
309
310impl_oneclass!(Array2<F>, Array2<()>);
311impl_oneclass!(ArrayView2<'_, F>, ArrayView2<'_, ()>);
312impl_oneclass!(Array2<F>, CountedTargets<(), Array2<()>>);
313impl_oneclass!(Array2<F>, CountedTargets<(), ArrayView2<'_, ()>>);
314impl_oneclass!(Array2<F>, Array1<()>);
315impl_oneclass!(ArrayView2<'_, F>, ArrayView1<'_, ()>);
316impl_oneclass!(Array2<F>, CountedTargets<(), Array1<()>>);
317impl_oneclass!(Array2<F>, CountedTargets<(), ArrayView1<'_, ()>>);
318
319/// Predict a probability with a feature vector
320impl<F: Float, D: Data<Elem = F>> Predict<ArrayBase<D, Ix1>, Pr> for Svm<F, Pr> {
321    fn predict(&self, data: ArrayBase<D, Ix1>) -> Pr {
322        let val = self.weighted_sum(&data) - self.rho;
323        let (a, b) = self.probability_coeffs.unwrap();
324
325        platt_predict(val, a, b)
326    }
327}
328
329/// Predict a probability with a feature vector
330impl<F: Float, D: Data<Elem = F>> Predict<ArrayBase<D, Ix1>, bool> for Svm<F, bool> {
331    fn predict(&self, data: ArrayBase<D, Ix1>) -> bool {
332        let val = self.weighted_sum(&data) - self.rho;
333
334        val >= F::zero()
335    }
336}
337
338// /// Predict a probability with a feature vector
339// impl<'a, F: Float> Predict<ArrayView1<'a, F>, Pr> for Svm<F, Pr> {
340//     fn predict(&self, data: ArrayView1<'a, F>) -> Pr {
341//         let val = self.weighted_sum(&data) - self.rho;
342//         let (a, b) = self.probability_coeffs.clone().unwrap();
343
344//         platt_predict(val, a, b)
345//     }
346// }
347
348// /// Predict a probability with a feature vector
349// impl<F: Float> Predict<Array1<F>, bool> for Svm<F, bool> {
350//     fn predict(&self, data: Array1<F>) -> bool {
351//         let val = self.weighted_sum(&data) - self.rho;
352
353//         val >= F::zero()
354//     }
355// }
356
357/// Classify observations
358///
359/// This function takes a number of features and predicts target probabilities that they belong to
360/// the positive class.
361impl<F: Float, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Ix2>, Array1<Pr>> for Svm<F, Pr> {
362    fn predict_inplace(&self, data: &ArrayBase<D, Ix2>, targets: &mut Array1<Pr>) {
363        assert_eq!(
364            data.nrows(),
365            targets.len(),
366            "The number of data points must match the number of output targets."
367        );
368
369        let (a, b) = self.probability_coeffs.unwrap();
370
371        for (data, target) in data.outer_iter().zip(targets.iter_mut()) {
372            let val = self.weighted_sum(&data) - self.rho;
373            *target = platt_predict(val, a, b);
374        }
375    }
376
377    fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<Pr> {
378        Array1::default(x.nrows())
379    }
380}
381
382/// Classify observations
383///
384/// This function takes a number of features and predicts target probabilities that they belong to
385/// the positive class.
386impl<F: Float, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Ix2>, Array1<bool>> for Svm<F, bool> {
387    fn predict_inplace(&self, data: &ArrayBase<D, Ix2>, targets: &mut Array1<bool>) {
388        assert_eq!(
389            data.nrows(),
390            targets.len(),
391            "The number of data points must match the number of output targets."
392        );
393
394        for (data, target) in data.outer_iter().zip(targets.iter_mut()) {
395            let val = self.weighted_sum(&data) - self.rho;
396            *target = val >= F::zero();
397        }
398    }
399
400    fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<bool> {
401        Array1::default(x.nrows())
402    }
403}
404
405#[cfg(test)]
406mod tests {
407    use std::f64::consts::TAU;
408
409    use super::Svm;
410    use crate::error::Result;
411    use approx::assert_abs_diff_eq;
412    use linfa::dataset::{Dataset, DatasetBase};
413    use linfa::prelude::ToConfusionMatrix;
414    use linfa::traits::{Fit, Predict};
415
416    use ndarray::{Array, Array1, Array2, Axis};
417    use ndarray_rand::rand::SeedableRng;
418    use ndarray_rand::rand_distr::Uniform;
419    use ndarray_rand::RandomExt;
420    use rand_xoshiro::Xoshiro256Plus;
421
422    pub fn generate_convoluted_rings(n_points: usize) -> Array2<f64> {
423        let mut out = Array::random((n_points * 2, 2), Uniform::new(0f64, 1.));
424        for (i, mut elm) in out.outer_iter_mut().enumerate() {
425            // generate convoluted rings with 1/10th noise
426            let phi = TAU * elm[1];
427            let eps = elm[0] / 10.0;
428
429            if i < n_points {
430                elm[0] = 1.0 * phi.cos() + eps;
431                elm[1] = 1.0 * phi.sin() + eps;
432            } else {
433                elm[0] = 5.0 * phi.cos() + eps;
434                elm[1] = 5.0 * phi.sin() + eps;
435            }
436        }
437
438        out
439    }
440
441    #[test]
442    fn test_linear_classification() -> Result<()> {
443        let entries: Array2<f64> = ndarray::concatenate(
444            Axis(0),
445            &[
446                Array::random((10, 2), Uniform::new(-1., -0.5)).view(),
447                Array::random((10, 2), Uniform::new(0.5, 1.)).view(),
448            ],
449        )
450        .unwrap();
451        let targets = (0..20).map(|x| x < 10).collect::<Array1<_>>();
452        let dataset = Dataset::new(entries, targets);
453
454        // train model with positive and negative weight
455        let model = Svm::<_, bool>::params()
456            .pos_neg_weights(1.0, 1.0)
457            .linear_kernel()
458            .fit(&dataset)?;
459
460        let y_est = model.predict(&dataset);
461
462        let cm = y_est.confusion_matrix(&dataset)?;
463        assert_abs_diff_eq!(cm.accuracy(), 1.0);
464
465        // train model with Nu parameter
466        let model = Svm::<_, bool>::params()
467            .nu_weight(0.05)
468            .linear_kernel()
469            .fit(&dataset)?;
470
471        let valid = model.predict(&dataset);
472
473        let cm = valid.confusion_matrix(&dataset)?;
474        assert_abs_diff_eq!(cm.accuracy(), 1.0);
475
476        Ok(())
477    }
478
479    #[test]
480    fn test_polynomial_classification() -> Result<()> {
481        let mut rng = Xoshiro256Plus::seed_from_u64(42);
482        // construct parabolica and classify middle area as positive and borders as negative
483        let records = Array::random_using((40, 1), Uniform::new(-2f64, 2.), &mut rng);
484        let targets = records.map_axis(Axis(1), |x| x[0] * x[0] < 0.5);
485        let dataset = Dataset::new(records, targets);
486
487        // train model with positive and negative weight
488        let model = Svm::<_, bool>::params()
489            .pos_neg_weights(1.0, 1.0)
490            .polynomial_kernel(0.0, 2.0)
491            .fit(&dataset)?;
492
493        //println!("{:?}", model.predict(DatasetBase::from(records.clone())).targets());
494
495        let valid = model.predict(&dataset);
496
497        let cm = valid.confusion_matrix(&dataset)?;
498        assert!(cm.accuracy() > 0.9);
499
500        Ok(())
501    }
502
503    #[test]
504    fn test_convoluted_rings_classification() -> Result<()> {
505        let records = generate_convoluted_rings(10);
506        let targets = (0..20).map(|x| x < 10).collect::<Array1<_>>();
507        let dataset = (records.view(), targets.view()).into();
508
509        // train model with positive and negative weight
510        let model = Svm::<_, bool>::params()
511            .pos_neg_weights(1.0, 1.0)
512            .gaussian_kernel(50.0)
513            .fit(&dataset)?;
514
515        let y_est = model.predict(&dataset);
516
517        let cm = y_est.confusion_matrix(&dataset)?;
518        assert!(cm.accuracy() > 0.9);
519
520        // train model with Nu parameter
521        let model = Svm::<_, bool>::params()
522            .nu_weight(0.01)
523            .gaussian_kernel(50.0)
524            .fit(&dataset)?;
525
526        let y_est = model.predict(&dataset);
527
528        let cm = y_est.confusion_matrix(&dataset)?;
529        assert!(cm.accuracy() > 0.9);
530
531        Ok(())
532    }
533
534    #[test]
535    fn test_iris_crossvalidation() {
536        let params = Svm::<_, bool>::params()
537            .pos_neg_weights(50000., 5000.)
538            .gaussian_kernel(40.0);
539
540        // perform cross-validation with the MCC
541        let acc_runs = linfa_datasets::winequality()
542            .map_targets(|x| *x > 6)
543            .iter_fold(1, |v| params.fit(v).unwrap())
544            .map(|(model, valid)| {
545                let cm = model.predict(&valid).confusion_matrix(&valid).unwrap();
546
547                cm.accuracy()
548            })
549            .collect::<Array1<_>>();
550
551        assert!(acc_runs[0] > 0.85);
552    }
553
554    #[test]
555    fn test_reject_classification() -> Result<()> {
556        // generate two clusters with 100 samples each
557        let entries = Array::random((100, 2), Uniform::new(-4., 4.));
558        let dataset = Dataset::from(entries);
559
560        // train model with positive and negative weight
561        let model = Svm::params()
562            .nu_weight(1.0)
563            .gaussian_kernel(100.0)
564            .fit(&dataset)?;
565
566        let valid = DatasetBase::from(Array::random((100, 2), Uniform::new(-10., 10f32)));
567        let valid = model.predict(valid);
568
569        // count the number of correctly rejected samples
570        let mut rejected = 0;
571        let mut total = 0;
572        for (pred, pos) in valid.targets().iter().zip(valid.records.outer_iter()) {
573            let distance = (pos[0] * pos[0] + pos[1] * pos[1]).sqrt();
574            if distance >= 5.0 {
575                if !pred {
576                    rejected += 1;
577                }
578                total += 1;
579            }
580        }
581
582        // at least 95% should be correctly rejected
583        assert!((rejected as f32) / (total as f32) > 0.95);
584
585        Ok(())
586    }
587}