1pub mod error;
20
21use crate::error::{Error, Result};
22use argmin::core::{CostFunction, Executor, Gradient, IterState, OptimizationResult, Solver};
23use argmin::solver::linesearch::MoreThuenteLineSearch;
24use argmin::solver::quasinewton::LBFGS;
25use linfa::dataset::AsSingleTargets;
26use linfa::prelude::DatasetBase;
27use linfa::traits::{Fit, PredictInplace};
28use ndarray::{
29 s, Array, Array1, Array2, ArrayBase, ArrayView, ArrayView2, Axis, CowArray, Data, DataMut,
30 Dimension, IntoDimension, Ix1, Ix2, RemoveAxis, Slice, Zip,
31};
32use ndarray_stats::QuantileExt;
33use std::default::Default;
34
35#[cfg(feature = "serde")]
36use serde_crate::de::DeserializeOwned;
37#[cfg(feature = "serde")]
38use serde_crate::{Deserialize, Serialize};
39
40mod argmin_param;
41mod float;
42mod hyperparams;
43
44use argmin_param::*;
45use float::Float;
46use hyperparams::{LogisticRegressionParams, LogisticRegressionValidParams};
47
48pub type LogisticRegression<F> = LogisticRegressionParams<F, Ix1>;
79
80pub type ValidLogisticRegression<F> = LogisticRegressionValidParams<F, Ix1>;
82
83pub type MultiLogisticRegression<F> = LogisticRegressionParams<F, Ix2>;
94
95pub type ValidMultiLogisticRegression<F> = LogisticRegressionValidParams<F, Ix2>;
97
98impl<F: Float, D: Dimension> Default for LogisticRegressionParams<F, D> {
99 fn default() -> Self {
100 LogisticRegressionParams::new()
101 }
102}
103
104type LBFGSType<F, D> = LBFGS<
105 MoreThuenteLineSearch<ArgminParam<F, D>, ArgminParam<F, D>, F>,
106 ArgminParam<F, D>,
107 ArgminParam<F, D>,
108 F,
109>;
110type LBFGSType1<F> = LBFGSType<F, Ix1>;
111type LBFGSType2<F> = LBFGSType<F, Ix2>;
112
113type IterStateType<F, D> = IterState<ArgminParam<F, D>, ArgminParam<F, D>, (), (), F>;
114
115impl<F: Float, D: Dimension> LogisticRegressionValidParams<F, D> {
116 fn setup_init_params(&self, dims: D::Pattern) -> ArgminParam<F, D> {
119 if let Some(params) = self.initial_params.as_ref() {
120 ArgminParam(params.clone())
121 } else {
122 let mut dims = dims.into_dimension();
123 dims.as_array_view_mut()[0] += self.fit_intercept as usize;
124 ArgminParam(Array::zeros(dims))
125 }
126 }
127
128 fn validate_data<A: Data<Elem = F>, B: Data<Elem = F>>(
131 &self,
132 x: &ArrayBase<A, Ix2>,
133 y: &ArrayBase<B, D>,
134 ) -> Result<()> {
135 if x.shape()[0] != y.shape()[0] {
136 return Err(Error::MismatchedShapes(x.shape()[0], y.shape()[0]));
137 }
138 if x.iter().any(|x| !x.is_finite()) || y.iter().any(|y| !y.is_finite()) {
139 return Err(Error::InvalidValues);
140 }
141 self.validate_init_dims(x.shape()[1], y.shape().get(1).copied())?;
142 Ok(())
143 }
144
145 fn validate_init_dims(&self, mut n_features: usize, n_classes: Option<usize>) -> Result<()> {
146 if let Some(params) = self.initial_params.as_ref() {
147 let shape = params.shape();
148 n_features += self.fit_intercept as usize;
149 if n_features != shape[0] {
150 return Err(Error::InitialParameterFeaturesMismatch {
151 n_features,
152 rows: shape[0],
153 });
154 }
155 if let Some(n_classes) = n_classes {
156 if n_classes != shape[1] {
157 return Err(Error::InitialParameterClassesMismatch {
158 n_classes,
159 cols: shape[1],
160 });
161 }
162 }
163 }
164 Ok(())
165 }
166
167 fn setup_problem<'a, A: Data<Elem = F>>(
169 &self,
170 x: &'a ArrayBase<A, Ix2>,
171 target: Array<F, D>,
172 ) -> LogisticRegressionProblem<'a, F, A, D> {
173 LogisticRegressionProblem {
174 x,
175 target,
176 alpha: self.alpha,
177 }
178 }
179
180 fn setup_solver(&self) -> LBFGSType<F, D> {
183 let linesearch = MoreThuenteLineSearch::new();
184 LBFGS::new(linesearch, 10)
185 .with_tolerance_grad(self.gradient_tolerance)
186 .unwrap()
187 }
188}
189
190impl<
191 F: Float,
192 #[cfg(feature = "serde")] D: Dimension + Serialize + DeserializeOwned,
193 #[cfg(not(feature = "serde"))] D: Dimension,
194 > LogisticRegressionValidParams<F, D>
195{
196 fn run_solver<P: SolvableProblem<F, D>>(
198 &self,
199 problem: P,
200 solver: P::Solver,
201 init_params: ArgminParam<F, D>,
202 ) -> Result<OptimizationResult<P, P::Solver, IterStateType<F, D>>> {
203 Executor::new(problem, solver)
204 .configure(|state| state.param(init_params).max_iters(self.max_iterations))
205 .run()
206 .map_err(move |err| err.into())
207 }
208}
209
210impl<C: Ord + Clone, F: Float, D: Data<Elem = F>, T: AsSingleTargets<Elem = C>>
211 Fit<ArrayBase<D, Ix2>, T, Error> for ValidLogisticRegression<F>
212{
213 type Object = FittedLogisticRegression<F, C>;
214
215 fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
232 let (x, y) = (dataset.records(), dataset.targets());
233 let (labels, target) = label_classes(y)?;
234 self.validate_data(x, &target)?;
235 let problem = self.setup_problem(x, target);
236 let solver = self.setup_solver();
237 let init_params = self.setup_init_params(x.ncols());
238 let result = self.run_solver(problem, solver, init_params)?;
239
240 let params = result
241 .state
242 .best_param
243 .unwrap_or(self.setup_init_params(x.ncols()));
244 let (w, intercept) = convert_params(x.ncols(), params.as_array());
245 Ok(FittedLogisticRegression::new(
246 *intercept.view().into_scalar(),
247 w.to_owned(),
248 labels,
249 ))
250 }
251}
252
253impl<C: Ord + Clone, F: Float, D: Data<Elem = F>, T: AsSingleTargets<Elem = C>>
254 Fit<ArrayBase<D, Ix2>, T, Error> for ValidMultiLogisticRegression<F>
255{
256 type Object = MultiFittedLogisticRegression<F, C>;
257
258 fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
269 let (x, y) = (dataset.records(), dataset.targets());
270 let (classes, target) = label_classes_multi(y)?;
271 self.validate_data(x, &target)?;
272 let problem = self.setup_problem(x, target);
273 let solver = self.setup_solver();
274 let init_params = self.setup_init_params((x.ncols(), classes.len()));
275 let result = self.run_solver(problem, solver, init_params)?;
276
277 let params = result
278 .state
279 .best_param
280 .unwrap_or(self.setup_init_params((x.ncols(), classes.len())));
281 let (w, intercept) = convert_params(x.ncols(), params.as_array());
282 Ok(MultiFittedLogisticRegression::new(
283 intercept.to_owned(),
284 w.to_owned(),
285 classes,
286 ))
287 }
288}
289
290fn label_classes<F, T, C>(y: T) -> Result<(BinaryClassLabels<F, C>, Array1<F>)>
297where
298 F: Float,
299 T: AsSingleTargets<Elem = C>,
300 C: Ord + Clone,
301{
302 let y = y.as_single_targets();
303
304 let mut binary_classes = [None, None];
306 for class in y {
308 binary_classes = match binary_classes {
309 [None, None] => [Some((class, 1)), None],
311 [Some((c, count)), c2] if c == class => [Some((class, count + 1)), c2],
313 [c1, Some((c, count))] if c == class => [c1, Some((class, count + 1))],
314 [Some(c1), None] => [Some(c1), Some((class, 1))],
316
317 [None, Some(_)] => unreachable!("impossible binary class array"),
319 [Some(_), Some(_)] => return Err(Error::TooManyClasses),
321 };
322 }
323
324 let (pos_class, neg_class) = match binary_classes {
325 [Some(a), Some(b)] => (a, b),
326 _ => return Err(Error::TooFewClasses),
327 };
328
329 let mut target_array = y
330 .into_iter()
331 .map(|x| {
332 if x == pos_class.0 {
333 F::POSITIVE_LABEL
334 } else {
335 F::NEGATIVE_LABEL
336 }
337 })
338 .collect::<Array1<_>>();
339
340 let (pos_cl, neg_cl) = if pos_class.1 < neg_class.1 {
341 target_array *= -F::one();
345 (neg_class.0.clone(), pos_class.0.clone())
346 } else {
347 (pos_class.0.clone(), neg_class.0.clone())
348 };
349
350 Ok((
351 BinaryClassLabels {
352 pos: ClassLabel {
353 class: pos_cl,
354 label: F::POSITIVE_LABEL,
355 },
356 neg: ClassLabel {
357 class: neg_cl,
358 label: F::NEGATIVE_LABEL,
359 },
360 },
361 target_array,
362 ))
363}
364
365fn label_classes_multi<F, T, C>(y: T) -> Result<(Vec<C>, Array2<F>)>
369where
370 F: Float,
371 T: AsSingleTargets<Elem = C>,
372 C: Ord + Clone,
373{
374 let y_single_target = y.as_single_targets();
375 let mut classes = y_single_target.to_vec();
376 classes.sort();
378 classes.dedup();
379
380 let mut onehot = Array2::zeros((y_single_target.len(), classes.len()));
381 Zip::from(onehot.rows_mut())
382 .and(&y_single_target)
383 .for_each(|mut oh_row, cls| {
384 let idx = classes.binary_search(cls).unwrap();
385 oh_row[idx] = F::one();
386 });
387 Ok((classes, onehot))
388}
389
390fn convert_params<F: Float, D: Dimension + RemoveAxis>(
394 n_features: usize,
395 w: &Array<F, D>,
396) -> (ArrayView<F, D>, CowArray<F, D::Smaller>) {
397 let nrows = w.shape()[0];
398 if n_features == nrows {
399 (
400 w.view(),
401 Array::zeros(w.raw_dim().remove_axis(Axis(0))).into(),
402 )
403 } else if n_features + 1 == nrows {
404 (
405 w.slice_axis(Axis(0), Slice::from(..n_features)),
406 w.index_axis(Axis(0), n_features).into(),
407 )
408 } else {
409 panic!(
410 "Unexpected length of parameter vector `w`, exected {} or {}, found {}",
411 n_features,
412 n_features + 1,
413 nrows
414 );
415 }
416}
417
418fn logistic<F: linfa::Float>(x: F) -> F {
420 F::one() / (F::one() + (-x).exp())
421}
422
423fn log_logistic<F: linfa::Float>(x: F) -> F {
431 if x > F::zero() {
432 -(F::one() + (-x).exp()).ln()
433 } else {
434 x - (F::one() + x.exp()).ln()
435 }
436}
437
438fn log_sum_exp<F: linfa::Float, A: Data<Elem = F>>(
443 m: &ArrayBase<A, Ix2>,
444 axis: Axis,
445) -> Array<F, Ix1> {
446 let max = m.iter().copied().reduce(F::max).unwrap();
448 let reduced = m.fold_axis(axis, F::zero(), |acc, elem| *acc + (*elem - max).exp());
451 reduced.mapv_into(|e| e.max(F::cast(1e-15)).ln() + max)
452}
453
454fn softmax_inplace<F: linfa::Float, A: DataMut<Elem = F>>(v: &mut ArrayBase<A, Ix1>) {
456 let max = v.iter().copied().reduce(F::max).unwrap();
457 v.mapv_inplace(|n| (n - max).exp());
458 let sum = v.sum();
459 v.mapv_inplace(|n| n / sum);
460}
461
462fn logistic_loss<F: Float, A: Data<Elem = F>>(
472 x: &ArrayBase<A, Ix2>,
473 y: &Array1<F>,
474 alpha: F,
475 w: &Array1<F>,
476) -> F {
477 let n_features = x.shape()[1];
478 let (params, intercept) = convert_params(n_features, w);
479 let yz = x.dot(¶ms.into_shape((params.len(), 1)).unwrap()) + intercept;
480 let len = yz.len();
481 let mut yz = yz.into_shape(len).unwrap() * y;
482 yz.mapv_inplace(log_logistic);
483 -yz.sum() + F::cast(0.5) * alpha * params.dot(¶ms)
484}
485
486fn logistic_grad<F: Float, A: Data<Elem = F>>(
488 x: &ArrayBase<A, Ix2>,
489 y: &Array1<F>,
490 alpha: F,
491 w: &Array1<F>,
492) -> Array1<F> {
493 let n_features = x.shape()[1];
494 let (params, intercept) = convert_params(n_features, w);
495 let yz = x.dot(¶ms.into_shape((params.len(), 1)).unwrap()) + intercept;
496 let len = yz.len();
497 let mut yz = yz.into_shape(len).unwrap() * y;
498 yz.mapv_inplace(logistic);
499 yz -= F::one();
500 yz *= y;
501 if w.len() == n_features + 1 {
502 let mut grad = Array::zeros(w.len());
503 grad.slice_mut(s![..n_features])
504 .assign(&(x.t().dot(&yz) + (¶ms * alpha)));
505 grad[n_features] = yz.sum();
506 grad
507 } else {
508 x.t().dot(&yz) + (¶ms * alpha)
509 }
510}
511
512fn multi_logistic_prob_params<'a, F: Float, A: Data<Elem = F>>(
517 x: &ArrayBase<A, Ix2>,
518 w: &'a Array2<F>, ) -> (Array2<F>, ArrayView2<'a, F>) {
520 let n_features = x.shape()[1];
521 let (params, intercept) = convert_params(n_features, w);
522 let h = x.dot(¶ms) + intercept;
524 let log_prob = &h - log_sum_exp(&h, Axis(1)).into_shape((h.nrows(), 1)).unwrap();
527 (log_prob, params)
528}
529
530fn multi_logistic_loss<F: Float, A: Data<Elem = F>>(
532 x: &ArrayBase<A, Ix2>,
533 y: &Array2<F>,
534 alpha: F,
535 w: &Array2<F>,
536) -> F {
537 let (log_prob, params) = multi_logistic_prob_params(x, w);
538 -elem_dot(&log_prob, y) + F::cast(0.5) * alpha * elem_dot(¶ms, ¶ms)
540}
541
542fn multi_logistic_grad<F: Float, A: Data<Elem = F>>(
546 x: &ArrayBase<A, Ix2>,
547 y: &Array2<F>,
548 alpha: F,
549 w: &Array2<F>,
550) -> Array2<F> {
551 let (log_prob, params) = multi_logistic_prob_params(x, w);
552 let (n_features, n_classes) = params.dim();
553 let intercept = w.nrows() > n_features;
554 let mut grad = Array::zeros((n_features + intercept as usize, n_classes));
555
556 let prob = log_prob.mapv_into(num_traits::Float::exp);
558 let diff = prob - y;
559 let dw = x.t().dot(&diff) + (¶ms * alpha);
561 grad.slice_mut(s![..n_features, ..]).assign(&dw);
562 if intercept {
564 grad.row_mut(n_features).assign(&diff.sum_axis(Axis(0)));
565 }
566 grad
567}
568
569#[derive(Debug, Clone, PartialEq)]
571#[cfg_attr(
572 feature = "serde",
573 derive(Serialize, Deserialize),
574 serde(crate = "serde_crate"),
575 serde(bound(deserialize = "C: Deserialize<'de>"))
576)]
577pub struct FittedLogisticRegression<F: Float, C: PartialOrd + Clone> {
578 threshold: F,
579 intercept: F,
580 params: Array1<F>,
581 labels: BinaryClassLabels<F, C>,
582}
583
584impl<F: Float, C: PartialOrd + Clone> FittedLogisticRegression<F, C> {
585 fn new(
586 intercept: F,
587 params: Array1<F>,
588 labels: BinaryClassLabels<F, C>,
589 ) -> FittedLogisticRegression<F, C> {
590 FittedLogisticRegression {
591 threshold: F::cast(0.5),
592 intercept,
593 params,
594 labels,
595 }
596 }
597
598 pub fn set_threshold(mut self, threshold: F) -> FittedLogisticRegression<F, C> {
601 if threshold < F::zero() || threshold > F::one() {
602 panic!("FittedLogisticRegression::set_threshold: threshold needs to be between 0.0 and 1.0");
603 }
604 self.threshold = threshold;
605 self
606 }
607
608 pub fn intercept(&self) -> F {
609 self.intercept
610 }
611
612 pub fn params(&self) -> &Array1<F> {
613 &self.params
614 }
615
616 pub fn labels(&self) -> &BinaryClassLabels<F, C> {
619 &self.labels
620 }
621
622 pub fn predict_probabilities<A: Data<Elem = F>>(&self, x: &ArrayBase<A, Ix2>) -> Array1<F> {
626 let mut probs = x.dot(&self.params) + self.intercept;
627 probs.mapv_inplace(logistic);
628 probs
629 }
630}
631
632impl<C: PartialOrd + Clone + Default, F: Float, D: Data<Elem = F>>
633 PredictInplace<ArrayBase<D, Ix2>, Array1<C>> for FittedLogisticRegression<F, C>
634{
635 fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<C>) {
638 assert_eq!(
639 x.nrows(),
640 y.len(),
641 "The number of data points must match the number of output targets."
642 );
643 assert_eq!(
644 x.ncols(),
645 self.params.len(),
646 "Number of data features must match the number of features the model was trained with."
647 );
648
649 let pos_class = &self.labels.pos.class;
650 let neg_class = &self.labels.neg.class;
651 Zip::from(&self.predict_probabilities(x))
652 .and(y)
653 .for_each(|prob, out| {
654 *out = if *prob >= self.threshold {
655 pos_class.clone()
656 } else {
657 neg_class.clone()
658 }
659 });
660 }
661
662 fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<C> {
663 Array1::default(x.nrows())
664 }
665}
666
667#[derive(Debug, Clone, PartialEq, Eq)]
669#[cfg_attr(
670 feature = "serde",
671 derive(Serialize, Deserialize),
672 serde(crate = "serde_crate")
673)]
674pub struct MultiFittedLogisticRegression<F, C: PartialOrd + Clone> {
675 intercept: Array1<F>,
676 params: Array2<F>,
677 classes: Vec<C>,
678}
679
680impl<F: Float, C: PartialOrd + Clone> MultiFittedLogisticRegression<F, C> {
681 fn new(intercept: Array1<F>, params: Array2<F>, classes: Vec<C>) -> Self {
682 Self {
683 intercept,
684 params,
685 classes,
686 }
687 }
688
689 pub fn intercept(&self) -> &Array1<F> {
690 &self.intercept
691 }
692
693 pub fn params(&self) -> &Array2<F> {
694 &self.params
695 }
696
697 fn predict_nonorm_probabilities<A: Data<Elem = F>>(&self, x: &ArrayBase<A, Ix2>) -> Array2<F> {
699 x.dot(&self.params) + &self.intercept
700 }
701
702 pub fn predict_probabilities<A: Data<Elem = F>>(&self, x: &ArrayBase<A, Ix2>) -> Array2<F> {
705 let mut probs = self.predict_nonorm_probabilities(x);
706 probs
707 .rows_mut()
708 .into_iter()
709 .for_each(|mut row| softmax_inplace(&mut row));
710 probs
711 }
712
713 pub fn classes(&self) -> &[C] {
715 &self.classes
716 }
717}
718
719impl<C: PartialOrd + Clone + Default, F: Float, D: Data<Elem = F>>
720 PredictInplace<ArrayBase<D, Ix2>, Array1<C>> for MultiFittedLogisticRegression<F, C>
721{
722 fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<C>) {
725 assert_eq!(
726 x.nrows(),
727 y.len(),
728 "The number of data points must match the number of output targets."
729 );
730 assert_eq!(
731 x.ncols(),
732 self.params.nrows(),
733 "Number of data features must match the number of features the model was trained with."
734 );
735
736 let probs = self.predict_nonorm_probabilities(x);
737 Zip::from(probs.rows()).and(y).for_each(|prob_row, out| {
738 let idx = prob_row.argmax().unwrap();
739 *out = self.classes[idx].clone();
740 });
741 }
742
743 fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<C> {
744 Array1::default(x.nrows())
745 }
746}
747
748#[derive(Debug, Clone, PartialEq)]
749#[cfg_attr(
750 feature = "serde",
751 derive(Serialize, Deserialize),
752 serde(crate = "serde_crate")
753)]
754pub struct ClassLabel<F, C: PartialOrd> {
755 pub class: C,
756 pub label: F,
757}
758
759#[derive(Debug, Clone, PartialEq)]
760#[cfg_attr(
761 feature = "serde",
762 derive(Serialize, Deserialize),
763 serde(crate = "serde_crate")
764)]
765pub struct BinaryClassLabels<F, C: PartialOrd> {
766 pub pos: ClassLabel<F, C>,
767 pub neg: ClassLabel<F, C>,
768}
769
770struct LogisticRegressionProblem<'a, F: Float, A: Data<Elem = F>, D: Dimension> {
773 x: &'a ArrayBase<A, Ix2>,
774 target: Array<F, D>,
775 alpha: F,
776}
777
778type LogisticRegressionProblem1<'a, F, A> = LogisticRegressionProblem<'a, F, A, Ix1>;
779type LogisticRegressionProblem2<'a, F, A> = LogisticRegressionProblem<'a, F, A, Ix2>;
780
781impl<F: Float, A: Data<Elem = F>> CostFunction for LogisticRegressionProblem1<'_, F, A> {
782 type Param = ArgminParam<F, Ix1>;
783 type Output = F;
784
785 fn cost(&self, p: &Self::Param) -> std::result::Result<Self::Output, argmin::core::Error> {
787 let w = p.as_array();
788 let cost = logistic_loss(self.x, &self.target, self.alpha, w);
789 Ok(cost)
790 }
791}
792
793impl<F: Float, A: Data<Elem = F>> Gradient for LogisticRegressionProblem1<'_, F, A> {
794 type Param = ArgminParam<F, Ix1>;
795 type Gradient = ArgminParam<F, Ix1>;
796
797 fn gradient(&self, p: &Self::Param) -> std::result::Result<Self::Param, argmin::core::Error> {
799 let w = p.as_array();
800 let grad = ArgminParam(logistic_grad(self.x, &self.target, self.alpha, w));
801 Ok(grad)
802 }
803}
804
805impl<F: Float, A: Data<Elem = F>> CostFunction for LogisticRegressionProblem2<'_, F, A> {
806 type Param = ArgminParam<F, Ix2>;
807 type Output = F;
808
809 fn cost(&self, p: &Self::Param) -> std::result::Result<Self::Output, argmin::core::Error> {
811 let w = p.as_array();
812 let cost = multi_logistic_loss(self.x, &self.target, self.alpha, w);
813 Ok(cost)
814 }
815}
816
817impl<F: Float, A: Data<Elem = F>> Gradient for LogisticRegressionProblem2<'_, F, A> {
818 type Param = ArgminParam<F, Ix2>;
819 type Gradient = ArgminParam<F, Ix2>;
820
821 fn gradient(&self, p: &Self::Param) -> std::result::Result<Self::Param, argmin::core::Error> {
823 let w = p.as_array();
824 let grad = ArgminParam(multi_logistic_grad(self.x, &self.target, self.alpha, w));
825 Ok(grad)
826 }
827}
828
829trait SolvableProblem<F: Float, D: Dimension>: Gradient + Sized {
830 type Solver: Solver<Self, IterStateType<F, D>>;
831}
832
833impl<F: Float, A: Data<Elem = F>> SolvableProblem<F, Ix1> for LogisticRegressionProblem1<'_, F, A> {
834 type Solver = LBFGSType1<F>;
835}
836
837impl<F: Float, A: Data<Elem = F>> SolvableProblem<F, Ix2> for LogisticRegressionProblem2<'_, F, A> {
838 type Solver = LBFGSType2<F>;
839}
840
841#[cfg(test)]
842mod test {
843 extern crate linfa;
844
845 use super::Error;
846 use super::*;
847 use approx::{assert_abs_diff_eq, assert_relative_eq, AbsDiffEq};
848 use linfa::prelude::*;
849 use ndarray::{array, Array2, Dim, Ix};
850
851 #[test]
852 fn autotraits() {
853 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
854 has_autotraits::<LogisticRegressionParams<f64, Dim<[Ix; 0]>>>();
855 has_autotraits::<LogisticRegressionValidParams<f64, Dim<[Ix; 0]>>>();
856 has_autotraits::<ArgminParam<f64, Dim<[Ix; 0]>>>();
857 }
858
859 #[test]
863 fn test_logistic_loss() {
864 let x = array![
865 [0.0],
866 [1.0],
867 [2.0],
868 [3.0],
869 [4.0],
870 [5.0],
871 [6.0],
872 [7.0],
873 [8.0],
874 [9.0]
875 ];
876 let y = array![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
877 let ws = vec![
878 array![0.0, 0.0],
879 array![0.0, 1.0],
880 array![1.0, 0.0],
881 array![1.0, 1.0],
882 array![0.0, -1.0],
883 array![-1.0, 0.0],
884 array![-1.0, -1.0],
885 ];
886 let alphas = &[0.0, 1.0, 10.0];
887 let expecteds = vec![
888 6.931471805599453,
889 6.931471805599453,
890 6.931471805599453,
891 4.652158847349118,
892 4.652158847349118,
893 4.652158847349118,
894 2.8012999588008323,
895 3.3012999588008323,
896 7.801299958800833,
897 2.783195429782239,
898 3.283195429782239,
899 7.783195429782239,
900 10.652158847349117,
901 10.652158847349117,
902 10.652158847349117,
903 41.80129995880083,
904 42.30129995880083,
905 46.80129995880083,
906 47.78319542978224,
907 48.28319542978224,
908 52.78319542978224,
909 ];
910
911 for ((w, alpha), exp) in ws
912 .iter()
913 .flat_map(|w| alphas.iter().map(move |&alpha| (w, alpha)))
914 .zip(&expecteds)
915 {
916 assert_abs_diff_eq!(logistic_loss(&x, &y, alpha, w), *exp);
917 }
918 }
919
920 #[test]
924 fn test_logistic_grad() {
925 let x = array![
926 [0.0],
927 [1.0],
928 [2.0],
929 [3.0],
930 [4.0],
931 [5.0],
932 [6.0],
933 [7.0],
934 [8.0],
935 [9.0]
936 ];
937 let y = array![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
938 let ws = vec![
939 array![0.0, 0.0],
940 array![0.0, 1.0],
941 array![1.0, 0.0],
942 array![1.0, 1.0],
943 array![0.0, -1.0],
944 array![-1.0, 0.0],
945 array![-1.0, -1.0],
946 ];
947 let alphas = &[0.0, 1.0, 10.0];
948 let expecteds = vec![
949 array![-19.5, -3.],
950 array![-19.5, -3.],
951 array![-19.5, -3.],
952 array![-10.48871543, -1.61364853],
953 array![-10.48871543, -1.61364853],
954 array![-10.48871543, -1.61364853],
955 array![-0.13041554, -0.02852148],
956 array![0.86958446, -0.02852148],
957 array![9.86958446, -0.02852148],
958 array![-0.04834401, -0.01058067],
959 array![0.95165599, -0.01058067],
960 array![9.95165599, -0.01058067],
961 array![-28.51128457, -4.38635147],
962 array![-28.51128457, -4.38635147],
963 array![-28.51128457, -4.38635147],
964 array![-38.86958446, -5.97147852],
965 array![-39.86958446, -5.97147852],
966 array![-48.86958446, -5.97147852],
967 array![-38.95165599, -5.98941933],
968 array![-39.95165599, -5.98941933],
969 array![-48.95165599, -5.98941933],
970 ];
971
972 for ((w, alpha), exp) in ws
973 .iter()
974 .flat_map(|w| alphas.iter().map(move |&alpha| (w, alpha)))
975 .zip(&expecteds)
976 {
977 let actual = logistic_grad(&x, &y, alpha, w);
978 assert!(actual.abs_diff_eq(exp, 1e-8));
979 }
980 }
981
982 #[test]
983 fn simple_example_1() {
984 let log_reg = LogisticRegression::default();
985 let x = array![[-1.0], [-0.01], [0.01], [1.0]];
986 let y = array![0, 0, 1, 1];
987 let dataset = Dataset::new(x, y);
988 let res = log_reg.fit(&dataset).unwrap();
989 assert_abs_diff_eq!(res.intercept(), 0.0);
990 assert!(res.params().abs_diff_eq(&array![-0.681], 1e-3));
991 assert_eq!(
992 &res.predict(dataset.records()),
993 dataset.targets().as_single_targets()
994 );
995 }
996
997 #[test]
998 fn simple_example_1_cats_dogs() {
999 let log_reg = LogisticRegression::default();
1000 let x = array![[0.01], [1.0], [-1.0], [-0.01]];
1001 let y = array!["dog", "dog", "cat", "cat"];
1002 let dataset = Dataset::new(x, y);
1003 let res = log_reg.fit(&dataset).unwrap();
1004 assert_abs_diff_eq!(res.intercept(), 0.0);
1005 assert!(res.params().abs_diff_eq(&array![0.681], 1e-3));
1006 assert!(res
1007 .predict_probabilities(dataset.records())
1008 .abs_diff_eq(&array![0.501, 0.664, 0.335, 0.498], 1e-3));
1009 assert_eq!(
1010 &res.predict(dataset.records()),
1011 dataset.targets().as_single_targets()
1012 );
1013 assert_eq!(res.labels().pos.class, "dog");
1014 assert_eq!(res.labels().neg.class, "cat");
1015 }
1016
1017 #[test]
1018 fn simple_example_2() {
1019 let log_reg = LogisticRegression::default().alpha(1.0);
1020 let x = array![
1021 [0.0],
1022 [1.0],
1023 [2.0],
1024 [3.0],
1025 [4.0],
1026 [5.0],
1027 [6.0],
1028 [7.0],
1029 [8.0],
1030 [9.0]
1031 ];
1032 let y = array![0, 0, 0, 0, 1, 1, 1, 1, 1, 1];
1033 let dataset = Dataset::new(x, y);
1034 let res = log_reg.fit(&dataset).unwrap();
1035 assert_eq!(
1036 &res.predict(dataset.records()),
1037 dataset.targets().as_single_targets()
1038 );
1039 }
1040
1041 #[test]
1042 fn simple_example_3() {
1043 let x = array![[1.0], [0.0], [1.0], [0.0]];
1044 let y = array![1, 0, 1, 0];
1045 let dataset = DatasetBase::new(x, y);
1046 let model = LogisticRegression::default().fit(&dataset).unwrap();
1047
1048 let pred = model.predict(&dataset.records);
1049 assert_eq!(dataset.targets(), pred);
1050 }
1051
1052 #[test]
1053 fn rejects_mismatching_x_y() {
1054 let log_reg = LogisticRegression::default();
1055 let x = array![[-1.0], [-0.01], [0.01]];
1056 let y = array![0, 0, 1, 1];
1057 let res = log_reg.fit(&Dataset::new(x, y));
1058 assert!(matches!(res.unwrap_err(), Error::MismatchedShapes(3, 4)));
1059 }
1060
1061 #[test]
1062 fn rejects_inf_values() {
1063 let infs = &[f64::INFINITY, f64::NEG_INFINITY, f64::NAN];
1064 let inf_xs: Vec<_> = infs.iter().map(|&inf| array![[1.0], [inf]]).collect();
1065 let log_reg = LogisticRegression::default();
1066 let normal_x = array![[-1.0], [1.0]];
1067 let y = array![0, 1];
1068 for inf_x in &inf_xs {
1069 let res = log_reg.fit(&DatasetBase::new(inf_x.view(), &y));
1070 assert!(matches!(res.unwrap_err(), Error::InvalidValues));
1071 }
1072 for inf in infs {
1073 let log_reg = LogisticRegression::default().alpha(*inf);
1074 let res = log_reg.fit(&DatasetBase::new(normal_x.view(), &y));
1075 assert!(matches!(res.unwrap_err(), Error::InvalidAlpha));
1076 }
1077 let mut non_positives = infs.to_vec();
1078 non_positives.push(-1.0);
1079 non_positives.push(0.0);
1080 for inf in &non_positives {
1081 let log_reg = LogisticRegression::default().gradient_tolerance(*inf);
1082 let res = log_reg.fit(&Dataset::new(normal_x.to_owned(), y.to_owned()));
1083 assert!(matches!(res.unwrap_err(), Error::InvalidGradientTolerance));
1084 }
1085 }
1086
1087 #[test]
1088 fn validates_initial_params() {
1089 let infs = &[f64::INFINITY, f64::NEG_INFINITY, f64::NAN];
1090 let normal_x = array![[-1.0], [1.0]];
1091 let normal_y = array![0, 1];
1092 let dataset = Dataset::new(normal_x, normal_y);
1093 for inf in infs {
1094 let log_reg = LogisticRegression::default().initial_params(array![*inf, 0.0]);
1095 let res = log_reg.fit(&dataset);
1096 assert!(matches!(res.unwrap_err(), Error::InvalidInitialParameters));
1097 }
1098 {
1099 let log_reg = LogisticRegression::default().initial_params(array![0.0, 0.0, 0.0]);
1100 let res = log_reg.fit(&dataset);
1101 assert!(matches!(
1102 res.unwrap_err(),
1103 Error::InitialParameterFeaturesMismatch {
1104 rows: 3,
1105 n_features: 2
1106 }
1107 ));
1108 }
1109 {
1110 let log_reg = LogisticRegression::default()
1111 .with_intercept(false)
1112 .initial_params(array![0.0, 0.0]);
1113 let res = log_reg.fit(&dataset);
1114 assert!(matches!(
1115 res.unwrap_err(),
1116 Error::InitialParameterFeaturesMismatch {
1117 rows: 2,
1118 n_features: 1
1119 }
1120 ));
1121 }
1122 }
1123
1124 #[test]
1125 fn uses_initial_params() {
1126 let params = array![1.2, -4.12];
1127 let log_reg = LogisticRegression::default()
1128 .initial_params(params)
1129 .max_iterations(5);
1130 let x = array![
1131 [0.0],
1132 [1.0],
1133 [2.0],
1134 [3.0],
1135 [4.0],
1136 [5.0],
1137 [6.0],
1138 [7.0],
1139 [8.0],
1140 [9.0]
1141 ];
1142 let y = array![0, 0, 0, 0, 1, 1, 1, 1, 1, 1];
1143 let dataset = Dataset::new(x, y);
1144 let res = log_reg.fit(&dataset).unwrap();
1145 assert!(res.intercept().abs_diff_eq(&-4.124, 1e-3));
1146 assert!(res.params().abs_diff_eq(&array![1.181], 1e-3));
1147 assert_eq!(
1148 &res.predict(dataset.records()),
1149 dataset.targets().as_single_targets()
1150 );
1151
1152 #[cfg(feature = "serde")]
1154 {
1155 let ser = rmp_serde::to_vec(&res).unwrap();
1156 let unser: FittedLogisticRegression<f32, f32> = rmp_serde::from_slice(&ser).unwrap();
1157
1158 let x = array![[1.0]];
1159 let y_hat = unser.predict(&x);
1160
1161 assert!(y_hat[0] == 0.0);
1162 }
1163 }
1164
1165 #[test]
1166 fn works_with_f32() {
1167 let log_reg = LogisticRegression::default();
1168 let x: Array2<f32> = array![[-1.0], [-0.01], [0.01], [1.0]];
1169 let y = array![0, 0, 1, 1];
1170 let dataset = Dataset::new(x, y);
1171 let res = log_reg.fit(&dataset).unwrap();
1172 assert_abs_diff_eq!(res.intercept(), 0.0_f32);
1173 assert!(res.params().abs_diff_eq(&array![-0.682_f32], 1e-3));
1174 assert_eq!(
1175 &res.predict(dataset.records()),
1176 dataset.targets().as_single_targets()
1177 );
1178 }
1179
1180 #[test]
1181 fn test_log_sum_exp() {
1182 let data = array![[3.3, 0.4, -2.1], [0.4, 2.2, -0.1], [1., 0., -1.]];
1183 let out = log_sum_exp(&data, Axis(1));
1184 assert_abs_diff_eq!(out, array![3.35783, 2.43551, 1.40761], epsilon = 1e-5);
1185 }
1186
1187 #[test]
1188 fn test_softmax() {
1189 let mut data = array![3.3, 5.5, 0.1, -4.4, 8.0];
1190 softmax_inplace(&mut data);
1191 assert_relative_eq!(
1192 data,
1193 array![0.0083324, 0.075200047, 0.000339647, 0.000003773, 0.91612413],
1194 epsilon = 1e-8
1195 );
1196 assert_abs_diff_eq!(data.sum(), 1.0);
1197 }
1198
1199 #[test]
1200 fn test_multi_logistic_loss_grad() {
1201 let x = array![
1202 [0.0, 0.5],
1203 [1.0, -1.0],
1204 [2.0, -2.0],
1205 [3.0, -3.0],
1206 [4.0, -4.0],
1207 [5.0, -5.0],
1208 [6.0, -6.0],
1209 [7.0, -7.0],
1210 ];
1211 let y = array![
1212 [1.0, 0.0, 0.0],
1213 [1.0, 0.0, 0.0],
1214 [0.0, 1.0, 0.0],
1215 [0.0, 1.0, 0.0],
1216 [0.0, 1.0, 0.0],
1217 [0.0, 0.0, 1.0],
1218 [0.0, 0.0, 1.0],
1219 [0.0, 0.0, 1.0],
1220 ];
1221 let params1 = array![[4.4, -1.2, 3.3], [3.4, 0.1, 0.0]];
1222 let params2 = array![[0.001, -3.2, 2.9], [0.1, 4.5, 5.7], [4.5, 2.2, 1.7]];
1223 let alpha = 0.6;
1224
1225 {
1226 let (log_prob, w) = multi_logistic_prob_params(&x, ¶ms1);
1227 assert_abs_diff_eq!(
1228 log_prob,
1229 array![
1230 [-3.18259845e-01, -1.96825985e+00, -2.01825985e+00],
1231 [-2.40463987e+00, -4.70463987e+00, -1.04639868e-01],
1232 [-4.61010168e+00, -9.21010168e+00, -1.01016809e-02],
1233 [-6.90100829e+00, -1.38010083e+01, -1.00829256e-03],
1234 [-9.20010104e+00, -1.84001010e+01, -1.01044506e-04],
1235 [-1.15000101e+01, -2.30000101e+01, -1.01301449e-05],
1236 [-1.38000010e+01, -2.76000010e+01, -1.01563199e-06],
1237 [-1.61000001e+01, -3.22000001e+01, -1.01826043e-07],
1238 ],
1239 epsilon = 1e-6
1240 );
1241 assert_abs_diff_eq!(w, params1);
1242 let loss = multi_logistic_loss(&x, &y, alpha, ¶ms1);
1243 assert_abs_diff_eq!(loss, 57.11212197835295, epsilon = 1e-6);
1244 let grad = multi_logistic_grad(&x, &y, alpha, ¶ms1);
1245 assert_abs_diff_eq!(
1246 grad,
1247 array![
1248 [1.7536815, -9.71074369, 11.85706219],
1249 [2.79002537, 9.12059357, -9.81061893]
1250 ],
1251 epsilon = 1e-6
1252 );
1253 }
1254
1255 {
1256 let (log_prob, w) = multi_logistic_prob_params(&x, ¶ms2);
1257 assert_abs_diff_eq!(
1258 log_prob,
1259 array![
1260 [-1.06637742e+00, -1.16637742e+00, -1.06637742e+00],
1261 [-4.12429463e-03, -9.90512429e+00, -5.50512429e+00],
1262 [-2.74092305e-04, -1.75022741e+01, -8.20227409e+00],
1263 [-1.84027855e-05, -2.51030184e+01, -1.09030184e+01],
1264 [-1.23554225e-06, -3.27040012e+01, -1.36040012e+01],
1265 [-8.29523046e-08, -4.03050001e+01, -1.63050001e+01],
1266 [-5.56928016e-09, -4.79060000e+01, -1.90060000e+01],
1267 [-3.73912013e-10, -5.55070000e+01, -2.17070000e+01]
1268 ],
1269 epsilon = 1e-6
1270 );
1271 assert_abs_diff_eq!(w, params2.slice(s![..params2.nrows() - 1, ..]));
1272 let loss = multi_logistic_loss(&x, &y, alpha, ¶ms2);
1273 assert_abs_diff_eq!(loss, 154.8177958366479, epsilon = 1e-6);
1274 let grad = multi_logistic_grad(&x, &y, alpha, ¶ms2);
1275 assert_abs_diff_eq!(
1276 grad,
1277 array![
1278 [26.99587549, -10.91995003, -16.25532546],
1279 [-27.26314882, 11.85569669, 21.58745213],
1280 [5.33984376, -2.68845675, -2.65138701]
1281 ],
1282 epsilon = 1e-6
1283 );
1284 }
1285 }
1286
1287 #[test]
1288 fn simple_multi_example() {
1289 let x = array![[-1., 0.], [0., 1.], [1., 1.]];
1290 let y = array![2, 1, 0];
1291 let log_reg = MultiLogisticRegression::default()
1292 .alpha(0.1)
1293 .initial_params(Array::zeros((3, 3)));
1294 let dataset = Dataset::new(x, y);
1295 let res = log_reg.fit(&dataset).unwrap();
1296 assert_eq!(res.params().dim(), (2, 3));
1297 assert_eq!(res.intercept().dim(), 3);
1298 assert_eq!(
1299 &res.predict(dataset.records()),
1300 dataset.targets().as_single_targets()
1301 );
1302 }
1303
1304 #[test]
1305 fn simple_multi_example_2() {
1306 let x = array![[1.0], [0.0], [1.0], [0.0]];
1307 let y = array![1, 0, 1, 0];
1308 let dataset = DatasetBase::new(x, y);
1309 let model = MultiLogisticRegression::default().fit(&dataset).unwrap();
1310
1311 let pred = model.predict(&dataset.records);
1312 assert_eq!(dataset.targets(), pred);
1313 }
1314
1315 #[test]
1316 fn simple_multi_example_text() {
1317 let log_reg = MultiLogisticRegression::default().alpha(0.1);
1318 let x = array![[0.1], [1.0], [-1.0], [-0.1]];
1319 let y = array!["dog", "ape", "rocket", "cat"];
1320 let dataset = Dataset::new(x, y);
1321 let res = log_reg.fit(&dataset).unwrap();
1322 assert_eq!(res.params().dim(), (1, 4));
1323 assert_eq!(res.intercept().dim(), 4);
1324 assert_eq!(
1325 &res.predict(dataset.records()),
1326 dataset.targets().as_single_targets()
1327 );
1328 }
1329
1330 #[test]
1331 fn multi_on_binary_problem() {
1332 let log_reg = MultiLogisticRegression::default().alpha(1.0);
1333 let x = array![
1334 [0.0],
1335 [1.0],
1336 [2.0],
1337 [3.0],
1338 [4.0],
1339 [5.0],
1340 [6.0],
1341 [7.0],
1342 [8.0],
1343 [9.0]
1344 ];
1345 let y = array![0, 0, 0, 0, 1, 1, 1, 1, 1, 1];
1346 let dataset = Dataset::new(x, y);
1347 let res = log_reg.fit(&dataset).unwrap();
1348 assert_eq!(res.params().dim(), (1, 2));
1349 assert_eq!(res.intercept().dim(), 2);
1350 assert_eq!(
1351 &res.predict(dataset.records()),
1352 dataset.targets().as_single_targets()
1353 );
1354 }
1355
1356 #[test]
1357 fn reject_num_class_mismatch() {
1358 let n_samples = 4;
1359 let n_classes = 3;
1360 let n_features = 1;
1361 let x = Array2::<f64>::zeros((n_samples, n_features));
1362 let y = array![0, 1, 2, 0];
1363 let dataset = Dataset::new(x, y);
1364
1365 let log_reg = MultiLogisticRegression::default()
1366 .with_intercept(false)
1367 .initial_params(Array::zeros((n_features, n_classes - 1)));
1368 assert!(matches!(
1369 log_reg.fit(&dataset).unwrap_err(),
1370 Error::InitialParameterClassesMismatch {
1371 cols: 2,
1372 n_classes: 3,
1373 }
1374 ));
1375 }
1376}