linfa_svm/
lib.rs

1//! # Support Vector Machines
2//!
3//! Support Vector Machines are a major branch of machine learning models and offer classification or
4//! regression analysis of labeled datasets. They seek a discriminant, which separates the data in
5//! an optimal way, e.g. have the fewest numbers of miss-classifications and maximizes the margin
6//! between positive and negative classes. A support vector
7//! contributes to the discriminant and is therefore important for the classification/regression
8//! task. The balance between the number of support vectors and model performance can be controlled
9//! with hyperparameters.
10//!
11//! More details can be found [here](https://en.wikipedia.org/wiki/Support_vector_machine)
12//!
13//! ## Available parameters in Classification and Regression
14//!
15//! For supervised classification tasks the C or Nu values are used to control this balance. In
16//! [fit_c](SVClassify/fn.fit_c) the
17//! C value controls the penalty given to missclassification and should be in the interval (0, inf). In
18//! [fit_nu](SVClassify/fn.fit_nu.html) the Nu value controls the number of support vectors and should be in the interval (0, 1].
19//!
20//! For supervised classification with just one class of data a special classifier is available in
21//! [fit_one_class](SVClassify/fn.fit_one_class.html). It also accepts a Nu value.
22//!
23//! For support vector regression two flavors are available. With
24//! [fit_epsilon](SVRegress/fn.fit_epsilon.html) a regression task is learned while minimizing deviation
25//! larger than epsilon. In [fit_nu](SVRegress/fn.fit_nu.html) the parameter epsilon is replaced with Nu
26//! again and should be in the interval (0, 1]
27//!
28//! ## Kernel Methods
29//! Normally the resulting discriminant is linear, but with [Kernel Methods](https://en.wikipedia.org/wiki/Kernel_method) non-linear relations between the input features
30//! can be learned in order improve the performance of the model.
31//!  
32//! For example to transform a dataset into a sparse RBF kernel with 10 non-zero distances you can
33//! use `linfa_kernel`:
34//! ```rust, ignore
35//! use linfa_kernel::Kernel;
36//! let train_kernel = Kernel::params()
37//!     .method(KernelMethod::Gaussian(30.0))
38//!     .transform(&train);
39//! ```
40//!
41//! # The solver
42//! This implementation uses Sequential Minimal Optimization, a widely used optimization tool for
43//! convex problems. It selects in each optimization step two variables and updates the variables.
44//! In each step it performs:
45//!
46//! 1. Find a variable, which violates the KKT conditions for the optimization problem
47//! 2. Pick a second variables and crate a pair (a1, a2)
48//! 3. Optimize the pair (a1, a2)
49//!
50//! After a couple of iterations the solution may be optimal.
51//!
52//! # Example
53//! The wine quality data consists of 11 features, like "acid", "sugar", "sulfur dioxide", and
54//! groups the quality into worst 3 to best 8. These are unified to good 8-7 and bad 3-6 to get a
55//! binary classification task.
56//!
57//! With an RBF kernel and C-Support Vector Classification an
58//! accuracy of 88.7% is reached within 79535 iterations and 316 support vectors. You can find the
59//! example [here](https://github.com/rust-ml/linfa/blob/master/linfa-svm/examples/winequality.rs).
60//! ```ignore
61//! Fit SVM classifier with #1440 training points
62//! Exited after 79535 iterations with obj = -46317.55802870996 and 316 support vectors
63//!
64//! classes    | bad        | good
65//! bad        | 133        | 9
66//! good       | 9          | 8
67//!
68//! accuracy 0.8867925, MCC 0.40720797
69//! ```
70use linfa::Float;
71use ndarray::{ArrayBase, Data, Ix1};
72
73use std::fmt;
74use std::marker::PhantomData;
75
76#[cfg(feature = "serde")]
77use serde_crate::{Deserialize, Serialize};
78
79mod classification;
80pub mod error;
81pub mod hyperparams;
82mod permutable_kernel;
83mod regression;
84pub mod solver_smo;
85
86pub use error::{Result, SvmError};
87pub use hyperparams::{SvmParams, SvmValidParams};
88use linfa_kernel::KernelMethod;
89pub use solver_smo::{SeparatingHyperplane, SolverParams};
90
91use std::ops::Mul;
92
93/// Reason for stopping
94///
95/// SMO can either exit because a threshold is reached or the iterations are maxed out. To
96/// differentiate between both this flag is passed with the solution.
97#[cfg_attr(
98    feature = "serde",
99    derive(Serialize, Deserialize),
100    serde(crate = "serde_crate")
101)]
102#[derive(Debug, Clone, PartialEq, Eq, Hash)]
103pub enum ExitReason {
104    ReachedThreshold,
105    ReachedIterations,
106}
107
108/// Fitted Support Vector Machines model
109///
110/// This is the result of the SMO optimizer and contains the support vectors, quality of solution
111/// and optionally the linear hyperplane.
112#[cfg_attr(
113    feature = "serde",
114    derive(Serialize, Deserialize),
115    serde(crate = "serde_crate")
116)]
117#[derive(Debug, Clone, PartialEq)]
118pub struct Svm<F: Float, T> {
119    pub alpha: Vec<F>,
120    pub rho: F,
121    r: Option<F>,
122    exit_reason: ExitReason,
123    iterations: usize,
124    obj: F,
125    #[cfg_attr(
126        feature = "serde",
127        serde(bound(
128            serialize = "KernelMethod<F>: Serialize",
129            deserialize = "KernelMethod<F>: Deserialize<'de>"
130        ))
131    )]
132    // the only thing I need the kernel for after the training is to
133    // compute the distances, but for that I only need the kernel method
134    // and not the whole inner matrix
135    kernel_method: KernelMethod<F>,
136    sep_hyperplane: SeparatingHyperplane<F>,
137    probability_coeffs: Option<(F, F)>,
138    phantom: PhantomData<T>,
139}
140
141impl<F: Float, T> Svm<F, T> {
142    /// Returns the number of support vectors
143    ///
144    /// This function returns the number of support vectors which have an influence on the decision
145    /// outcome greater than zero.
146    pub fn nsupport(&self) -> usize {
147        self.alpha
148            .iter()
149            // around 1e-5 for f32 and 2e-14 for f64
150            .filter(|x| x.abs() > F::cast(100.) * F::epsilon())
151            .count()
152    }
153    pub(crate) fn with_phantom<S>(self) -> Svm<F, S> {
154        Svm {
155            alpha: self.alpha,
156            rho: self.rho,
157            r: self.r,
158            exit_reason: self.exit_reason,
159            obj: self.obj,
160            iterations: self.iterations,
161            sep_hyperplane: self.sep_hyperplane,
162            kernel_method: self.kernel_method,
163            probability_coeffs: self.probability_coeffs,
164            phantom: PhantomData,
165        }
166    }
167
168    /// Sums the inner product of `sample` and every one of the support vectors.
169    ///
170    /// ## Parameters
171    ///
172    /// * `sample`: the input sample
173    ///
174    /// ## Returns
175    ///
176    /// The sum of all inner products of `sample` and every one of the support vectors, scaled by their weight.
177    ///
178    /// ## Panics
179    ///
180    /// If the shape of `sample` is not compatible with the
181    /// shape of the support vectors
182    pub fn weighted_sum<D: Data<Elem = F>>(&self, sample: &ArrayBase<D, Ix1>) -> F {
183        match self.sep_hyperplane {
184            SeparatingHyperplane::Linear(ref x) => x.mul(sample).sum(),
185            SeparatingHyperplane::WeightedCombination(ref supp_vecs) => supp_vecs
186                .outer_iter()
187                .zip(
188                    self.alpha
189                        .iter()
190                        .filter(|a| a.abs() > F::cast(100.) * F::epsilon()),
191                )
192                .map(|(x, a)| self.kernel_method.distance(x, sample.view()) * *a)
193                .sum(),
194        }
195    }
196}
197
198/// Display solution
199///
200/// In order to understand the solution of the SMO solver the objective, number of iterations and
201/// required support vectors are printed here.
202impl<F: Float, T> fmt::Display for Svm<F, T> {
203    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
204        match self.exit_reason {
205            ExitReason::ReachedThreshold => write!(
206                f,
207                "Exited after {} iterations with obj = {} and {} support vectors",
208                self.iterations,
209                self.obj,
210                self.nsupport()
211            ),
212            ExitReason::ReachedIterations => write!(
213                f,
214                "Reached maximal iterations {} with obj = {} and {} support vectors",
215                self.iterations,
216                self.obj,
217                self.nsupport()
218            ),
219        }
220    }
221}
222
223#[cfg(test)]
224mod tests {
225    use crate::{Svm, SvmParams, SvmValidParams};
226    use linfa::prelude::*;
227
228    #[test]
229    fn autotraits() {
230        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
231        has_autotraits::<Svm<f64, usize>>();
232        has_autotraits::<SvmParams<f64, usize>>();
233        has_autotraits::<SvmValidParams<f64, usize>>();
234    }
235
236    #[test]
237    fn test_iter_folding_for_classification() {
238        let mut dataset = linfa_datasets::winequality().map_targets(|x| *x > 6);
239        let params = Svm::<_, bool>::params()
240            .pos_neg_weights(7., 0.6)
241            .gaussian_kernel(80.0);
242
243        let avg_acc = dataset
244            .iter_fold(4, |training_set| params.fit(training_set).unwrap())
245            .map(|(model, valid)| {
246                model
247                    .predict(valid.view())
248                    .confusion_matrix(&valid)
249                    .unwrap()
250                    .accuracy()
251            })
252            .sum::<f32>()
253            / 4_f32;
254        assert!(avg_acc >= 0.5)
255    }
256
257    /*#[test]
258    fn test_iter_folding_for_regression() {
259        let mut dataset: Dataset<f64, f64> = linfa_datasets::diabetes();
260        let params = Svm::params().linear_kernel().c_eps(100., 1.);
261
262        let _avg_r2 = dataset
263            .iter_fold(4, |training_set| params.fit(&training_set).unwrap())
264            .map(|(model, valid)| Array1::from(model.predict(valid.view())).r2(valid.targets()))
265            .sum::<f64>()
266            / 4_f64;
267    }*/
268}