1use linfa::dataset::AsSingleTargets;
2use linfa::prelude::Transformer;
3use linfa::{
4 composing::platt_scaling::{platt_newton_method, platt_predict, PlattParams},
5 dataset::{CountedTargets, DatasetBase, Pr},
6 traits::Fit,
7 traits::{Predict, PredictInplace},
8 ParamGuard,
9};
10use ndarray::{Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Data, Ix1, Ix2};
11use std::cmp::Ordering;
12
13use super::error::{Result, SvmError};
14use super::permutable_kernel::{PermutableKernel, PermutableKernelOneClass};
15use super::solver_smo::SolverState;
16use super::SolverParams;
17use super::{Float, Svm, SvmValidParams};
18use linfa_kernel::Kernel;
19
20fn calibrate_with_platt<F: Float, D: Data<Elem = F>, T: AsSingleTargets<Elem = bool>>(
21 mut obj: Svm<F, F>,
22 params: &PlattParams<F, ()>,
23 dataset: &DatasetBase<ArrayBase<D, Ix2>, T>,
24) -> Result<Svm<F, Pr>> {
25 let pred = dataset
26 .records()
27 .outer_iter()
28 .map(|x| obj.weighted_sum(&x) - obj.rho)
29 .collect::<Array1<_>>();
30
31 let (a, b) = platt_newton_method(
32 pred.view(),
33 dataset.as_single_targets(),
34 params.check_ref()?,
35 )?;
36 obj.probability_coeffs = Some((a, b));
37
38 Ok(obj.with_phantom())
39}
40
41pub fn fit_c<F: Float>(
58 params: SolverParams<F>,
59 dataset: ArrayView2<F>,
60 kernel: Kernel<F>,
61 targets: &[bool],
62 cpos: F,
63 cneg: F,
64) -> Svm<F, F> {
65 let bounds = targets
66 .iter()
67 .map(|x| if *x { cpos } else { cneg })
68 .collect::<Vec<_>>();
69
70 let kernel = PermutableKernel::new(kernel, targets.to_vec());
71
72 let solver = SolverState::new(
73 vec![F::zero(); targets.len()],
74 vec![-F::one(); targets.len()],
75 targets.to_vec(),
76 dataset,
77 kernel,
78 bounds,
79 params,
80 false,
81 );
82
83 let mut res = solver.solve();
84
85 res.alpha = res
86 .alpha
87 .into_iter()
88 .zip(targets.iter())
89 .map(|(a, b)| if *b { a } else { -a })
90 .collect();
91
92 res
93}
94
95pub fn fit_nu<F: Float>(
111 params: SolverParams<F>,
112 dataset: ArrayView2<F>,
113 kernel: Kernel<F>,
114 targets: &[bool],
115 nu: F,
116) -> Svm<F, F> {
117 let mut sum_pos = nu * F::cast(targets.len()) / F::cast(2.0);
118 let mut sum_neg = nu * F::cast(targets.len()) / F::cast(2.0);
119 let init_alpha = targets
120 .iter()
121 .map(|x| {
122 if *x {
123 let val = F::min(F::one(), sum_pos);
124 sum_pos -= val;
125 val
126 } else {
127 let val = F::min(F::one(), sum_neg);
128 sum_neg -= val;
129 val
130 }
131 })
132 .collect::<Vec<_>>();
133
134 let kernel = PermutableKernel::new(kernel, targets.to_vec());
135
136 let solver = SolverState::new(
137 init_alpha,
138 vec![F::zero(); targets.len()],
139 targets.to_vec(),
140 dataset,
141 kernel,
142 vec![F::one(); targets.len()],
143 params,
144 true,
145 );
146
147 let mut res = solver.solve();
148
149 let r = res.r.unwrap();
150
151 res.alpha = res
152 .alpha
153 .into_iter()
154 .zip(targets.iter())
155 .map(|(a, b)| if *b { a } else { -a })
156 .map(|x| x / r)
157 .collect();
158 res.rho /= r;
159 res.obj /= r * r;
160
161 res
162}
163
164pub fn fit_one_class<F: Float + num_traits::ToPrimitive>(
175 params: SolverParams<F>,
176 dataset: ArrayView2<F>,
177 kernel: Kernel<F>,
178 nu: F,
179) -> Svm<F, F> {
180 let size = kernel.size();
181 let n = (nu * F::cast(size)).to_usize().unwrap();
182
183 let init_alpha = (0..size)
184 .map(|x| match x.cmp(&n) {
185 Ordering::Less => F::one(),
186 Ordering::Greater => F::zero(),
187 Ordering::Equal => nu * F::cast(size) - F::cast(x),
188 })
189 .collect::<Vec<_>>();
190
191 let kernel = PermutableKernelOneClass::new(kernel);
192
193 let solver = SolverState::new(
194 init_alpha,
195 vec![F::zero(); size],
196 vec![true; size],
197 dataset,
198 kernel,
199 vec![F::one(); size],
200 params,
201 false,
202 );
203
204 solver.solve()
205}
206
207macro_rules! impl_classification {
213 ($records:ty, $targets:ty) => {
214 impl<F: Float> Fit<$records, $targets, SvmError> for SvmValidParams<F, Pr> {
215 type Object = Svm<F, Pr>;
216
217 fn fit(&self, dataset: &DatasetBase<$records, $targets>) -> Result<Self::Object> {
218 let kernel = self.kernel_params().transform(dataset.records());
219 let target = dataset.as_single_targets();
220 let target = target.as_slice().unwrap();
221
222 let ret = match (self.c(), self.nu()) {
223 (Some((c_p, c_n)), _) => fit_c(
224 self.solver_params().clone(),
225 dataset.records().view(),
226 kernel,
227 target,
228 c_p,
229 c_n,
230 ),
231 (None, Some((nu, _))) => fit_nu(
232 self.solver_params().clone(),
233 dataset.records().view(),
234 kernel,
235 target,
236 nu,
237 ),
238 _ => panic!("Set either C value or Nu value"),
239 };
240
241 calibrate_with_platt(ret, &self.platt_params(), dataset)
242 }
243 }
244
245 impl<F: Float> Fit<$records, $targets, SvmError> for SvmValidParams<F, bool> {
246 type Object = Svm<F, bool>;
247
248 fn fit(&self, dataset: &DatasetBase<$records, $targets>) -> Result<Self::Object> {
249 let kernel = self.kernel_params().transform(dataset.records());
250 let target = dataset.as_single_targets();
251 let target = target.as_slice().unwrap();
252
253 let ret = match (self.c(), self.nu()) {
254 (Some((c_p, c_n)), _) => fit_c(
255 self.solver_params().clone(),
256 dataset.records().view(),
257 kernel,
258 target,
259 c_p,
260 c_n,
261 ),
262 (None, Some((nu, _))) => fit_nu(
263 self.solver_params().clone(),
264 dataset.records().view(),
265 kernel,
266 target,
267 nu,
268 ),
269 _ => panic!("Set either C value or Nu value"),
270 };
271
272 Ok(ret.with_phantom())
273 }
274 }
275 };
276}
277
278impl_classification!(Array2<F>, Array1<bool>);
279impl_classification!(ArrayView2<'_, F>, ArrayView1<'_, bool>);
280impl_classification!(Array2<F>, CountedTargets<bool, Array1<bool>>);
281impl_classification!(ArrayView2<'_, F>, CountedTargets<bool, Array1<bool>>);
282impl_classification!(ArrayView2<'_, F>, CountedTargets<bool, ArrayView1<'_, bool>>);
283
284macro_rules! impl_oneclass {
289 ($records:ty, $targets:ty) => {
290 impl<F: Float> Fit<$records, $targets, SvmError> for SvmValidParams<F, Pr> {
291 type Object = Svm<F, bool>;
292
293 fn fit(&self, dataset: &DatasetBase<$records, $targets>) -> Result<Self::Object> {
294 let kernel = self.kernel_params().transform(dataset.records());
295 let records = dataset.records().view();
296
297 let ret = match self.nu() {
298 Some((nu, _)) => {
299 fit_one_class(self.solver_params().clone(), records, kernel, nu)
300 }
301 None => panic!("One class needs Nu value"),
302 };
303
304 Ok(ret.with_phantom())
305 }
306 }
307 };
308}
309
310impl_oneclass!(Array2<F>, Array2<()>);
311impl_oneclass!(ArrayView2<'_, F>, ArrayView2<'_, ()>);
312impl_oneclass!(Array2<F>, CountedTargets<(), Array2<()>>);
313impl_oneclass!(Array2<F>, CountedTargets<(), ArrayView2<'_, ()>>);
314impl_oneclass!(Array2<F>, Array1<()>);
315impl_oneclass!(ArrayView2<'_, F>, ArrayView1<'_, ()>);
316impl_oneclass!(Array2<F>, CountedTargets<(), Array1<()>>);
317impl_oneclass!(Array2<F>, CountedTargets<(), ArrayView1<'_, ()>>);
318
319impl<F: Float, D: Data<Elem = F>> Predict<ArrayBase<D, Ix1>, Pr> for Svm<F, Pr> {
321 fn predict(&self, data: ArrayBase<D, Ix1>) -> Pr {
322 let val = self.weighted_sum(&data) - self.rho;
323 let (a, b) = self.probability_coeffs.unwrap();
324
325 platt_predict(val, a, b)
326 }
327}
328
329impl<F: Float, D: Data<Elem = F>> Predict<ArrayBase<D, Ix1>, bool> for Svm<F, bool> {
331 fn predict(&self, data: ArrayBase<D, Ix1>) -> bool {
332 let val = self.weighted_sum(&data) - self.rho;
333
334 val >= F::zero()
335 }
336}
337
338impl<F: Float, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Ix2>, Array1<Pr>> for Svm<F, Pr> {
362 fn predict_inplace(&self, data: &ArrayBase<D, Ix2>, targets: &mut Array1<Pr>) {
363 assert_eq!(
364 data.nrows(),
365 targets.len(),
366 "The number of data points must match the number of output targets."
367 );
368
369 let (a, b) = self.probability_coeffs.unwrap();
370
371 for (data, target) in data.outer_iter().zip(targets.iter_mut()) {
372 let val = self.weighted_sum(&data) - self.rho;
373 *target = platt_predict(val, a, b);
374 }
375 }
376
377 fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<Pr> {
378 Array1::default(x.nrows())
379 }
380}
381
382impl<F: Float, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Ix2>, Array1<bool>> for Svm<F, bool> {
387 fn predict_inplace(&self, data: &ArrayBase<D, Ix2>, targets: &mut Array1<bool>) {
388 assert_eq!(
389 data.nrows(),
390 targets.len(),
391 "The number of data points must match the number of output targets."
392 );
393
394 for (data, target) in data.outer_iter().zip(targets.iter_mut()) {
395 let val = self.weighted_sum(&data) - self.rho;
396 *target = val >= F::zero();
397 }
398 }
399
400 fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<bool> {
401 Array1::default(x.nrows())
402 }
403}
404
405#[cfg(test)]
406mod tests {
407 use std::f64::consts::TAU;
408
409 use super::Svm;
410 use crate::error::Result;
411 use approx::assert_abs_diff_eq;
412 use linfa::dataset::{Dataset, DatasetBase};
413 use linfa::prelude::ToConfusionMatrix;
414 use linfa::traits::{Fit, Predict};
415
416 use ndarray::{Array, Array1, Array2, Axis};
417 use ndarray_rand::rand::SeedableRng;
418 use ndarray_rand::rand_distr::Uniform;
419 use ndarray_rand::RandomExt;
420 use rand_xoshiro::Xoshiro256Plus;
421
422 pub fn generate_convoluted_rings(n_points: usize) -> Array2<f64> {
423 let mut out = Array::random((n_points * 2, 2), Uniform::new(0f64, 1.));
424 for (i, mut elm) in out.outer_iter_mut().enumerate() {
425 let phi = TAU * elm[1];
427 let eps = elm[0] / 10.0;
428
429 if i < n_points {
430 elm[0] = 1.0 * phi.cos() + eps;
431 elm[1] = 1.0 * phi.sin() + eps;
432 } else {
433 elm[0] = 5.0 * phi.cos() + eps;
434 elm[1] = 5.0 * phi.sin() + eps;
435 }
436 }
437
438 out
439 }
440
441 #[test]
442 fn test_linear_classification() -> Result<()> {
443 let entries: Array2<f64> = ndarray::concatenate(
444 Axis(0),
445 &[
446 Array::random((10, 2), Uniform::new(-1., -0.5)).view(),
447 Array::random((10, 2), Uniform::new(0.5, 1.)).view(),
448 ],
449 )
450 .unwrap();
451 let targets = (0..20).map(|x| x < 10).collect::<Array1<_>>();
452 let dataset = Dataset::new(entries, targets);
453
454 let model = Svm::<_, bool>::params()
456 .pos_neg_weights(1.0, 1.0)
457 .linear_kernel()
458 .fit(&dataset)?;
459
460 let y_est = model.predict(&dataset);
461
462 let cm = y_est.confusion_matrix(&dataset)?;
463 assert_abs_diff_eq!(cm.accuracy(), 1.0);
464
465 let model = Svm::<_, bool>::params()
467 .nu_weight(0.05)
468 .linear_kernel()
469 .fit(&dataset)?;
470
471 let valid = model.predict(&dataset);
472
473 let cm = valid.confusion_matrix(&dataset)?;
474 assert_abs_diff_eq!(cm.accuracy(), 1.0);
475
476 Ok(())
477 }
478
479 #[test]
480 fn test_polynomial_classification() -> Result<()> {
481 let mut rng = Xoshiro256Plus::seed_from_u64(42);
482 let records = Array::random_using((40, 1), Uniform::new(-2f64, 2.), &mut rng);
484 let targets = records.map_axis(Axis(1), |x| x[0] * x[0] < 0.5);
485 let dataset = Dataset::new(records, targets);
486
487 let model = Svm::<_, bool>::params()
489 .pos_neg_weights(1.0, 1.0)
490 .polynomial_kernel(0.0, 2.0)
491 .fit(&dataset)?;
492
493 let valid = model.predict(&dataset);
496
497 let cm = valid.confusion_matrix(&dataset)?;
498 assert!(cm.accuracy() > 0.9);
499
500 Ok(())
501 }
502
503 #[test]
504 fn test_convoluted_rings_classification() -> Result<()> {
505 let records = generate_convoluted_rings(10);
506 let targets = (0..20).map(|x| x < 10).collect::<Array1<_>>();
507 let dataset = (records.view(), targets.view()).into();
508
509 let model = Svm::<_, bool>::params()
511 .pos_neg_weights(1.0, 1.0)
512 .gaussian_kernel(50.0)
513 .fit(&dataset)?;
514
515 let y_est = model.predict(&dataset);
516
517 let cm = y_est.confusion_matrix(&dataset)?;
518 assert!(cm.accuracy() > 0.9);
519
520 let model = Svm::<_, bool>::params()
522 .nu_weight(0.01)
523 .gaussian_kernel(50.0)
524 .fit(&dataset)?;
525
526 let y_est = model.predict(&dataset);
527
528 let cm = y_est.confusion_matrix(&dataset)?;
529 assert!(cm.accuracy() > 0.9);
530
531 Ok(())
532 }
533
534 #[test]
535 fn test_iris_crossvalidation() {
536 let params = Svm::<_, bool>::params()
537 .pos_neg_weights(50000., 5000.)
538 .gaussian_kernel(40.0);
539
540 let acc_runs = linfa_datasets::winequality()
542 .map_targets(|x| *x > 6)
543 .iter_fold(1, |v| params.fit(v).unwrap())
544 .map(|(model, valid)| {
545 let cm = model.predict(&valid).confusion_matrix(&valid).unwrap();
546
547 cm.accuracy()
548 })
549 .collect::<Array1<_>>();
550
551 assert!(acc_runs[0] > 0.85);
552 }
553
554 #[test]
555 fn test_reject_classification() -> Result<()> {
556 let entries = Array::random((100, 2), Uniform::new(-4., 4.));
558 let dataset = Dataset::from(entries);
559
560 let model = Svm::params()
562 .nu_weight(1.0)
563 .gaussian_kernel(100.0)
564 .fit(&dataset)?;
565
566 let valid = DatasetBase::from(Array::random((100, 2), Uniform::new(-10., 10f32)));
567 let valid = model.predict(valid);
568
569 let mut rejected = 0;
571 let mut total = 0;
572 for (pred, pos) in valid.targets().iter().zip(valid.records.outer_iter()) {
573 let distance = (pos[0] * pos[0] + pos[1] * pos[1]).sqrt();
574 if distance >= 5.0 {
575 if !pred {
576 rejected += 1;
577 }
578 total += 1;
579 }
580 }
581
582 assert!((rejected as f32) / (total as f32) > 0.95);
584
585 Ok(())
586 }
587}