linfa_svm/lib.rs
1//! # Support Vector Machines
2//!
3//! Support Vector Machines are a major branch of machine learning models and offer classification or
4//! regression analysis of labeled datasets. They seek a discriminant, which separates the data in
5//! an optimal way, e.g. have the fewest numbers of miss-classifications and maximizes the margin
6//! between positive and negative classes. A support vector
7//! contributes to the discriminant and is therefore important for the classification/regression
8//! task. The balance between the number of support vectors and model performance can be controlled
9//! with hyperparameters.
10//!
11//! More details can be found [here](https://en.wikipedia.org/wiki/Support_vector_machine)
12//!
13//! ## Available parameters in Classification and Regression
14//!
15//! For supervised classification tasks the C or Nu values are used to control this balance. In
16//! [fit_c](SVClassify/fn.fit_c) the
17//! C value controls the penalty given to missclassification and should be in the interval (0, inf). In
18//! [fit_nu](SVClassify/fn.fit_nu.html) the Nu value controls the number of support vectors and should be in the interval (0, 1].
19//!
20//! For supervised classification with just one class of data a special classifier is available in
21//! [fit_one_class](SVClassify/fn.fit_one_class.html). It also accepts a Nu value.
22//!
23//! For support vector regression two flavors are available. With
24//! [fit_epsilon](SVRegress/fn.fit_epsilon.html) a regression task is learned while minimizing deviation
25//! larger than epsilon. In [fit_nu](SVRegress/fn.fit_nu.html) the parameter epsilon is replaced with Nu
26//! again and should be in the interval (0, 1]
27//!
28//! ## Kernel Methods
29//! Normally the resulting discriminant is linear, but with [Kernel Methods](https://en.wikipedia.org/wiki/Kernel_method) non-linear relations between the input features
30//! can be learned in order improve the performance of the model.
31//!
32//! For example to transform a dataset into a sparse RBF kernel with 10 non-zero distances you can
33//! use `linfa_kernel`:
34//! ```rust, ignore
35//! use linfa_kernel::Kernel;
36//! let train_kernel = Kernel::params()
37//! .method(KernelMethod::Gaussian(30.0))
38//! .transform(&train);
39//! ```
40//!
41//! # The solver
42//! This implementation uses Sequential Minimal Optimization, a widely used optimization tool for
43//! convex problems. It selects in each optimization step two variables and updates the variables.
44//! In each step it performs:
45//!
46//! 1. Find a variable, which violates the KKT conditions for the optimization problem
47//! 2. Pick a second variables and crate a pair (a1, a2)
48//! 3. Optimize the pair (a1, a2)
49//!
50//! After a couple of iterations the solution may be optimal.
51//!
52//! # Example
53//! The wine quality data consists of 11 features, like "acid", "sugar", "sulfur dioxide", and
54//! groups the quality into worst 3 to best 8. These are unified to good 8-7 and bad 3-6 to get a
55//! binary classification task.
56//!
57//! With an RBF kernel and C-Support Vector Classification an
58//! accuracy of 88.7% is reached within 79535 iterations and 316 support vectors. You can find the
59//! example [here](https://github.com/rust-ml/linfa/blob/master/linfa-svm/examples/winequality.rs).
60//! ```ignore
61//! Fit SVM classifier with #1440 training points
62//! Exited after 79535 iterations with obj = -46317.55802870996 and 316 support vectors
63//!
64//! classes | bad | good
65//! bad | 133 | 9
66//! good | 9 | 8
67//!
68//! accuracy 0.8867925, MCC 0.40720797
69//! ```
70use linfa::Float;
71use ndarray::{ArrayBase, Data, Ix1};
72
73use std::fmt;
74use std::marker::PhantomData;
75
76#[cfg(feature = "serde")]
77use serde_crate::{Deserialize, Serialize};
78
79mod classification;
80pub mod error;
81pub mod hyperparams;
82mod permutable_kernel;
83mod regression;
84pub mod solver_smo;
85
86pub use error::{Result, SvmError};
87pub use hyperparams::{SvmParams, SvmValidParams};
88use linfa_kernel::KernelMethod;
89pub use solver_smo::{SeparatingHyperplane, SolverParams};
90
91use std::ops::Mul;
92
93/// Reason for stopping
94///
95/// SMO can either exit because a threshold is reached or the iterations are maxed out. To
96/// differentiate between both this flag is passed with the solution.
97#[cfg_attr(
98 feature = "serde",
99 derive(Serialize, Deserialize),
100 serde(crate = "serde_crate")
101)]
102#[derive(Debug, Clone, PartialEq, Eq, Hash)]
103pub enum ExitReason {
104 ReachedThreshold,
105 ReachedIterations,
106}
107
108/// Fitted Support Vector Machines model
109///
110/// This is the result of the SMO optimizer and contains the support vectors, quality of solution
111/// and optionally the linear hyperplane.
112#[cfg_attr(
113 feature = "serde",
114 derive(Serialize, Deserialize),
115 serde(crate = "serde_crate")
116)]
117#[derive(Debug, Clone, PartialEq)]
118pub struct Svm<F: Float, T> {
119 pub alpha: Vec<F>,
120 pub rho: F,
121 r: Option<F>,
122 exit_reason: ExitReason,
123 iterations: usize,
124 obj: F,
125 #[cfg_attr(
126 feature = "serde",
127 serde(bound(
128 serialize = "KernelMethod<F>: Serialize",
129 deserialize = "KernelMethod<F>: Deserialize<'de>"
130 ))
131 )]
132 // the only thing I need the kernel for after the training is to
133 // compute the distances, but for that I only need the kernel method
134 // and not the whole inner matrix
135 kernel_method: KernelMethod<F>,
136 sep_hyperplane: SeparatingHyperplane<F>,
137 probability_coeffs: Option<(F, F)>,
138 phantom: PhantomData<T>,
139}
140
141impl<F: Float, T> Svm<F, T> {
142 /// Returns the number of support vectors
143 ///
144 /// This function returns the number of support vectors which have an influence on the decision
145 /// outcome greater than zero.
146 pub fn nsupport(&self) -> usize {
147 self.alpha
148 .iter()
149 // around 1e-5 for f32 and 2e-14 for f64
150 .filter(|x| x.abs() > F::cast(100.) * F::epsilon())
151 .count()
152 }
153 pub(crate) fn with_phantom<S>(self) -> Svm<F, S> {
154 Svm {
155 alpha: self.alpha,
156 rho: self.rho,
157 r: self.r,
158 exit_reason: self.exit_reason,
159 obj: self.obj,
160 iterations: self.iterations,
161 sep_hyperplane: self.sep_hyperplane,
162 kernel_method: self.kernel_method,
163 probability_coeffs: self.probability_coeffs,
164 phantom: PhantomData,
165 }
166 }
167
168 /// Sums the inner product of `sample` and every one of the support vectors.
169 ///
170 /// ## Parameters
171 ///
172 /// * `sample`: the input sample
173 ///
174 /// ## Returns
175 ///
176 /// The sum of all inner products of `sample` and every one of the support vectors, scaled by their weight.
177 ///
178 /// ## Panics
179 ///
180 /// If the shape of `sample` is not compatible with the
181 /// shape of the support vectors
182 pub fn weighted_sum<D: Data<Elem = F>>(&self, sample: &ArrayBase<D, Ix1>) -> F {
183 match self.sep_hyperplane {
184 SeparatingHyperplane::Linear(ref x) => x.mul(sample).sum(),
185 SeparatingHyperplane::WeightedCombination(ref supp_vecs) => supp_vecs
186 .outer_iter()
187 .zip(
188 self.alpha
189 .iter()
190 .filter(|a| a.abs() > F::cast(100.) * F::epsilon()),
191 )
192 .map(|(x, a)| self.kernel_method.distance(x, sample.view()) * *a)
193 .sum(),
194 }
195 }
196}
197
198/// Display solution
199///
200/// In order to understand the solution of the SMO solver the objective, number of iterations and
201/// required support vectors are printed here.
202impl<F: Float, T> fmt::Display for Svm<F, T> {
203 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
204 match self.exit_reason {
205 ExitReason::ReachedThreshold => write!(
206 f,
207 "Exited after {} iterations with obj = {} and {} support vectors",
208 self.iterations,
209 self.obj,
210 self.nsupport()
211 ),
212 ExitReason::ReachedIterations => write!(
213 f,
214 "Reached maximal iterations {} with obj = {} and {} support vectors",
215 self.iterations,
216 self.obj,
217 self.nsupport()
218 ),
219 }
220 }
221}
222
223#[cfg(test)]
224mod tests {
225 use crate::{Svm, SvmParams, SvmValidParams};
226 use linfa::prelude::*;
227
228 #[test]
229 fn autotraits() {
230 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
231 has_autotraits::<Svm<f64, usize>>();
232 has_autotraits::<SvmParams<f64, usize>>();
233 has_autotraits::<SvmValidParams<f64, usize>>();
234 }
235
236 #[test]
237 fn test_iter_folding_for_classification() {
238 let mut dataset = linfa_datasets::winequality().map_targets(|x| *x > 6);
239 let params = Svm::<_, bool>::params()
240 .pos_neg_weights(7., 0.6)
241 .gaussian_kernel(80.0);
242
243 let avg_acc = dataset
244 .iter_fold(4, |training_set| params.fit(training_set).unwrap())
245 .map(|(model, valid)| {
246 model
247 .predict(valid.view())
248 .confusion_matrix(&valid)
249 .unwrap()
250 .accuracy()
251 })
252 .sum::<f32>()
253 / 4_f32;
254 assert!(avg_acc >= 0.5)
255 }
256
257 /*#[test]
258 fn test_iter_folding_for_regression() {
259 let mut dataset: Dataset<f64, f64> = linfa_datasets::diabetes();
260 let params = Svm::params().linear_kernel().c_eps(100., 1.);
261
262 let _avg_r2 = dataset
263 .iter_fold(4, |training_set| params.fit(&training_set).unwrap())
264 .map(|(model, valid)| Array1::from(model.predict(valid.view())).r2(valid.targets()))
265 .sum::<f64>()
266 / 4_f64;
267 }*/
268}