linfa_svm/
solver_smo.rs

1use super::permutable_kernel::Permutable;
2use super::{ExitReason, Float, Svm};
3
4use ndarray::{Array1, Array2, ArrayView2, Axis};
5#[cfg(feature = "serde")]
6use serde_crate::{Deserialize, Serialize};
7use std::marker::PhantomData;
8
9/// Parameters of the solver routine
10#[derive(Clone, Debug, PartialEq)]
11pub struct SolverParams<F: Float> {
12    /// Stopping condition
13    pub eps: F,
14    /// Should we shrink, e.g. ignore bounded alphas
15    pub shrinking: bool,
16}
17
18/// Status of alpha variables of the solver
19#[derive(Clone, Debug, PartialEq)]
20struct Alpha<F: Float> {
21    value: F,
22    upper_bound: F,
23}
24
25impl<F: Float> Alpha<F> {
26    pub fn from(value: F, upper_bound: F) -> Alpha<F> {
27        Alpha { value, upper_bound }
28    }
29
30    pub fn reached_upper(&self) -> bool {
31        self.value >= self.upper_bound
32    }
33
34    pub fn free_floating(&self) -> bool {
35        self.value < self.upper_bound && self.value > F::zero()
36    }
37
38    pub fn reached_lower(&self) -> bool {
39        self.value == F::zero()
40    }
41
42    pub fn val(&self) -> F {
43        self.value
44    }
45}
46
47#[cfg_attr(
48    feature = "serde",
49    derive(Serialize, Deserialize),
50    serde(crate = "serde_crate")
51)]
52#[derive(Clone, Debug, PartialEq)]
53pub enum SeparatingHyperplane<F: Float> {
54    Linear(Array1<F>),
55    WeightedCombination(Array2<F>),
56}
57
58/// Current state of the SMO solver
59///
60/// We are solving the dual problem with linear constraints
61/// min_a f(a), s.t. y^Ta = d, 0 <= a_t < C, t = 1, ..., l
62/// where f(a) = a^T Q a / 2 + p^T a
63#[derive(Clone, Debug, PartialEq)]
64pub struct SolverState<'a, F: Float, K: Permutable<F>> {
65    /// Gradient of each variable
66    gradient: Vec<F>,
67    /// Cached gradient because most of the variables are constant
68    gradient_fixed: Vec<F>,
69    /// Current value of each variable and in respect to bounds
70    alpha: Vec<Alpha<F>>,
71    /// Active set of variables
72    active_set: Vec<usize>,
73    /// Number of active variables
74    nactive: usize,
75    unshrink: bool,
76    nu_constraint: bool,
77    r: F,
78
79    /// Training data
80    dataset: ArrayView2<'a, F>,
81
82    /// Quadratic term of the problem
83    kernel: K,
84    /// Linear term of the problem
85    p: Vec<F>,
86    /// Targets we want to predict
87    targets: Vec<bool>,
88    /// Bounds per alpha
89    bounds: Vec<F>,
90
91    /// Parameters, e.g. stopping condition etc.
92    params: SolverParams<F>,
93
94    phantom: PhantomData<&'a K>,
95}
96
97#[allow(clippy::needless_range_loop, clippy::too_many_arguments)]
98impl<'a, F: Float, K: 'a + Permutable<F>> SolverState<'a, F, K> {
99    /// Initialize a solver state
100    ///
101    /// This is bounded by the lifetime of the kernel matrix, because it can quite large
102    pub fn new(
103        alpha: Vec<F>,
104        p: Vec<F>,
105        targets: Vec<bool>,
106        dataset: ArrayView2<'a, F>,
107        kernel: K,
108        bounds: Vec<F>,
109        params: SolverParams<F>,
110        nu_constraint: bool,
111    ) -> SolverState<'a, F, K> {
112        // initialize alpha status according to bound
113        let alpha = alpha
114            .into_iter()
115            .enumerate()
116            .map(|(i, alpha)| Alpha::from(alpha, bounds[i]))
117            .collect::<Vec<_>>();
118
119        // initialize full active set
120        let active_set = (0..alpha.len()).collect::<Vec<_>>();
121
122        // initialize gradient
123        let mut gradient = p.clone();
124        let mut gradient_fixed = vec![F::zero(); alpha.len()];
125
126        for i in 0..alpha.len() {
127            // when we have reached alpha = F::zero(), then d(a) = p
128            if !alpha[i].reached_lower() {
129                let dist_i = kernel.distances(i, alpha.len());
130                let alpha_i = alpha[i].val();
131
132                // update gradient as d(a) = p + Q a
133                for j in 0..alpha.len() {
134                    gradient[j] += alpha_i * dist_i[j];
135                }
136
137                // Cache gradient when we reached the upper bound for a variable
138                if alpha[i].reached_upper() {
139                    for j in 0..alpha.len() {
140                        gradient_fixed[j] += bounds[i] * dist_i[j];
141                    }
142                }
143            }
144        }
145
146        SolverState {
147            gradient,
148            gradient_fixed,
149            alpha,
150            p,
151            nactive: active_set.len(),
152            unshrink: false,
153            active_set,
154            dataset,
155            kernel,
156            targets,
157            bounds,
158            params,
159            nu_constraint,
160            r: F::zero(),
161            phantom: PhantomData,
162        }
163    }
164
165    /// Return number of active variables
166    pub fn nactive(&self) -> usize {
167        self.nactive
168    }
169
170    /// Return number of total variables
171    pub fn ntotal(&self) -> usize {
172        self.alpha.len()
173    }
174
175    /// Return target as positive/negative indicator
176    pub fn target(&self, idx: usize) -> F {
177        if self.targets[idx] {
178            F::one()
179        } else {
180            -F::one()
181        }
182    }
183
184    /// Return the k-th bound
185    pub fn bound(&self, idx: usize) -> F {
186        self.bounds[idx]
187    }
188
189    /// Swap two variables
190    pub fn swap(&mut self, i: usize, j: usize) {
191        self.gradient.swap(i, j);
192        self.gradient_fixed.swap(i, j);
193        self.alpha.swap(i, j);
194        self.p.swap(i, j);
195        self.active_set.swap(i, j);
196        self.kernel.swap_indices(i, j);
197        self.targets.swap(i, j);
198    }
199
200    /// Reconstruct gradients from inactivate variables
201    ///
202    /// A variables is inactive, when it reaches the upper bound.
203    ///
204    fn reconstruct_gradient(&mut self) {
205        // if no variable is inactive, skip
206        if self.nactive() == self.ntotal() {
207            return;
208        }
209
210        // d(a_i) = G^_i + p_i + ...
211        for j in self.nactive()..self.ntotal() {
212            self.gradient[j] = self.gradient_fixed[j] + self.p[j];
213        }
214
215        let nfree: usize = (0..self.nactive())
216            .filter(|x| self.alpha[*x].free_floating())
217            .count();
218        if nfree * self.ntotal() > 2 * self.nactive() * (self.ntotal() - self.nactive()) {
219            for i in self.nactive()..self.ntotal() {
220                let dist_i = self.kernel.distances(i, self.nactive());
221                for j in 0..self.nactive() {
222                    if self.alpha[i].free_floating() {
223                        self.gradient[i] += self.alpha[j].val() * dist_i[j];
224                    }
225                }
226            }
227        } else {
228            for i in 0..self.nactive() {
229                if self.alpha[i].free_floating() {
230                    let dist_i = self.kernel.distances(i, self.ntotal());
231                    let alpha_i = self.alpha[i].val();
232                    for j in self.nactive()..self.ntotal() {
233                        self.gradient[j] += alpha_i * dist_i[j];
234                    }
235                }
236            }
237        }
238    }
239
240    pub fn update(&mut self, working_set: (usize, usize)) {
241        // working set indices are called i, j here
242        let (i, j) = working_set;
243
244        let dist_i = self.kernel.distances(i, self.nactive());
245        let dist_j = self.kernel.distances(j, self.nactive());
246
247        let bound_i = self.bound(i);
248        let bound_j = self.bound(j);
249
250        let old_alpha_i = self.alpha[i].val();
251        let old_alpha_j = self.alpha[j].val();
252
253        if self.targets[i] != self.targets[j] {
254            let mut quad_coef = self.kernel.self_distance(i)
255                + self.kernel.self_distance(j)
256                + (F::one() + F::one()) * dist_i[j];
257            if quad_coef <= F::zero() {
258                quad_coef = F::cast(1e-10);
259            }
260
261            let delta = -(self.gradient[i] + self.gradient[j]) / quad_coef;
262            let diff = self.alpha[i].val() - self.alpha[j].val();
263
264            // update parameters
265            self.alpha[i].value += delta;
266            self.alpha[j].value += delta;
267
268            // bound to feasible solution
269            if diff > F::zero() {
270                if self.alpha[j].val() < F::zero() {
271                    self.alpha[j].value = F::zero();
272                    self.alpha[i].value = diff;
273                }
274            } else if self.alpha[i].val() < F::zero() {
275                self.alpha[i].value = F::zero();
276                self.alpha[j].value = -diff;
277            }
278
279            if diff > bound_i - bound_j {
280                if self.alpha[i].val() > bound_i {
281                    self.alpha[i].value = bound_i;
282                    self.alpha[j].value = bound_i - diff;
283                }
284            } else if self.alpha[j].val() > bound_j {
285                self.alpha[j].value = bound_j;
286                self.alpha[i].value = bound_j + diff;
287            }
288        } else {
289            //dbg!(self.kernel.self_distance(i), self.kernel.self_distance(j), F::cast(2.0) * dist_i[j]);
290            let mut quad_coef = self.kernel.self_distance(i) + self.kernel.self_distance(j)
291                - F::cast(2.0) * dist_i[j];
292            if quad_coef <= F::zero() {
293                quad_coef = F::cast(1e-10);
294            }
295
296            let delta = (self.gradient[i] - self.gradient[j]) / quad_coef;
297            let sum = self.alpha[i].val() + self.alpha[j].val();
298
299            // update parameters
300            self.alpha[i].value -= delta;
301            self.alpha[j].value += delta;
302
303            // bound to feasible solution
304            if sum > bound_i {
305                if self.alpha[i].val() > bound_i {
306                    self.alpha[i].value = bound_i;
307                    self.alpha[j].value = sum - bound_i;
308                }
309            } else if self.alpha[j].val() < F::zero() {
310                self.alpha[j].value = F::zero();
311                self.alpha[i].value = sum;
312            }
313            if sum > bound_j {
314                if self.alpha[j].val() > bound_j {
315                    self.alpha[j].value = bound_j;
316                    self.alpha[i].value = sum - bound_j;
317                }
318            } else if self.alpha[i].val() < F::zero() {
319                self.alpha[i].value = F::zero();
320                self.alpha[j].value = sum;
321            }
322            /*if self.alpha[i].val() > bound_i {
323                self.alpha[i].value = bound_i;
324            } else if self.alpha[i].val() < F::zero() {
325                self.alpha[i].value = F::zero();
326            }
327
328            if self.alpha[j].val() > bound_j {
329                self.alpha[j].value = bound_j;
330            } else if self.alpha[j].val() < F::zero() {
331                self.alpha[j].value = F::zero();
332            }*/
333        }
334
335        // update gradient
336        let delta_alpha_i = self.alpha[i].val() - old_alpha_i;
337        let delta_alpha_j = self.alpha[j].val() - old_alpha_j;
338
339        for k in 0..self.nactive() {
340            self.gradient[k] += dist_i[k] * delta_alpha_i + dist_j[k] * delta_alpha_j;
341        }
342
343        // update alpha status and gradient bar
344        let ui = self.alpha[i].reached_upper();
345        let uj = self.alpha[j].reached_upper();
346
347        self.alpha[i] = Alpha::from(self.alpha[i].val(), self.bound(i));
348        self.alpha[j] = Alpha::from(self.alpha[j].val(), self.bound(j));
349
350        // update gradient of non-free variables if `i` became free or non-free
351        if ui != self.alpha[i].reached_upper() {
352            let dist_i = self.kernel.distances(i, self.ntotal());
353            let bound_i = self.bound(i);
354            if ui {
355                for k in 0..self.ntotal() {
356                    self.gradient_fixed[k] -= bound_i * dist_i[k];
357                }
358            } else {
359                for k in 0..self.ntotal() {
360                    self.gradient_fixed[k] += bound_i * dist_i[k];
361                }
362            }
363        }
364
365        // update gradient of non-free variables if `j` became free or non-free
366        if uj != self.alpha[j].reached_upper() {
367            let dist_j = self.kernel.distances(j, self.ntotal());
368            let bound_j = self.bound(j);
369            if uj {
370                for k in 0..self.nactive() {
371                    self.gradient_fixed[k] -= bound_j * dist_j[k];
372                }
373            } else {
374                for k in 0..self.nactive() {
375                    self.gradient_fixed[k] += bound_j * dist_j[k];
376                }
377            }
378        }
379    }
380
381    /// Return max and min gradients of free variables
382    pub fn max_violating_pair(&self) -> ((F, isize), (F, isize)) {
383        // max { -y_i * grad(f)_i \i in I_up(\alpha) }
384        let mut gmax1 = (-F::infinity(), -1);
385        // max { y_i * grad(f)_i \i in U_low(\alpha) }
386        let mut gmax2 = (-F::infinity(), -1);
387
388        for i in 0..self.nactive() {
389            if self.targets[i] {
390                if !self.alpha[i].reached_upper() && -self.gradient[i] >= gmax1.0 {
391                    gmax1 = (-self.gradient[i], i as isize);
392                }
393                if !self.alpha[i].reached_lower() && self.gradient[i] >= gmax2.0 {
394                    gmax2 = (self.gradient[i], i as isize);
395                }
396            } else {
397                if !self.alpha[i].reached_upper() && -self.gradient[i] >= gmax2.0 {
398                    gmax2 = (-self.gradient[i], i as isize);
399                }
400                if !self.alpha[i].reached_lower() && self.gradient[i] >= gmax1.0 {
401                    gmax1 = (self.gradient[i], i as isize);
402                }
403            }
404        }
405
406        (gmax1, gmax2)
407    }
408
409    #[allow(clippy::type_complexity)]
410    pub fn max_violating_pair_nu(&self) -> ((F, isize), (F, isize), (F, isize), (F, isize)) {
411        let mut gmax1 = (-F::infinity(), -1);
412        let mut gmax2 = (-F::infinity(), -1);
413        let mut gmax3 = (-F::infinity(), -1);
414        let mut gmax4 = (-F::infinity(), -1);
415
416        for i in 0..self.nactive() {
417            if self.targets[i] {
418                if !self.alpha[i].reached_upper() && -self.gradient[i] > gmax1.0 {
419                    gmax1 = (-self.gradient[i], i as isize);
420                }
421                if !self.alpha[i].reached_lower() && self.gradient[i] > gmax3.0 {
422                    gmax3 = (self.gradient[i], i as isize);
423                }
424            } else {
425                if !self.alpha[i].reached_upper() && -self.gradient[i] > gmax4.0 {
426                    gmax4 = (-self.gradient[i], i as isize);
427                }
428                if !self.alpha[i].reached_lower() && self.gradient[i] > gmax2.0 {
429                    gmax2 = (self.gradient[i], i as isize);
430                }
431            }
432        }
433
434        (gmax1, gmax2, gmax3, gmax4)
435    }
436
437    /// Select optimal working set
438    ///
439    /// In each optimization step two variables are selected and then optimized. The indices are
440    /// selected such that:
441    ///  * i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)
442    ///  * j: minimizes the decrease of the objective value
443    pub fn select_working_set(&self) -> (usize, usize, bool) {
444        if self.nu_constraint {
445            return self.select_working_set_nu();
446        }
447
448        let (gmax, gmax2) = self.max_violating_pair();
449
450        let mut obj_diff_min = (F::infinity(), -1);
451
452        if gmax.1 != -1 {
453            let dist_i = self.kernel.distances(gmax.1 as usize, self.ntotal());
454
455            for (j, dist_ij) in dist_i.into_iter().enumerate().take(self.nactive()) {
456                if self.targets[j] {
457                    if !self.alpha[j].reached_lower() {
458                        let grad_diff = gmax.0 + self.gradient[j];
459                        if grad_diff > F::zero() {
460                            // this is possible, because op_i is some
461                            let i = gmax.1 as usize;
462
463                            let quad_coef = self.kernel.self_distance(i)
464                                + self.kernel.self_distance(j)
465                                - F::cast(2.0) * self.target(i) * dist_ij;
466
467                            let obj_diff = if quad_coef > F::zero() {
468                                -(grad_diff * grad_diff) / quad_coef
469                            } else {
470                                -(grad_diff * grad_diff) / F::cast(1e-10)
471                            };
472
473                            if obj_diff <= obj_diff_min.0 {
474                                obj_diff_min = (obj_diff, j as isize);
475                            }
476                        }
477                    }
478                } else if !self.alpha[j].reached_upper() {
479                    let grad_diff = gmax.0 - self.gradient[j];
480                    if grad_diff > F::zero() {
481                        // this is possible, because op_i is `Some`
482                        let i = gmax.1 as usize;
483
484                        let quad_coef = self.kernel.self_distance(i)
485                            + self.kernel.self_distance(j)
486                            + F::cast(2.0) * self.target(i) * dist_ij;
487
488                        let obj_diff = if quad_coef > F::zero() {
489                            -(grad_diff * grad_diff) / quad_coef
490                        } else {
491                            -(grad_diff * grad_diff) / F::cast(1e-10)
492                        };
493                        if obj_diff <= obj_diff_min.0 {
494                            obj_diff_min = (obj_diff, j as isize);
495                        }
496                    }
497                }
498            }
499        }
500
501        if gmax.0 + gmax2.0 < self.params.eps || obj_diff_min.1 == -1 {
502            (0, 0, true)
503        } else {
504            (gmax.1 as usize, obj_diff_min.1 as usize, false)
505        }
506    }
507
508    /// Select optimal working set
509    ///
510    /// In each optimization step two variables are selected and then optimized. The indices are
511    /// selected such that:
512    ///  * i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)
513    ///  * j: minimizes the decrease of the objective value
514    pub fn select_working_set_nu(&self) -> (usize, usize, bool) {
515        let (gmaxp1, gmaxn1, gmaxp2, gmaxn2) = self.max_violating_pair_nu();
516
517        let mut obj_diff_min = (F::infinity(), -1);
518
519        let dist_i_p = if gmaxp1.1 != -1 {
520            Some(self.kernel.distances(gmaxp1.1 as usize, self.ntotal()))
521        } else {
522            None
523        };
524
525        let dist_i_n = if gmaxn1.1 != -1 {
526            Some(self.kernel.distances(gmaxn1.1 as usize, self.ntotal()))
527        } else {
528            None
529        };
530
531        for j in 0..self.nactive() {
532            if self.targets[j] {
533                if !self.alpha[j].reached_lower() {
534                    let grad_diff = gmaxp1.0 + self.gradient[j];
535                    if grad_diff > F::zero() {
536                        let dist_i_p = match dist_i_p {
537                            Some(ref x) => x,
538                            None => continue,
539                        };
540
541                        // this is possible, because op_i is some
542                        let i = gmaxp1.1 as usize;
543
544                        let quad_coef = self.kernel.self_distance(i) + self.kernel.self_distance(j)
545                            - F::cast(2.0) * dist_i_p[j];
546
547                        let obj_diff = if quad_coef > F::zero() {
548                            -(grad_diff * grad_diff) / quad_coef
549                        } else {
550                            -(grad_diff * grad_diff) / F::cast(1e-10)
551                        };
552
553                        if obj_diff <= obj_diff_min.0 {
554                            obj_diff_min = (obj_diff, j as isize);
555                        }
556                    }
557                }
558            } else if !self.alpha[j].reached_upper() {
559                let grad_diff = gmaxn1.0 - self.gradient[j];
560                if grad_diff > F::zero() {
561                    let dist_i_n = match dist_i_n {
562                        Some(ref x) => x,
563                        None => continue,
564                    };
565
566                    // this is possible, because op_i is `Some`
567                    let i = gmaxn1.1 as usize;
568
569                    let quad_coef = self.kernel.self_distance(i) + self.kernel.self_distance(j)
570                        - F::cast(2.0) * dist_i_n[j];
571
572                    let obj_diff = if quad_coef > F::zero() {
573                        -(grad_diff * grad_diff) / quad_coef
574                    } else {
575                        -(grad_diff * grad_diff) / F::cast(1e-10)
576                    };
577                    if obj_diff <= obj_diff_min.0 {
578                        obj_diff_min = (obj_diff, j as isize);
579                    }
580                }
581            }
582        }
583
584        if F::max(gmaxp1.0 + gmaxp2.0, gmaxn1.0 + gmaxn2.0) < self.params.eps
585            || obj_diff_min.1 == -1
586        {
587            (0, 0, true)
588        } else {
589            let out_j = obj_diff_min.1 as usize;
590            let out_i = if self.targets[out_j] {
591                gmaxp1.1 as usize
592            } else {
593                gmaxn1.1 as usize
594            };
595
596            (out_i, out_j, false)
597        }
598    }
599
600    pub fn should_shrunk(&self, i: usize, gmax1: F, gmax2: F) -> bool {
601        if self.alpha[i].reached_upper() {
602            if self.targets[i] {
603                -self.gradient[i] > gmax1
604            } else {
605                -self.gradient[i] > gmax2
606            }
607        } else if self.alpha[i].reached_lower() {
608            if self.targets[i] {
609                self.gradient[i] > gmax2
610            } else {
611                -self.gradient[i] > gmax1
612            }
613        } else {
614            false
615        }
616    }
617
618    pub fn should_shrunk_nu(&self, i: usize, gmax1: F, gmax2: F, gmax3: F, gmax4: F) -> bool {
619        if self.alpha[i].reached_upper() {
620            if self.targets[i] {
621                -self.gradient[i] > gmax1
622            } else {
623                -self.gradient[i] > gmax4
624            }
625        } else if self.alpha[i].reached_lower() {
626            if self.targets[i] {
627                self.gradient[i] > gmax2
628            } else {
629                self.gradient[i] > gmax3
630            }
631        } else {
632            false
633        }
634    }
635
636    pub fn do_shrinking(&mut self) {
637        if self.nu_constraint {
638            self.do_shrinking_nu();
639            return;
640        }
641
642        let (gmax1, gmax2) = self.max_violating_pair();
643        let (gmax1, gmax2) = (gmax1.0, gmax2.0);
644
645        // work on all variables when 10*eps is reached
646        if !self.unshrink && gmax1 + gmax2 <= self.params.eps * F::cast(10.0) {
647            self.unshrink = true;
648            self.reconstruct_gradient();
649            self.nactive = self.ntotal();
650        }
651
652        // swap items until working set is homogeneous
653        for i in 0..self.nactive() {
654            if self.should_shrunk(i, gmax1, gmax2) {
655                self.nactive -= 1;
656                // only consider items behing this one
657                while self.nactive > i {
658                    if !self.should_shrunk(self.nactive(), gmax1, gmax2) {
659                        self.swap(i, self.nactive());
660                        break;
661                    }
662                    self.nactive -= 1;
663                }
664            }
665        }
666    }
667
668    pub fn do_shrinking_nu(&mut self) {
669        let (gmax1, gmax2, gmax3, gmax4) = self.max_violating_pair_nu();
670        let (gmax1, gmax2, gmax3, gmax4) = (gmax1.0, gmax2.0, gmax3.0, gmax4.0);
671
672        // work on all variables when 10*eps is reached
673        if !self.unshrink && F::max(gmax1 + gmax2, gmax3 + gmax4) <= self.params.eps * F::cast(10.0)
674        {
675            self.unshrink = true;
676            self.reconstruct_gradient();
677            self.nactive = self.ntotal();
678        }
679
680        // swap items until working set is homogeneous
681        for i in 0..self.nactive() {
682            if self.should_shrunk_nu(i, gmax1, gmax2, gmax3, gmax4) {
683                self.nactive -= 1;
684                // only consider items behing this one
685                while self.nactive > i {
686                    if !self.should_shrunk_nu(self.nactive(), gmax1, gmax2, gmax3, gmax4) {
687                        self.swap(i, self.nactive());
688                        break;
689                    }
690                    self.nactive -= 1;
691                }
692            }
693        }
694    }
695
696    pub fn calculate_rho(&mut self) -> F {
697        // with additional constraint call the other function
698        if self.nu_constraint {
699            return self.calculate_rho_nu();
700        }
701
702        let mut nfree = 0;
703        let mut sum_free = F::zero();
704        let mut ub = F::infinity();
705        let mut lb = -F::infinity();
706
707        for i in 0..self.nactive() {
708            let yg = self.target(i) * self.gradient[i];
709
710            if self.alpha[i].reached_upper() {
711                if self.targets[i] {
712                    lb = F::max(lb, yg);
713                } else {
714                    ub = F::min(ub, yg);
715                }
716            } else if self.alpha[i].reached_lower() {
717                if self.targets[i] {
718                    ub = F::min(ub, yg);
719                } else {
720                    lb = F::max(lb, yg);
721                }
722            } else {
723                nfree += 1;
724                sum_free += yg;
725            }
726        }
727
728        if nfree > 0 {
729            sum_free / F::cast(nfree)
730        } else {
731            (ub + lb) / F::cast(2.0)
732        }
733    }
734
735    pub fn calculate_rho_nu(&mut self) -> F {
736        let (mut nfree1, mut nfree2) = (0, 0);
737        let (mut sum_free1, mut sum_free2) = (F::zero(), F::zero());
738        let (mut ub1, mut ub2) = (F::infinity(), F::infinity());
739        let (mut lb1, mut lb2) = (-F::infinity(), -F::infinity());
740
741        for i in 0..self.nactive() {
742            if self.targets[i] {
743                if self.alpha[i].reached_upper() {
744                    lb1 = F::max(lb1, self.gradient[i]);
745                } else if self.alpha[i].reached_lower() {
746                    ub1 = F::max(ub1, self.gradient[i]);
747                } else {
748                    nfree1 += 1;
749                    sum_free1 += self.gradient[i];
750                }
751            }
752
753            if !self.targets[i] {
754                if self.alpha[i].reached_upper() {
755                    lb2 = F::max(lb2, self.gradient[i]);
756                } else if self.alpha[i].reached_lower() {
757                    ub2 = F::max(ub2, self.gradient[i]);
758                } else {
759                    nfree2 += 1;
760                    sum_free2 += self.gradient[i];
761                }
762            }
763        }
764
765        let r1 = if nfree1 > 0 {
766            sum_free1 / F::cast(nfree1)
767        } else {
768            (ub1 + lb1) / F::cast(2.0)
769        };
770        let r2 = if nfree2 > 0 {
771            sum_free2 / F::cast(nfree2)
772        } else {
773            (ub2 + lb2) / F::cast(2.0)
774        };
775
776        self.r = (r1 + r2) / F::cast(2.0);
777
778        (r1 - r2) / F::cast(2.0)
779    }
780
781    pub fn solve(mut self) -> Svm<F, F> {
782        let mut iter = 0;
783        let max_iter = if self.targets.len() > usize::MAX / 100 {
784            usize::MAX
785        } else {
786            100 * self.targets.len()
787        };
788
789        let max_iter = usize::max(10_000_000, max_iter);
790        let mut counter = usize::min(self.targets.len(), 1000) + 1;
791        while iter < max_iter {
792            counter -= 1;
793            if counter == 0 {
794                counter = usize::min(self.ntotal(), 1000);
795                if self.params.shrinking {
796                    self.do_shrinking();
797                }
798            }
799
800            let (mut i, mut j, is_optimal) = self.select_working_set();
801            if is_optimal {
802                self.reconstruct_gradient();
803                let (i2, j2, is_optimal) = self.select_working_set();
804                if is_optimal {
805                    break;
806                } else {
807                    // do shrinking next iteration
808                    counter = 1;
809                    i = i2;
810                    j = j2;
811                }
812            }
813
814            iter += 1;
815
816            // update alpha[i] and alpha[j]
817            self.update((i, j));
818        }
819
820        if iter >= max_iter && self.nactive() < self.targets.len() {
821            self.reconstruct_gradient();
822            self.nactive = self.ntotal();
823        }
824
825        let rho = self.calculate_rho();
826        let r = if self.nu_constraint {
827            Some(self.r)
828        } else {
829            None
830        };
831
832        // calculate object function
833        let mut v = F::zero();
834        for i in 0..self.targets.len() {
835            v += self.alpha[i].val() * (self.gradient[i] + self.p[i]);
836        }
837        let obj = v / F::cast(2.0);
838
839        let exit_reason = if max_iter == iter {
840            ExitReason::ReachedIterations
841        } else {
842            ExitReason::ReachedThreshold
843        };
844
845        // put back the solution
846        let mut alpha: Vec<F> = (0..self.ntotal())
847            .map(|i| self.alpha[self.active_set[i]].val())
848            .collect();
849
850        // If we are solving a regresssion problem the number of alpha values
851        // computed by the solver are 2*(#samples). The final weights of each sample
852        // is then computed as alpha[i] - alpha[#samples + i].
853        // If instead the problem being solved is a calssification problem then
854        // the alpha values are already in the same number as the samples and
855        // they already represent their respective weights
856
857        // Computing the final alpha vaues for regression
858        if self.ntotal() > self.dataset.len_of(Axis(0)) {
859            for i in 0..self.dataset.len_of(Axis(0)) {
860                let tmp = alpha[i + self.dataset.len_of(Axis(0))];
861                alpha[i] -= tmp;
862            }
863            alpha.truncate(self.dataset.len_of(Axis(0)));
864        }
865
866        // Make unmutable
867        let alpha = alpha;
868
869        // Now that the alpha values are set correctly we can proceed to calculate the
870        // support vectors. If the kernel used is linear then they can be pre-combined
871        // and we only need to store the vector given by their combination. If the kernel
872        // is non linear then we need to store all support vectors so that we are able to
873        // compute distances between them and new samples when making predictions.
874        let sep_hyperplane = if self.kernel.inner().is_linear() {
875            let mut tmp = Array1::zeros(self.dataset.len_of(Axis(1)));
876
877            for (i, elm) in self.dataset.outer_iter().enumerate() {
878                tmp.scaled_add(self.target(i) * alpha[i], &elm);
879            }
880
881            SeparatingHyperplane::Linear(tmp)
882        } else {
883            let support_vectors = self.dataset.select(
884                Axis(0),
885                &alpha
886                    .iter()
887                    .enumerate()
888                    .filter(|(_, a)| a.abs() > F::cast(100.) * F::epsilon())
889                    .map(|(i, _)| i)
890                    .collect::<Vec<_>>(),
891            );
892            SeparatingHyperplane::WeightedCombination(support_vectors)
893        };
894
895        Svm {
896            alpha,
897            rho,
898            r,
899            exit_reason,
900            obj,
901            iterations: iter,
902            sep_hyperplane,
903            kernel_method: self.kernel.into_inner().method,
904            probability_coeffs: None,
905            phantom: PhantomData,
906        }
907    }
908}
909
910#[cfg(test)]
911mod tests {
912    use super::{SolverParams, SolverState};
913    use crate::permutable_kernel::PermutableKernel;
914    use crate::SeparatingHyperplane;
915
916    #[test]
917    fn autotraits() {
918        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
919        has_autotraits::<SolverState<f64, PermutableKernel<f64>>>();
920        has_autotraits::<SolverParams<f64>>();
921        has_autotraits::<SeparatingHyperplane<f64>>();
922    }
923}
924/*
925/// Optimize the booth function
926#[test]
927fn test_booth_function() {
928    let kernel = array![[10., 8.], [8., 10.]];
929    let kernel = Kernel {
930        inner: KernelInner::Dense(kernel.clone()),
931        fnc: Box::new(|_,_| 0.0),
932        dataset: &kernel
933    };
934    let targets = vec![true, true];
935    let kernel = PermutableKernel::new(&kernel, targets.clone());
936
937    let p = vec![-34., -38.];
938    let params = SolverParams {
939        eps: 1e-6,
940        shrinking: false
941    };
942
943    let solver = SolverState::new(vec![1.0, 1.0], p, targets, kernel, vec![1000.0; 2], &params, false);
944
945    let res: SvmBase<f64> = solver.solve();
946
947    println!("{:?}", res.alpha);
948    println!("{}", res);
949}*/