1use super::permutable_kernel::Permutable;
2use super::{ExitReason, Float, Svm};
3
4use ndarray::{Array1, Array2, ArrayView2, Axis};
5#[cfg(feature = "serde")]
6use serde_crate::{Deserialize, Serialize};
7use std::marker::PhantomData;
8
9#[derive(Clone, Debug, PartialEq)]
11pub struct SolverParams<F: Float> {
12 pub eps: F,
14 pub shrinking: bool,
16}
17
18#[derive(Clone, Debug, PartialEq)]
20struct Alpha<F: Float> {
21 value: F,
22 upper_bound: F,
23}
24
25impl<F: Float> Alpha<F> {
26 pub fn from(value: F, upper_bound: F) -> Alpha<F> {
27 Alpha { value, upper_bound }
28 }
29
30 pub fn reached_upper(&self) -> bool {
31 self.value >= self.upper_bound
32 }
33
34 pub fn free_floating(&self) -> bool {
35 self.value < self.upper_bound && self.value > F::zero()
36 }
37
38 pub fn reached_lower(&self) -> bool {
39 self.value == F::zero()
40 }
41
42 pub fn val(&self) -> F {
43 self.value
44 }
45}
46
47#[cfg_attr(
48 feature = "serde",
49 derive(Serialize, Deserialize),
50 serde(crate = "serde_crate")
51)]
52#[derive(Clone, Debug, PartialEq)]
53pub enum SeparatingHyperplane<F: Float> {
54 Linear(Array1<F>),
55 WeightedCombination(Array2<F>),
56}
57
58#[derive(Clone, Debug, PartialEq)]
64pub struct SolverState<'a, F: Float, K: Permutable<F>> {
65 gradient: Vec<F>,
67 gradient_fixed: Vec<F>,
69 alpha: Vec<Alpha<F>>,
71 active_set: Vec<usize>,
73 nactive: usize,
75 unshrink: bool,
76 nu_constraint: bool,
77 r: F,
78
79 dataset: ArrayView2<'a, F>,
81
82 kernel: K,
84 p: Vec<F>,
86 targets: Vec<bool>,
88 bounds: Vec<F>,
90
91 params: SolverParams<F>,
93
94 phantom: PhantomData<&'a K>,
95}
96
97#[allow(clippy::needless_range_loop, clippy::too_many_arguments)]
98impl<'a, F: Float, K: 'a + Permutable<F>> SolverState<'a, F, K> {
99 pub fn new(
103 alpha: Vec<F>,
104 p: Vec<F>,
105 targets: Vec<bool>,
106 dataset: ArrayView2<'a, F>,
107 kernel: K,
108 bounds: Vec<F>,
109 params: SolverParams<F>,
110 nu_constraint: bool,
111 ) -> SolverState<'a, F, K> {
112 let alpha = alpha
114 .into_iter()
115 .enumerate()
116 .map(|(i, alpha)| Alpha::from(alpha, bounds[i]))
117 .collect::<Vec<_>>();
118
119 let active_set = (0..alpha.len()).collect::<Vec<_>>();
121
122 let mut gradient = p.clone();
124 let mut gradient_fixed = vec![F::zero(); alpha.len()];
125
126 for i in 0..alpha.len() {
127 if !alpha[i].reached_lower() {
129 let dist_i = kernel.distances(i, alpha.len());
130 let alpha_i = alpha[i].val();
131
132 for j in 0..alpha.len() {
134 gradient[j] += alpha_i * dist_i[j];
135 }
136
137 if alpha[i].reached_upper() {
139 for j in 0..alpha.len() {
140 gradient_fixed[j] += bounds[i] * dist_i[j];
141 }
142 }
143 }
144 }
145
146 SolverState {
147 gradient,
148 gradient_fixed,
149 alpha,
150 p,
151 nactive: active_set.len(),
152 unshrink: false,
153 active_set,
154 dataset,
155 kernel,
156 targets,
157 bounds,
158 params,
159 nu_constraint,
160 r: F::zero(),
161 phantom: PhantomData,
162 }
163 }
164
165 pub fn nactive(&self) -> usize {
167 self.nactive
168 }
169
170 pub fn ntotal(&self) -> usize {
172 self.alpha.len()
173 }
174
175 pub fn target(&self, idx: usize) -> F {
177 if self.targets[idx] {
178 F::one()
179 } else {
180 -F::one()
181 }
182 }
183
184 pub fn bound(&self, idx: usize) -> F {
186 self.bounds[idx]
187 }
188
189 pub fn swap(&mut self, i: usize, j: usize) {
191 self.gradient.swap(i, j);
192 self.gradient_fixed.swap(i, j);
193 self.alpha.swap(i, j);
194 self.p.swap(i, j);
195 self.active_set.swap(i, j);
196 self.kernel.swap_indices(i, j);
197 self.targets.swap(i, j);
198 }
199
200 fn reconstruct_gradient(&mut self) {
205 if self.nactive() == self.ntotal() {
207 return;
208 }
209
210 for j in self.nactive()..self.ntotal() {
212 self.gradient[j] = self.gradient_fixed[j] + self.p[j];
213 }
214
215 let nfree: usize = (0..self.nactive())
216 .filter(|x| self.alpha[*x].free_floating())
217 .count();
218 if nfree * self.ntotal() > 2 * self.nactive() * (self.ntotal() - self.nactive()) {
219 for i in self.nactive()..self.ntotal() {
220 let dist_i = self.kernel.distances(i, self.nactive());
221 for j in 0..self.nactive() {
222 if self.alpha[i].free_floating() {
223 self.gradient[i] += self.alpha[j].val() * dist_i[j];
224 }
225 }
226 }
227 } else {
228 for i in 0..self.nactive() {
229 if self.alpha[i].free_floating() {
230 let dist_i = self.kernel.distances(i, self.ntotal());
231 let alpha_i = self.alpha[i].val();
232 for j in self.nactive()..self.ntotal() {
233 self.gradient[j] += alpha_i * dist_i[j];
234 }
235 }
236 }
237 }
238 }
239
240 pub fn update(&mut self, working_set: (usize, usize)) {
241 let (i, j) = working_set;
243
244 let dist_i = self.kernel.distances(i, self.nactive());
245 let dist_j = self.kernel.distances(j, self.nactive());
246
247 let bound_i = self.bound(i);
248 let bound_j = self.bound(j);
249
250 let old_alpha_i = self.alpha[i].val();
251 let old_alpha_j = self.alpha[j].val();
252
253 if self.targets[i] != self.targets[j] {
254 let mut quad_coef = self.kernel.self_distance(i)
255 + self.kernel.self_distance(j)
256 + (F::one() + F::one()) * dist_i[j];
257 if quad_coef <= F::zero() {
258 quad_coef = F::cast(1e-10);
259 }
260
261 let delta = -(self.gradient[i] + self.gradient[j]) / quad_coef;
262 let diff = self.alpha[i].val() - self.alpha[j].val();
263
264 self.alpha[i].value += delta;
266 self.alpha[j].value += delta;
267
268 if diff > F::zero() {
270 if self.alpha[j].val() < F::zero() {
271 self.alpha[j].value = F::zero();
272 self.alpha[i].value = diff;
273 }
274 } else if self.alpha[i].val() < F::zero() {
275 self.alpha[i].value = F::zero();
276 self.alpha[j].value = -diff;
277 }
278
279 if diff > bound_i - bound_j {
280 if self.alpha[i].val() > bound_i {
281 self.alpha[i].value = bound_i;
282 self.alpha[j].value = bound_i - diff;
283 }
284 } else if self.alpha[j].val() > bound_j {
285 self.alpha[j].value = bound_j;
286 self.alpha[i].value = bound_j + diff;
287 }
288 } else {
289 let mut quad_coef = self.kernel.self_distance(i) + self.kernel.self_distance(j)
291 - F::cast(2.0) * dist_i[j];
292 if quad_coef <= F::zero() {
293 quad_coef = F::cast(1e-10);
294 }
295
296 let delta = (self.gradient[i] - self.gradient[j]) / quad_coef;
297 let sum = self.alpha[i].val() + self.alpha[j].val();
298
299 self.alpha[i].value -= delta;
301 self.alpha[j].value += delta;
302
303 if sum > bound_i {
305 if self.alpha[i].val() > bound_i {
306 self.alpha[i].value = bound_i;
307 self.alpha[j].value = sum - bound_i;
308 }
309 } else if self.alpha[j].val() < F::zero() {
310 self.alpha[j].value = F::zero();
311 self.alpha[i].value = sum;
312 }
313 if sum > bound_j {
314 if self.alpha[j].val() > bound_j {
315 self.alpha[j].value = bound_j;
316 self.alpha[i].value = sum - bound_j;
317 }
318 } else if self.alpha[i].val() < F::zero() {
319 self.alpha[i].value = F::zero();
320 self.alpha[j].value = sum;
321 }
322 }
334
335 let delta_alpha_i = self.alpha[i].val() - old_alpha_i;
337 let delta_alpha_j = self.alpha[j].val() - old_alpha_j;
338
339 for k in 0..self.nactive() {
340 self.gradient[k] += dist_i[k] * delta_alpha_i + dist_j[k] * delta_alpha_j;
341 }
342
343 let ui = self.alpha[i].reached_upper();
345 let uj = self.alpha[j].reached_upper();
346
347 self.alpha[i] = Alpha::from(self.alpha[i].val(), self.bound(i));
348 self.alpha[j] = Alpha::from(self.alpha[j].val(), self.bound(j));
349
350 if ui != self.alpha[i].reached_upper() {
352 let dist_i = self.kernel.distances(i, self.ntotal());
353 let bound_i = self.bound(i);
354 if ui {
355 for k in 0..self.ntotal() {
356 self.gradient_fixed[k] -= bound_i * dist_i[k];
357 }
358 } else {
359 for k in 0..self.ntotal() {
360 self.gradient_fixed[k] += bound_i * dist_i[k];
361 }
362 }
363 }
364
365 if uj != self.alpha[j].reached_upper() {
367 let dist_j = self.kernel.distances(j, self.ntotal());
368 let bound_j = self.bound(j);
369 if uj {
370 for k in 0..self.nactive() {
371 self.gradient_fixed[k] -= bound_j * dist_j[k];
372 }
373 } else {
374 for k in 0..self.nactive() {
375 self.gradient_fixed[k] += bound_j * dist_j[k];
376 }
377 }
378 }
379 }
380
381 pub fn max_violating_pair(&self) -> ((F, isize), (F, isize)) {
383 let mut gmax1 = (-F::infinity(), -1);
385 let mut gmax2 = (-F::infinity(), -1);
387
388 for i in 0..self.nactive() {
389 if self.targets[i] {
390 if !self.alpha[i].reached_upper() && -self.gradient[i] >= gmax1.0 {
391 gmax1 = (-self.gradient[i], i as isize);
392 }
393 if !self.alpha[i].reached_lower() && self.gradient[i] >= gmax2.0 {
394 gmax2 = (self.gradient[i], i as isize);
395 }
396 } else {
397 if !self.alpha[i].reached_upper() && -self.gradient[i] >= gmax2.0 {
398 gmax2 = (-self.gradient[i], i as isize);
399 }
400 if !self.alpha[i].reached_lower() && self.gradient[i] >= gmax1.0 {
401 gmax1 = (self.gradient[i], i as isize);
402 }
403 }
404 }
405
406 (gmax1, gmax2)
407 }
408
409 #[allow(clippy::type_complexity)]
410 pub fn max_violating_pair_nu(&self) -> ((F, isize), (F, isize), (F, isize), (F, isize)) {
411 let mut gmax1 = (-F::infinity(), -1);
412 let mut gmax2 = (-F::infinity(), -1);
413 let mut gmax3 = (-F::infinity(), -1);
414 let mut gmax4 = (-F::infinity(), -1);
415
416 for i in 0..self.nactive() {
417 if self.targets[i] {
418 if !self.alpha[i].reached_upper() && -self.gradient[i] > gmax1.0 {
419 gmax1 = (-self.gradient[i], i as isize);
420 }
421 if !self.alpha[i].reached_lower() && self.gradient[i] > gmax3.0 {
422 gmax3 = (self.gradient[i], i as isize);
423 }
424 } else {
425 if !self.alpha[i].reached_upper() && -self.gradient[i] > gmax4.0 {
426 gmax4 = (-self.gradient[i], i as isize);
427 }
428 if !self.alpha[i].reached_lower() && self.gradient[i] > gmax2.0 {
429 gmax2 = (self.gradient[i], i as isize);
430 }
431 }
432 }
433
434 (gmax1, gmax2, gmax3, gmax4)
435 }
436
437 pub fn select_working_set(&self) -> (usize, usize, bool) {
444 if self.nu_constraint {
445 return self.select_working_set_nu();
446 }
447
448 let (gmax, gmax2) = self.max_violating_pair();
449
450 let mut obj_diff_min = (F::infinity(), -1);
451
452 if gmax.1 != -1 {
453 let dist_i = self.kernel.distances(gmax.1 as usize, self.ntotal());
454
455 for (j, dist_ij) in dist_i.into_iter().enumerate().take(self.nactive()) {
456 if self.targets[j] {
457 if !self.alpha[j].reached_lower() {
458 let grad_diff = gmax.0 + self.gradient[j];
459 if grad_diff > F::zero() {
460 let i = gmax.1 as usize;
462
463 let quad_coef = self.kernel.self_distance(i)
464 + self.kernel.self_distance(j)
465 - F::cast(2.0) * self.target(i) * dist_ij;
466
467 let obj_diff = if quad_coef > F::zero() {
468 -(grad_diff * grad_diff) / quad_coef
469 } else {
470 -(grad_diff * grad_diff) / F::cast(1e-10)
471 };
472
473 if obj_diff <= obj_diff_min.0 {
474 obj_diff_min = (obj_diff, j as isize);
475 }
476 }
477 }
478 } else if !self.alpha[j].reached_upper() {
479 let grad_diff = gmax.0 - self.gradient[j];
480 if grad_diff > F::zero() {
481 let i = gmax.1 as usize;
483
484 let quad_coef = self.kernel.self_distance(i)
485 + self.kernel.self_distance(j)
486 + F::cast(2.0) * self.target(i) * dist_ij;
487
488 let obj_diff = if quad_coef > F::zero() {
489 -(grad_diff * grad_diff) / quad_coef
490 } else {
491 -(grad_diff * grad_diff) / F::cast(1e-10)
492 };
493 if obj_diff <= obj_diff_min.0 {
494 obj_diff_min = (obj_diff, j as isize);
495 }
496 }
497 }
498 }
499 }
500
501 if gmax.0 + gmax2.0 < self.params.eps || obj_diff_min.1 == -1 {
502 (0, 0, true)
503 } else {
504 (gmax.1 as usize, obj_diff_min.1 as usize, false)
505 }
506 }
507
508 pub fn select_working_set_nu(&self) -> (usize, usize, bool) {
515 let (gmaxp1, gmaxn1, gmaxp2, gmaxn2) = self.max_violating_pair_nu();
516
517 let mut obj_diff_min = (F::infinity(), -1);
518
519 let dist_i_p = if gmaxp1.1 != -1 {
520 Some(self.kernel.distances(gmaxp1.1 as usize, self.ntotal()))
521 } else {
522 None
523 };
524
525 let dist_i_n = if gmaxn1.1 != -1 {
526 Some(self.kernel.distances(gmaxn1.1 as usize, self.ntotal()))
527 } else {
528 None
529 };
530
531 for j in 0..self.nactive() {
532 if self.targets[j] {
533 if !self.alpha[j].reached_lower() {
534 let grad_diff = gmaxp1.0 + self.gradient[j];
535 if grad_diff > F::zero() {
536 let dist_i_p = match dist_i_p {
537 Some(ref x) => x,
538 None => continue,
539 };
540
541 let i = gmaxp1.1 as usize;
543
544 let quad_coef = self.kernel.self_distance(i) + self.kernel.self_distance(j)
545 - F::cast(2.0) * dist_i_p[j];
546
547 let obj_diff = if quad_coef > F::zero() {
548 -(grad_diff * grad_diff) / quad_coef
549 } else {
550 -(grad_diff * grad_diff) / F::cast(1e-10)
551 };
552
553 if obj_diff <= obj_diff_min.0 {
554 obj_diff_min = (obj_diff, j as isize);
555 }
556 }
557 }
558 } else if !self.alpha[j].reached_upper() {
559 let grad_diff = gmaxn1.0 - self.gradient[j];
560 if grad_diff > F::zero() {
561 let dist_i_n = match dist_i_n {
562 Some(ref x) => x,
563 None => continue,
564 };
565
566 let i = gmaxn1.1 as usize;
568
569 let quad_coef = self.kernel.self_distance(i) + self.kernel.self_distance(j)
570 - F::cast(2.0) * dist_i_n[j];
571
572 let obj_diff = if quad_coef > F::zero() {
573 -(grad_diff * grad_diff) / quad_coef
574 } else {
575 -(grad_diff * grad_diff) / F::cast(1e-10)
576 };
577 if obj_diff <= obj_diff_min.0 {
578 obj_diff_min = (obj_diff, j as isize);
579 }
580 }
581 }
582 }
583
584 if F::max(gmaxp1.0 + gmaxp2.0, gmaxn1.0 + gmaxn2.0) < self.params.eps
585 || obj_diff_min.1 == -1
586 {
587 (0, 0, true)
588 } else {
589 let out_j = obj_diff_min.1 as usize;
590 let out_i = if self.targets[out_j] {
591 gmaxp1.1 as usize
592 } else {
593 gmaxn1.1 as usize
594 };
595
596 (out_i, out_j, false)
597 }
598 }
599
600 pub fn should_shrunk(&self, i: usize, gmax1: F, gmax2: F) -> bool {
601 if self.alpha[i].reached_upper() {
602 if self.targets[i] {
603 -self.gradient[i] > gmax1
604 } else {
605 -self.gradient[i] > gmax2
606 }
607 } else if self.alpha[i].reached_lower() {
608 if self.targets[i] {
609 self.gradient[i] > gmax2
610 } else {
611 -self.gradient[i] > gmax1
612 }
613 } else {
614 false
615 }
616 }
617
618 pub fn should_shrunk_nu(&self, i: usize, gmax1: F, gmax2: F, gmax3: F, gmax4: F) -> bool {
619 if self.alpha[i].reached_upper() {
620 if self.targets[i] {
621 -self.gradient[i] > gmax1
622 } else {
623 -self.gradient[i] > gmax4
624 }
625 } else if self.alpha[i].reached_lower() {
626 if self.targets[i] {
627 self.gradient[i] > gmax2
628 } else {
629 self.gradient[i] > gmax3
630 }
631 } else {
632 false
633 }
634 }
635
636 pub fn do_shrinking(&mut self) {
637 if self.nu_constraint {
638 self.do_shrinking_nu();
639 return;
640 }
641
642 let (gmax1, gmax2) = self.max_violating_pair();
643 let (gmax1, gmax2) = (gmax1.0, gmax2.0);
644
645 if !self.unshrink && gmax1 + gmax2 <= self.params.eps * F::cast(10.0) {
647 self.unshrink = true;
648 self.reconstruct_gradient();
649 self.nactive = self.ntotal();
650 }
651
652 for i in 0..self.nactive() {
654 if self.should_shrunk(i, gmax1, gmax2) {
655 self.nactive -= 1;
656 while self.nactive > i {
658 if !self.should_shrunk(self.nactive(), gmax1, gmax2) {
659 self.swap(i, self.nactive());
660 break;
661 }
662 self.nactive -= 1;
663 }
664 }
665 }
666 }
667
668 pub fn do_shrinking_nu(&mut self) {
669 let (gmax1, gmax2, gmax3, gmax4) = self.max_violating_pair_nu();
670 let (gmax1, gmax2, gmax3, gmax4) = (gmax1.0, gmax2.0, gmax3.0, gmax4.0);
671
672 if !self.unshrink && F::max(gmax1 + gmax2, gmax3 + gmax4) <= self.params.eps * F::cast(10.0)
674 {
675 self.unshrink = true;
676 self.reconstruct_gradient();
677 self.nactive = self.ntotal();
678 }
679
680 for i in 0..self.nactive() {
682 if self.should_shrunk_nu(i, gmax1, gmax2, gmax3, gmax4) {
683 self.nactive -= 1;
684 while self.nactive > i {
686 if !self.should_shrunk_nu(self.nactive(), gmax1, gmax2, gmax3, gmax4) {
687 self.swap(i, self.nactive());
688 break;
689 }
690 self.nactive -= 1;
691 }
692 }
693 }
694 }
695
696 pub fn calculate_rho(&mut self) -> F {
697 if self.nu_constraint {
699 return self.calculate_rho_nu();
700 }
701
702 let mut nfree = 0;
703 let mut sum_free = F::zero();
704 let mut ub = F::infinity();
705 let mut lb = -F::infinity();
706
707 for i in 0..self.nactive() {
708 let yg = self.target(i) * self.gradient[i];
709
710 if self.alpha[i].reached_upper() {
711 if self.targets[i] {
712 lb = F::max(lb, yg);
713 } else {
714 ub = F::min(ub, yg);
715 }
716 } else if self.alpha[i].reached_lower() {
717 if self.targets[i] {
718 ub = F::min(ub, yg);
719 } else {
720 lb = F::max(lb, yg);
721 }
722 } else {
723 nfree += 1;
724 sum_free += yg;
725 }
726 }
727
728 if nfree > 0 {
729 sum_free / F::cast(nfree)
730 } else {
731 (ub + lb) / F::cast(2.0)
732 }
733 }
734
735 pub fn calculate_rho_nu(&mut self) -> F {
736 let (mut nfree1, mut nfree2) = (0, 0);
737 let (mut sum_free1, mut sum_free2) = (F::zero(), F::zero());
738 let (mut ub1, mut ub2) = (F::infinity(), F::infinity());
739 let (mut lb1, mut lb2) = (-F::infinity(), -F::infinity());
740
741 for i in 0..self.nactive() {
742 if self.targets[i] {
743 if self.alpha[i].reached_upper() {
744 lb1 = F::max(lb1, self.gradient[i]);
745 } else if self.alpha[i].reached_lower() {
746 ub1 = F::max(ub1, self.gradient[i]);
747 } else {
748 nfree1 += 1;
749 sum_free1 += self.gradient[i];
750 }
751 }
752
753 if !self.targets[i] {
754 if self.alpha[i].reached_upper() {
755 lb2 = F::max(lb2, self.gradient[i]);
756 } else if self.alpha[i].reached_lower() {
757 ub2 = F::max(ub2, self.gradient[i]);
758 } else {
759 nfree2 += 1;
760 sum_free2 += self.gradient[i];
761 }
762 }
763 }
764
765 let r1 = if nfree1 > 0 {
766 sum_free1 / F::cast(nfree1)
767 } else {
768 (ub1 + lb1) / F::cast(2.0)
769 };
770 let r2 = if nfree2 > 0 {
771 sum_free2 / F::cast(nfree2)
772 } else {
773 (ub2 + lb2) / F::cast(2.0)
774 };
775
776 self.r = (r1 + r2) / F::cast(2.0);
777
778 (r1 - r2) / F::cast(2.0)
779 }
780
781 pub fn solve(mut self) -> Svm<F, F> {
782 let mut iter = 0;
783 let max_iter = if self.targets.len() > usize::MAX / 100 {
784 usize::MAX
785 } else {
786 100 * self.targets.len()
787 };
788
789 let max_iter = usize::max(10_000_000, max_iter);
790 let mut counter = usize::min(self.targets.len(), 1000) + 1;
791 while iter < max_iter {
792 counter -= 1;
793 if counter == 0 {
794 counter = usize::min(self.ntotal(), 1000);
795 if self.params.shrinking {
796 self.do_shrinking();
797 }
798 }
799
800 let (mut i, mut j, is_optimal) = self.select_working_set();
801 if is_optimal {
802 self.reconstruct_gradient();
803 let (i2, j2, is_optimal) = self.select_working_set();
804 if is_optimal {
805 break;
806 } else {
807 counter = 1;
809 i = i2;
810 j = j2;
811 }
812 }
813
814 iter += 1;
815
816 self.update((i, j));
818 }
819
820 if iter >= max_iter && self.nactive() < self.targets.len() {
821 self.reconstruct_gradient();
822 self.nactive = self.ntotal();
823 }
824
825 let rho = self.calculate_rho();
826 let r = if self.nu_constraint {
827 Some(self.r)
828 } else {
829 None
830 };
831
832 let mut v = F::zero();
834 for i in 0..self.targets.len() {
835 v += self.alpha[i].val() * (self.gradient[i] + self.p[i]);
836 }
837 let obj = v / F::cast(2.0);
838
839 let exit_reason = if max_iter == iter {
840 ExitReason::ReachedIterations
841 } else {
842 ExitReason::ReachedThreshold
843 };
844
845 let mut alpha: Vec<F> = (0..self.ntotal())
847 .map(|i| self.alpha[self.active_set[i]].val())
848 .collect();
849
850 if self.ntotal() > self.dataset.len_of(Axis(0)) {
859 for i in 0..self.dataset.len_of(Axis(0)) {
860 let tmp = alpha[i + self.dataset.len_of(Axis(0))];
861 alpha[i] -= tmp;
862 }
863 alpha.truncate(self.dataset.len_of(Axis(0)));
864 }
865
866 let alpha = alpha;
868
869 let sep_hyperplane = if self.kernel.inner().is_linear() {
875 let mut tmp = Array1::zeros(self.dataset.len_of(Axis(1)));
876
877 for (i, elm) in self.dataset.outer_iter().enumerate() {
878 tmp.scaled_add(self.target(i) * alpha[i], &elm);
879 }
880
881 SeparatingHyperplane::Linear(tmp)
882 } else {
883 let support_vectors = self.dataset.select(
884 Axis(0),
885 &alpha
886 .iter()
887 .enumerate()
888 .filter(|(_, a)| a.abs() > F::cast(100.) * F::epsilon())
889 .map(|(i, _)| i)
890 .collect::<Vec<_>>(),
891 );
892 SeparatingHyperplane::WeightedCombination(support_vectors)
893 };
894
895 Svm {
896 alpha,
897 rho,
898 r,
899 exit_reason,
900 obj,
901 iterations: iter,
902 sep_hyperplane,
903 kernel_method: self.kernel.into_inner().method,
904 probability_coeffs: None,
905 phantom: PhantomData,
906 }
907 }
908}
909
910#[cfg(test)]
911mod tests {
912 use super::{SolverParams, SolverState};
913 use crate::permutable_kernel::PermutableKernel;
914 use crate::SeparatingHyperplane;
915
916 #[test]
917 fn autotraits() {
918 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
919 has_autotraits::<SolverState<f64, PermutableKernel<f64>>>();
920 has_autotraits::<SolverParams<f64>>();
921 has_autotraits::<SeparatingHyperplane<f64>>();
922 }
923}
924