1pub mod error;
20
21use crate::error::{Error, Result};
22use argmin::core::{CostFunction, Executor, Gradient, IterState, OptimizationResult, Solver};
23use argmin::solver::linesearch::MoreThuenteLineSearch;
24use argmin::solver::quasinewton::LBFGS;
25use linfa::dataset::AsSingleTargets;
26use linfa::prelude::DatasetBase;
27use linfa::traits::{Fit, PredictInplace};
28use ndarray::{
29 s, Array, Array1, Array2, ArrayBase, ArrayView, ArrayView2, Axis, CowArray, Data, DataMut,
30 Dimension, IntoDimension, Ix1, Ix2, RemoveAxis, Slice, Zip,
31};
32use ndarray_stats::QuantileExt;
33use std::default::Default;
34
35#[cfg(feature = "serde")]
36use serde_crate::de::DeserializeOwned;
37#[cfg(feature = "serde")]
38use serde_crate::{Deserialize, Serialize};
39
40mod argmin_param;
41mod float;
42mod hyperparams;
43
44use argmin_param::*;
45use float::Float;
46use hyperparams::{LogisticRegressionParams, LogisticRegressionValidParams};
47
48pub type LogisticRegression<F> = LogisticRegressionParams<F, Ix1>;
79
80pub type ValidLogisticRegression<F> = LogisticRegressionValidParams<F, Ix1>;
82
83pub type MultiLogisticRegression<F> = LogisticRegressionParams<F, Ix2>;
94
95pub type ValidMultiLogisticRegression<F> = LogisticRegressionValidParams<F, Ix2>;
97
98impl<F: Float, D: Dimension> Default for LogisticRegressionParams<F, D> {
99 fn default() -> Self {
100 LogisticRegressionParams::new()
101 }
102}
103
104type LBFGSType<F, D> = LBFGS<
105 MoreThuenteLineSearch<ArgminParam<F, D>, ArgminParam<F, D>, F>,
106 ArgminParam<F, D>,
107 ArgminParam<F, D>,
108 F,
109>;
110type LBFGSType1<F> = LBFGSType<F, Ix1>;
111type LBFGSType2<F> = LBFGSType<F, Ix2>;
112
113type IterStateType<F, D> = IterState<ArgminParam<F, D>, ArgminParam<F, D>, (), (), (), F>;
114
115impl<F: Float, D: Dimension> LogisticRegressionValidParams<F, D> {
116 fn setup_init_params(&self, dims: D::Pattern) -> ArgminParam<F, D> {
119 if let Some(params) = self.initial_params.as_ref() {
120 ArgminParam(params.clone())
121 } else {
122 let mut dims = dims.into_dimension();
123 dims.as_array_view_mut()[0] += self.fit_intercept as usize;
124 ArgminParam(Array::zeros(dims))
125 }
126 }
127
128 fn validate_data<A: Data<Elem = F>, B: Data<Elem = F>>(
131 &self,
132 x: &ArrayBase<A, Ix2>,
133 y: &ArrayBase<B, D>,
134 ) -> Result<()> {
135 if x.shape()[0] != y.shape()[0] {
136 return Err(Error::MismatchedShapes(x.shape()[0], y.shape()[0]));
137 }
138 if x.iter().any(|x| !x.is_finite()) || y.iter().any(|y| !y.is_finite()) {
139 return Err(Error::InvalidValues);
140 }
141 self.validate_init_dims(x.shape()[1], y.shape().get(1).copied())?;
142 Ok(())
143 }
144
145 fn validate_init_dims(&self, mut n_features: usize, n_classes: Option<usize>) -> Result<()> {
146 if let Some(params) = self.initial_params.as_ref() {
147 let shape = params.shape();
148 n_features += self.fit_intercept as usize;
149 if n_features != shape[0] {
150 return Err(Error::InitialParameterFeaturesMismatch {
151 n_features,
152 rows: shape[0],
153 });
154 }
155 if let Some(n_classes) = n_classes {
156 if n_classes != shape[1] {
157 return Err(Error::InitialParameterClassesMismatch {
158 n_classes,
159 cols: shape[1],
160 });
161 }
162 }
163 }
164 Ok(())
165 }
166
167 fn setup_problem<'a, A: Data<Elem = F>>(
169 &self,
170 x: &'a ArrayBase<A, Ix2>,
171 target: Array<F, D>,
172 ) -> LogisticRegressionProblem<'a, F, A, D> {
173 LogisticRegressionProblem {
174 x,
175 target,
176 alpha: self.alpha,
177 }
178 }
179
180 fn setup_solver(&self) -> LBFGSType<F, D> {
183 let linesearch = MoreThuenteLineSearch::new();
184 LBFGS::new(linesearch, 10)
185 .with_tolerance_grad(self.gradient_tolerance)
186 .unwrap()
187 }
188}
189
190impl<
191 F: Float,
192 #[cfg(feature = "serde")] D: Dimension + Serialize + DeserializeOwned,
193 #[cfg(not(feature = "serde"))] D: Dimension,
194 > LogisticRegressionValidParams<F, D>
195{
196 fn run_solver<P: SolvableProblem<F, D>>(
198 &self,
199 problem: P,
200 solver: P::Solver,
201 init_params: ArgminParam<F, D>,
202 ) -> Result<OptimizationResult<P, P::Solver, IterStateType<F, D>>> {
203 Executor::new(problem, solver)
204 .configure(|state| state.param(init_params).max_iters(self.max_iterations))
205 .run()
206 .map_err(move |err| err.into())
207 }
208}
209
210impl<C: Ord + Clone, F: Float, D: Data<Elem = F>, T: AsSingleTargets<Elem = C>>
211 Fit<ArrayBase<D, Ix2>, T, Error> for ValidLogisticRegression<F>
212{
213 type Object = FittedLogisticRegression<F, C>;
214
215 fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
232 let (x, y) = (dataset.records(), dataset.targets());
233 let (labels, target) = label_classes(y)?;
234 self.validate_data(x, &target)?;
235 let problem = self.setup_problem(x, target);
236 let solver = self.setup_solver();
237 let init_params = self.setup_init_params(x.ncols());
238 let result = self.run_solver(problem, solver, init_params)?;
239
240 let params = result
241 .state
242 .best_param
243 .unwrap_or(self.setup_init_params(x.ncols()));
244 let (w, intercept) = convert_params(x.ncols(), params.as_array());
245 Ok(FittedLogisticRegression::new(
246 *intercept.view().into_scalar(),
247 w.to_owned(),
248 labels,
249 ))
250 }
251}
252
253impl<C: Ord + Clone, F: Float, D: Data<Elem = F>, T: AsSingleTargets<Elem = C>>
254 Fit<ArrayBase<D, Ix2>, T, Error> for ValidMultiLogisticRegression<F>
255{
256 type Object = MultiFittedLogisticRegression<F, C>;
257
258 fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
269 let (x, y) = (dataset.records(), dataset.targets());
270 let (classes, target) = label_classes_multi(y)?;
271 self.validate_data(x, &target)?;
272 let problem = self.setup_problem(x, target);
273 let solver = self.setup_solver();
274 let init_params = self.setup_init_params((x.ncols(), classes.len()));
275 let result = self.run_solver(problem, solver, init_params)?;
276
277 let params = result
278 .state
279 .best_param
280 .unwrap_or(self.setup_init_params((x.ncols(), classes.len())));
281 let (w, intercept) = convert_params(x.ncols(), params.as_array());
282 Ok(MultiFittedLogisticRegression::new(
283 intercept.to_owned(),
284 w.to_owned(),
285 classes,
286 ))
287 }
288}
289
290fn label_classes<F, T, C>(y: T) -> Result<(BinaryClassLabels<F, C>, Array1<F>)>
297where
298 F: Float,
299 T: AsSingleTargets<Elem = C>,
300 C: Ord + Clone,
301{
302 let y = y.as_single_targets();
303
304 let mut binary_classes = [None, None];
305 for class in y {
306 binary_classes = match binary_classes {
307 [None, None] => [Some((class, 1)), None],
308 [Some((c, count)), c2] if c == class => [Some((class, count + 1)), c2],
309 [c1, Some((c, count))] if c == class => [c1, Some((class, count + 1))],
310 [Some(c1), None] => [Some(c1), Some((class, 1))],
311 [None, Some(_)] => unreachable!("impossible binary class array"),
312 [Some(_), Some(_)] => return Err(Error::TooManyClasses),
313 };
314 }
315
316 let (class_a, class_b) = match binary_classes {
317 [Some(a), Some(b)] => (a, b),
318 _ => return Err(Error::TooFewClasses),
319 };
320
321 let (neg_class, pos_class) = if class_a.0 < class_b.0 {
325 (class_a, class_b)
326 } else {
327 (class_b, class_a)
328 };
329
330 let target_array = y
331 .into_iter()
332 .map(|x| {
333 if x == pos_class.0 {
334 F::POSITIVE_LABEL
335 } else {
336 F::NEGATIVE_LABEL
337 }
338 })
339 .collect::<Array1<_>>();
340
341 Ok((
342 BinaryClassLabels {
343 pos: ClassLabel {
344 class: pos_class.0.clone(),
345 label: F::POSITIVE_LABEL,
346 },
347 neg: ClassLabel {
348 class: neg_class.0.clone(),
349 label: F::NEGATIVE_LABEL,
350 },
351 },
352 target_array,
353 ))
354}
355
356fn label_classes_multi<F, T, C>(y: T) -> Result<(Vec<C>, Array2<F>)>
360where
361 F: Float,
362 T: AsSingleTargets<Elem = C>,
363 C: Ord + Clone,
364{
365 let y_single_target = y.as_single_targets();
366 let mut classes = y_single_target.to_vec();
367 classes.sort();
369 classes.dedup();
370
371 let mut onehot = Array2::zeros((y_single_target.len(), classes.len()));
372 Zip::from(onehot.rows_mut())
373 .and(&y_single_target)
374 .for_each(|mut oh_row, cls| {
375 let idx = classes.binary_search(cls).unwrap();
376 oh_row[idx] = F::one();
377 });
378 Ok((classes, onehot))
379}
380
381fn convert_params<F: Float, D: Dimension + RemoveAxis>(
385 n_features: usize,
386 w: &Array<F, D>,
387) -> (ArrayView<'_, F, D>, CowArray<'_, F, D::Smaller>) {
388 let nrows = w.shape()[0];
389 if n_features == nrows {
390 (
391 w.view(),
392 Array::zeros(w.raw_dim().remove_axis(Axis(0))).into(),
393 )
394 } else if n_features + 1 == nrows {
395 (
396 w.slice_axis(Axis(0), Slice::from(..n_features)),
397 w.index_axis(Axis(0), n_features).into(),
398 )
399 } else {
400 panic!(
401 "Unexpected length of parameter vector `w`, exected {} or {}, found {}",
402 n_features,
403 n_features + 1,
404 nrows
405 );
406 }
407}
408
409fn logistic<F: linfa::Float>(x: F) -> F {
411 F::one() / (F::one() + (-x).exp())
412}
413
414fn log_logistic<F: linfa::Float>(x: F) -> F {
422 if x > F::zero() {
423 -(F::one() + (-x).exp()).ln()
424 } else {
425 x - (F::one() + x.exp()).ln()
426 }
427}
428
429fn log_sum_exp<F: linfa::Float, A: Data<Elem = F>>(
434 m: &ArrayBase<A, Ix2>,
435 axis: Axis,
436) -> Array<F, Ix1> {
437 let max = m.iter().copied().reduce(F::max).unwrap();
439 let reduced = m.fold_axis(axis, F::zero(), |acc, elem| *acc + (*elem - max).exp());
442 reduced.mapv_into(|e| e.max(F::cast(1e-15)).ln() + max)
443}
444
445fn softmax_inplace<F: linfa::Float, A: DataMut<Elem = F>>(v: &mut ArrayBase<A, Ix1>) {
447 let max = v.iter().copied().reduce(F::max).unwrap();
448 v.mapv_inplace(|n| (n - max).exp());
449 let sum = v.sum();
450 v.mapv_inplace(|n| n / sum);
451}
452
453fn logistic_loss<F: Float, A: Data<Elem = F>>(
463 x: &ArrayBase<A, Ix2>,
464 y: &Array1<F>,
465 alpha: F,
466 w: &Array1<F>,
467) -> F {
468 let n_features = x.shape()[1];
469 let (params, intercept) = convert_params(n_features, w);
470 let yz = x.dot(¶ms.into_shape_with_order((params.len(), 1)).unwrap()) + intercept;
471 let len = yz.len();
472 let mut yz = yz.into_shape_with_order(len).unwrap() * y;
473 yz.mapv_inplace(log_logistic);
474 -yz.sum() + F::cast(0.5) * alpha * params.dot(¶ms)
475}
476
477fn logistic_grad<F: Float, A: Data<Elem = F>>(
479 x: &ArrayBase<A, Ix2>,
480 y: &Array1<F>,
481 alpha: F,
482 w: &Array1<F>,
483) -> Array1<F> {
484 let n_features = x.shape()[1];
485 let (params, intercept) = convert_params(n_features, w);
486 let yz = x.dot(¶ms.into_shape_with_order((params.len(), 1)).unwrap()) + intercept;
487 let len = yz.len();
488 let mut yz = yz.into_shape_with_order(len).unwrap() * y;
489 yz.mapv_inplace(logistic);
490 yz -= F::one();
491 yz *= y;
492 if w.len() == n_features + 1 {
493 let mut grad = Array::zeros(w.len());
494 grad.slice_mut(s![..n_features])
495 .assign(&(x.t().dot(&yz) + (¶ms * alpha)));
496 grad[n_features] = yz.sum();
497 grad
498 } else {
499 x.t().dot(&yz) + (¶ms * alpha)
500 }
501}
502
503fn multi_logistic_prob_params<'a, F: Float, A: Data<Elem = F>>(
508 x: &ArrayBase<A, Ix2>,
509 w: &'a Array2<F>, ) -> (Array2<F>, ArrayView2<'a, F>) {
511 let n_features = x.shape()[1];
512 let (params, intercept) = convert_params(n_features, w);
513 let h = x.dot(¶ms) + intercept;
515 let log_prob = &h
518 - log_sum_exp(&h, Axis(1))
519 .into_shape_with_order((h.nrows(), 1))
520 .unwrap();
521 (log_prob, params)
522}
523
524fn multi_logistic_loss<F: Float, A: Data<Elem = F>>(
526 x: &ArrayBase<A, Ix2>,
527 y: &Array2<F>,
528 alpha: F,
529 w: &Array2<F>,
530) -> F {
531 let (log_prob, params) = multi_logistic_prob_params(x, w);
532 -elem_dot(&log_prob, y) + F::cast(0.5) * alpha * elem_dot(¶ms, ¶ms)
534}
535
536fn multi_logistic_grad<F: Float, A: Data<Elem = F>>(
540 x: &ArrayBase<A, Ix2>,
541 y: &Array2<F>,
542 alpha: F,
543 w: &Array2<F>,
544) -> Array2<F> {
545 let (log_prob, params) = multi_logistic_prob_params(x, w);
546 let (n_features, n_classes) = params.dim();
547 let intercept = w.nrows() > n_features;
548 let mut grad = Array::zeros((n_features + intercept as usize, n_classes));
549
550 let prob = log_prob.mapv_into(num_traits::Float::exp);
552 let diff = prob - y;
553 let dw = x.t().dot(&diff) + (¶ms * alpha);
555 grad.slice_mut(s![..n_features, ..]).assign(&dw);
556 if intercept {
558 grad.row_mut(n_features).assign(&diff.sum_axis(Axis(0)));
559 }
560 grad
561}
562
563#[derive(Debug, Clone, PartialEq)]
565#[cfg_attr(
566 feature = "serde",
567 derive(Serialize, Deserialize),
568 serde(crate = "serde_crate")
569)]
570pub struct FittedLogisticRegression<F: Float, C: PartialOrd + Clone> {
571 threshold: F,
572 intercept: F,
573 params: Array1<F>,
574 labels: BinaryClassLabels<F, C>,
575}
576
577impl<F: Float, C: PartialOrd + Clone> FittedLogisticRegression<F, C> {
578 fn new(
579 intercept: F,
580 params: Array1<F>,
581 labels: BinaryClassLabels<F, C>,
582 ) -> FittedLogisticRegression<F, C> {
583 FittedLogisticRegression {
584 threshold: F::cast(0.5),
585 intercept,
586 params,
587 labels,
588 }
589 }
590
591 pub fn set_threshold(mut self, threshold: F) -> FittedLogisticRegression<F, C> {
594 if threshold < F::zero() || threshold > F::one() {
595 panic!("FittedLogisticRegression::set_threshold: threshold needs to be between 0.0 and 1.0");
596 }
597 self.threshold = threshold;
598 self
599 }
600
601 pub fn intercept(&self) -> F {
602 self.intercept
603 }
604
605 pub fn params(&self) -> &Array1<F> {
606 &self.params
607 }
608
609 pub fn labels(&self) -> &BinaryClassLabels<F, C> {
612 &self.labels
613 }
614
615 pub fn predict_probabilities<A: Data<Elem = F>>(&self, x: &ArrayBase<A, Ix2>) -> Array1<F> {
619 let mut probs = x.dot(&self.params) + self.intercept;
620 probs.mapv_inplace(logistic);
621 probs
622 }
623}
624
625impl<C: PartialOrd + Clone + Default, F: Float, D: Data<Elem = F>>
626 PredictInplace<ArrayBase<D, Ix2>, Array1<C>> for FittedLogisticRegression<F, C>
627{
628 fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<C>) {
631 assert_eq!(
632 x.nrows(),
633 y.len(),
634 "The number of data points must match the number of output targets."
635 );
636 assert_eq!(
637 x.ncols(),
638 self.params.len(),
639 "Number of data features must match the number of features the model was trained with."
640 );
641
642 let pos_class = &self.labels.pos.class;
643 let neg_class = &self.labels.neg.class;
644 Zip::from(&self.predict_probabilities(x))
645 .and(y)
646 .for_each(|prob, out| {
647 *out = if *prob >= self.threshold {
648 pos_class.clone()
649 } else {
650 neg_class.clone()
651 }
652 });
653 }
654
655 fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<C> {
656 Array1::default(x.nrows())
657 }
658}
659
660#[derive(Debug, Clone, PartialEq, Eq)]
662#[cfg_attr(
663 feature = "serde",
664 derive(Serialize, Deserialize),
665 serde(crate = "serde_crate")
666)]
667pub struct MultiFittedLogisticRegression<F, C: PartialOrd + Clone> {
668 intercept: Array1<F>,
669 params: Array2<F>,
670 classes: Vec<C>,
671}
672
673impl<F: Float, C: PartialOrd + Clone> MultiFittedLogisticRegression<F, C> {
674 fn new(intercept: Array1<F>, params: Array2<F>, classes: Vec<C>) -> Self {
675 Self {
676 intercept,
677 params,
678 classes,
679 }
680 }
681
682 pub fn intercept(&self) -> &Array1<F> {
683 &self.intercept
684 }
685
686 pub fn params(&self) -> &Array2<F> {
687 &self.params
688 }
689
690 fn predict_nonorm_probabilities<A: Data<Elem = F>>(&self, x: &ArrayBase<A, Ix2>) -> Array2<F> {
692 x.dot(&self.params) + &self.intercept
693 }
694
695 pub fn predict_probabilities<A: Data<Elem = F>>(&self, x: &ArrayBase<A, Ix2>) -> Array2<F> {
698 let mut probs = self.predict_nonorm_probabilities(x);
699 probs
700 .rows_mut()
701 .into_iter()
702 .for_each(|mut row| softmax_inplace(&mut row));
703 probs
704 }
705
706 pub fn classes(&self) -> &[C] {
708 &self.classes
709 }
710}
711
712impl<C: PartialOrd + Clone + Default, F: Float, D: Data<Elem = F>>
713 PredictInplace<ArrayBase<D, Ix2>, Array1<C>> for MultiFittedLogisticRegression<F, C>
714{
715 fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<C>) {
718 assert_eq!(
719 x.nrows(),
720 y.len(),
721 "The number of data points must match the number of output targets."
722 );
723 assert_eq!(
724 x.ncols(),
725 self.params.nrows(),
726 "Number of data features must match the number of features the model was trained with."
727 );
728
729 let probs = self.predict_nonorm_probabilities(x);
730 Zip::from(probs.rows()).and(y).for_each(|prob_row, out| {
731 let idx = prob_row.argmax().unwrap();
732 *out = self.classes[idx].clone();
733 });
734 }
735
736 fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<C> {
737 Array1::default(x.nrows())
738 }
739}
740
741#[derive(Debug, Clone, PartialEq)]
742#[cfg_attr(
743 feature = "serde",
744 derive(Serialize, Deserialize),
745 serde(crate = "serde_crate")
746)]
747pub struct ClassLabel<F, C: PartialOrd> {
748 pub class: C,
749 pub label: F,
750}
751
752#[derive(Debug, Clone, PartialEq)]
753#[cfg_attr(
754 feature = "serde",
755 derive(Serialize, Deserialize),
756 serde(crate = "serde_crate")
757)]
758pub struct BinaryClassLabels<F, C: PartialOrd> {
759 pub pos: ClassLabel<F, C>,
760 pub neg: ClassLabel<F, C>,
761}
762
763struct LogisticRegressionProblem<'a, F: Float, A: Data<Elem = F>, D: Dimension> {
766 x: &'a ArrayBase<A, Ix2>,
767 target: Array<F, D>,
768 alpha: F,
769}
770
771type LogisticRegressionProblem1<'a, F, A> = LogisticRegressionProblem<'a, F, A, Ix1>;
772type LogisticRegressionProblem2<'a, F, A> = LogisticRegressionProblem<'a, F, A, Ix2>;
773
774impl<F: Float, A: Data<Elem = F>> CostFunction for LogisticRegressionProblem1<'_, F, A> {
775 type Param = ArgminParam<F, Ix1>;
776 type Output = F;
777
778 fn cost(&self, p: &Self::Param) -> std::result::Result<Self::Output, argmin::core::Error> {
780 let w = p.as_array();
781 let cost = logistic_loss(self.x, &self.target, self.alpha, w);
782 Ok(cost)
783 }
784}
785
786impl<F: Float, A: Data<Elem = F>> Gradient for LogisticRegressionProblem1<'_, F, A> {
787 type Param = ArgminParam<F, Ix1>;
788 type Gradient = ArgminParam<F, Ix1>;
789
790 fn gradient(&self, p: &Self::Param) -> std::result::Result<Self::Param, argmin::core::Error> {
792 let w = p.as_array();
793 let grad = ArgminParam(logistic_grad(self.x, &self.target, self.alpha, w));
794 Ok(grad)
795 }
796}
797
798impl<F: Float, A: Data<Elem = F>> CostFunction for LogisticRegressionProblem2<'_, F, A> {
799 type Param = ArgminParam<F, Ix2>;
800 type Output = F;
801
802 fn cost(&self, p: &Self::Param) -> std::result::Result<Self::Output, argmin::core::Error> {
804 let w = p.as_array();
805 let cost = multi_logistic_loss(self.x, &self.target, self.alpha, w);
806 Ok(cost)
807 }
808}
809
810impl<F: Float, A: Data<Elem = F>> Gradient for LogisticRegressionProblem2<'_, F, A> {
811 type Param = ArgminParam<F, Ix2>;
812 type Gradient = ArgminParam<F, Ix2>;
813
814 fn gradient(&self, p: &Self::Param) -> std::result::Result<Self::Param, argmin::core::Error> {
816 let w = p.as_array();
817 let grad = ArgminParam(multi_logistic_grad(self.x, &self.target, self.alpha, w));
818 Ok(grad)
819 }
820}
821
822trait SolvableProblem<F: Float, D: Dimension>: Gradient + Sized {
823 type Solver: Solver<Self, IterStateType<F, D>>;
824}
825
826impl<F: Float, A: Data<Elem = F>> SolvableProblem<F, Ix1> for LogisticRegressionProblem1<'_, F, A> {
827 type Solver = LBFGSType1<F>;
828}
829
830impl<F: Float, A: Data<Elem = F>> SolvableProblem<F, Ix2> for LogisticRegressionProblem2<'_, F, A> {
831 type Solver = LBFGSType2<F>;
832}
833
834#[cfg(test)]
835mod test {
836 extern crate linfa;
837
838 use super::Error;
839 use super::*;
840 use approx::{assert_abs_diff_eq, assert_relative_eq, AbsDiffEq};
841 use linfa::prelude::*;
842 use ndarray::{array, Array2, Dim, Ix};
843
844 #[test]
845 fn autotraits() {
846 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
847 has_autotraits::<LogisticRegressionParams<f64, Dim<[Ix; 0]>>>();
848 has_autotraits::<LogisticRegressionValidParams<f64, Dim<[Ix; 0]>>>();
849 has_autotraits::<ArgminParam<f64, Dim<[Ix; 0]>>>();
850 }
851
852 #[test]
856 fn test_logistic_loss() {
857 let x = array![
858 [0.0],
859 [1.0],
860 [2.0],
861 [3.0],
862 [4.0],
863 [5.0],
864 [6.0],
865 [7.0],
866 [8.0],
867 [9.0]
868 ];
869 let y = array![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
870 let ws = [
871 array![0.0, 0.0],
872 array![0.0, 1.0],
873 array![1.0, 0.0],
874 array![1.0, 1.0],
875 array![0.0, -1.0],
876 array![-1.0, 0.0],
877 array![-1.0, -1.0],
878 ];
879 let alphas = &[0.0, 1.0, 10.0];
880 let expecteds = vec![
881 6.931471805599453,
882 6.931471805599453,
883 6.931471805599453,
884 4.652158847349118,
885 4.652158847349118,
886 4.652158847349118,
887 2.8012999588008323,
888 3.3012999588008323,
889 7.801299958800833,
890 2.783195429782239,
891 3.283195429782239,
892 7.783195429782239,
893 10.652158847349117,
894 10.652158847349117,
895 10.652158847349117,
896 41.80129995880083,
897 42.30129995880083,
898 46.80129995880083,
899 47.78319542978224,
900 48.28319542978224,
901 52.78319542978224,
902 ];
903
904 for ((w, alpha), exp) in ws
905 .iter()
906 .flat_map(|w| alphas.iter().map(move |&alpha| (w, alpha)))
907 .zip(&expecteds)
908 {
909 assert_abs_diff_eq!(logistic_loss(&x, &y, alpha, w), *exp);
910 }
911 }
912
913 #[test]
917 fn test_logistic_grad() {
918 let x = array![
919 [0.0],
920 [1.0],
921 [2.0],
922 [3.0],
923 [4.0],
924 [5.0],
925 [6.0],
926 [7.0],
927 [8.0],
928 [9.0]
929 ];
930 let y = array![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
931 let ws = [
932 array![0.0, 0.0],
933 array![0.0, 1.0],
934 array![1.0, 0.0],
935 array![1.0, 1.0],
936 array![0.0, -1.0],
937 array![-1.0, 0.0],
938 array![-1.0, -1.0],
939 ];
940 let alphas = &[0.0, 1.0, 10.0];
941 let expecteds = vec![
942 array![-19.5, -3.],
943 array![-19.5, -3.],
944 array![-19.5, -3.],
945 array![-10.48871543, -1.61364853],
946 array![-10.48871543, -1.61364853],
947 array![-10.48871543, -1.61364853],
948 array![-0.13041554, -0.02852148],
949 array![0.86958446, -0.02852148],
950 array![9.86958446, -0.02852148],
951 array![-0.04834401, -0.01058067],
952 array![0.95165599, -0.01058067],
953 array![9.95165599, -0.01058067],
954 array![-28.51128457, -4.38635147],
955 array![-28.51128457, -4.38635147],
956 array![-28.51128457, -4.38635147],
957 array![-38.86958446, -5.97147852],
958 array![-39.86958446, -5.97147852],
959 array![-48.86958446, -5.97147852],
960 array![-38.95165599, -5.98941933],
961 array![-39.95165599, -5.98941933],
962 array![-48.95165599, -5.98941933],
963 ];
964
965 for ((w, alpha), exp) in ws
966 .iter()
967 .flat_map(|w| alphas.iter().map(move |&alpha| (w, alpha)))
968 .zip(&expecteds)
969 {
970 let actual = logistic_grad(&x, &y, alpha, w);
971 assert!(actual.abs_diff_eq(exp, 1e-8));
972 }
973 }
974
975 #[test]
976 fn simple_example_1() {
977 let log_reg = LogisticRegression::default();
978 let x = array![[-1.0], [-0.01], [0.01], [1.0]];
979 let y = array![0, 0, 1, 1];
980 let dataset = Dataset::new(x, y);
981 let res = log_reg.fit(&dataset).unwrap();
982 assert_abs_diff_eq!(res.intercept(), 0.0);
983 assert!(res.params().abs_diff_eq(&array![0.681], 1e-3));
984 assert_eq!(
985 &res.predict(dataset.records()),
986 dataset.targets().as_single_targets()
987 );
988 }
989
990 #[test]
991 fn simple_example_1_cats_dogs() {
992 let log_reg = LogisticRegression::default();
993 let x = array![[0.01], [1.0], [-1.0], [-0.01]];
994 let y = array!["dog", "dog", "cat", "cat"];
995 let dataset = Dataset::new(x, y);
996 let res = log_reg.fit(&dataset).unwrap();
997 assert_abs_diff_eq!(res.intercept(), 0.0);
998 assert!(res.params().abs_diff_eq(&array![0.681], 1e-3));
999 assert!(res
1000 .predict_probabilities(dataset.records())
1001 .abs_diff_eq(&array![0.501, 0.664, 0.335, 0.498], 1e-3));
1002 assert_eq!(
1003 &res.predict(dataset.records()),
1004 dataset.targets().as_single_targets()
1005 );
1006 assert_eq!(res.labels().pos.class, "dog");
1007 assert_eq!(res.labels().neg.class, "cat");
1008 }
1009
1010 #[test]
1011 fn simple_example_2() {
1012 let log_reg = LogisticRegression::default().alpha(1.0);
1013 let x = array![
1014 [0.0],
1015 [1.0],
1016 [2.0],
1017 [3.0],
1018 [4.0],
1019 [5.0],
1020 [6.0],
1021 [7.0],
1022 [8.0],
1023 [9.0]
1024 ];
1025 let y = array![0, 0, 0, 0, 1, 1, 1, 1, 1, 1];
1026 let dataset = Dataset::new(x, y);
1027 let res = log_reg.fit(&dataset).unwrap();
1028 assert_eq!(
1029 &res.predict(dataset.records()),
1030 dataset.targets().as_single_targets()
1031 );
1032 }
1033
1034 #[test]
1035 fn simple_example_3() {
1036 let x = array![[1.0], [0.0], [1.0], [0.0]];
1037 let y = array![1, 0, 1, 0];
1038 let dataset = DatasetBase::new(x, y);
1039 let model = LogisticRegression::default().fit(&dataset).unwrap();
1040
1041 let pred = model.predict(&dataset.records);
1042 assert_eq!(dataset.targets(), pred);
1043 }
1044
1045 #[test]
1046 fn rejects_mismatching_x_y() {
1047 let log_reg = LogisticRegression::default();
1048 let x = array![[-1.0], [-0.01], [0.01]];
1049 let y = array![0, 0, 1, 1];
1050 let res = log_reg.fit(&Dataset::new(x, y));
1051 assert!(matches!(res.unwrap_err(), Error::MismatchedShapes(3, 4)));
1052 }
1053
1054 #[test]
1055 fn rejects_inf_values() {
1056 let infs = &[f64::INFINITY, f64::NEG_INFINITY, f64::NAN];
1057 let inf_xs: Vec<_> = infs.iter().map(|&inf| array![[1.0], [inf]]).collect();
1058 let log_reg = LogisticRegression::default();
1059 let normal_x = array![[-1.0], [1.0]];
1060 let y = array![0, 1];
1061 for inf_x in &inf_xs {
1062 let res = log_reg.fit(&DatasetBase::new(inf_x.view(), &y));
1063 assert!(matches!(res.unwrap_err(), Error::InvalidValues));
1064 }
1065 for inf in infs {
1066 let log_reg = LogisticRegression::default().alpha(*inf);
1067 let res = log_reg.fit(&DatasetBase::new(normal_x.view(), &y));
1068 assert!(matches!(res.unwrap_err(), Error::InvalidAlpha));
1069 }
1070 let mut non_positives = infs.to_vec();
1071 non_positives.push(-1.0);
1072 non_positives.push(0.0);
1073 for inf in &non_positives {
1074 let log_reg = LogisticRegression::default().gradient_tolerance(*inf);
1075 let res = log_reg.fit(&Dataset::new(normal_x.to_owned(), y.to_owned()));
1076 assert!(matches!(res.unwrap_err(), Error::InvalidGradientTolerance));
1077 }
1078 }
1079
1080 #[test]
1081 fn validates_initial_params() {
1082 let infs = &[f64::INFINITY, f64::NEG_INFINITY, f64::NAN];
1083 let normal_x = array![[-1.0], [1.0]];
1084 let normal_y = array![0, 1];
1085 let dataset = Dataset::new(normal_x, normal_y);
1086 for inf in infs {
1087 let log_reg = LogisticRegression::default().initial_params(array![*inf, 0.0]);
1088 let res = log_reg.fit(&dataset);
1089 assert!(matches!(res.unwrap_err(), Error::InvalidInitialParameters));
1090 }
1091 {
1092 let log_reg = LogisticRegression::default().initial_params(array![0.0, 0.0, 0.0]);
1093 let res = log_reg.fit(&dataset);
1094 assert!(matches!(
1095 res.unwrap_err(),
1096 Error::InitialParameterFeaturesMismatch {
1097 rows: 3,
1098 n_features: 2
1099 }
1100 ));
1101 }
1102 {
1103 let log_reg = LogisticRegression::default()
1104 .with_intercept(false)
1105 .initial_params(array![0.0, 0.0]);
1106 let res = log_reg.fit(&dataset);
1107 assert!(matches!(
1108 res.unwrap_err(),
1109 Error::InitialParameterFeaturesMismatch {
1110 rows: 2,
1111 n_features: 1
1112 }
1113 ));
1114 }
1115 }
1116
1117 #[test]
1118 fn uses_initial_params() {
1119 let params = array![1.2, -4.12];
1120 let log_reg = LogisticRegression::default()
1121 .initial_params(params)
1122 .max_iterations(5);
1123 let x = array![
1124 [0.0],
1125 [1.0],
1126 [2.0],
1127 [3.0],
1128 [4.0],
1129 [5.0],
1130 [6.0],
1131 [7.0],
1132 [8.0],
1133 [9.0]
1134 ];
1135 let y = array![0, 0, 0, 0, 1, 1, 1, 1, 1, 1];
1136 let dataset = Dataset::new(x, y);
1137 let res = log_reg.fit(&dataset).unwrap();
1138 assert!(res.intercept().abs_diff_eq(&-4.124, 1e-3));
1139 assert!(res.params().abs_diff_eq(&array![1.181], 1e-3));
1140 assert_eq!(
1141 &res.predict(dataset.records()),
1142 dataset.targets().as_single_targets()
1143 );
1144
1145 #[cfg(feature = "serde")]
1147 {
1148 let ser = rmp_serde::to_vec(&res).unwrap();
1149 let unser: FittedLogisticRegression<f32, f32> = rmp_serde::from_slice(&ser).unwrap();
1150
1151 let x = array![[1.0]];
1152 let y_hat = unser.predict(&x);
1153
1154 assert!(y_hat[0] == 0.0);
1155 }
1156 }
1157
1158 #[test]
1159 fn works_with_f32() {
1160 let log_reg = LogisticRegression::default();
1161 let x: Array2<f32> = array![[-1.0], [-0.01], [0.01], [1.0]];
1162 let y = array![0, 0, 1, 1];
1163 let dataset = Dataset::new(x, y);
1164 let res = log_reg.fit(&dataset).unwrap();
1165 assert_abs_diff_eq!(res.intercept(), 0.0_f32);
1166 assert!(res.params().abs_diff_eq(&array![0.682_f32], 1e-3));
1167 assert_eq!(
1168 &res.predict(dataset.records()),
1169 dataset.targets().as_single_targets()
1170 );
1171 }
1172
1173 #[test]
1174 fn test_log_sum_exp() {
1175 let data = array![[3.3, 0.4, -2.1], [0.4, 2.2, -0.1], [1., 0., -1.]];
1176 let out = log_sum_exp(&data, Axis(1));
1177 assert_abs_diff_eq!(out, array![3.35783, 2.43551, 1.40761], epsilon = 1e-5);
1178 }
1179
1180 #[test]
1181 fn test_softmax() {
1182 let mut data = array![3.3, 5.5, 0.1, -4.4, 8.0];
1183 softmax_inplace(&mut data);
1184 assert_relative_eq!(
1185 data,
1186 array![0.0083324, 0.075200047, 0.000339647, 0.000003773, 0.91612413],
1187 epsilon = 1e-8
1188 );
1189 assert_abs_diff_eq!(data.sum(), 1.0);
1190 }
1191
1192 #[test]
1193 fn test_multi_logistic_loss_grad() {
1194 let x = array![
1195 [0.0, 0.5],
1196 [1.0, -1.0],
1197 [2.0, -2.0],
1198 [3.0, -3.0],
1199 [4.0, -4.0],
1200 [5.0, -5.0],
1201 [6.0, -6.0],
1202 [7.0, -7.0],
1203 ];
1204 let y = array![
1205 [1.0, 0.0, 0.0],
1206 [1.0, 0.0, 0.0],
1207 [0.0, 1.0, 0.0],
1208 [0.0, 1.0, 0.0],
1209 [0.0, 1.0, 0.0],
1210 [0.0, 0.0, 1.0],
1211 [0.0, 0.0, 1.0],
1212 [0.0, 0.0, 1.0],
1213 ];
1214 let params1 = array![[4.4, -1.2, 3.3], [3.4, 0.1, 0.0]];
1215 let params2 = array![[0.001, -3.2, 2.9], [0.1, 4.5, 5.7], [4.5, 2.2, 1.7]];
1216 let alpha = 0.6;
1217
1218 {
1219 let (log_prob, w) = multi_logistic_prob_params(&x, ¶ms1);
1220 assert_abs_diff_eq!(
1221 log_prob,
1222 array![
1223 [-3.18259845e-01, -1.96825985e+00, -2.01825985e+00],
1224 [-2.40463987e+00, -4.70463987e+00, -1.04639868e-01],
1225 [-4.61010168e+00, -9.21010168e+00, -1.01016809e-02],
1226 [-6.90100829e+00, -1.38010083e+01, -1.00829256e-03],
1227 [-9.20010104e+00, -1.84001010e+01, -1.01044506e-04],
1228 [-1.15000101e+01, -2.30000101e+01, -1.01301449e-05],
1229 [-1.38000010e+01, -2.76000010e+01, -1.01563199e-06],
1230 [-1.61000001e+01, -3.22000001e+01, -1.01826043e-07],
1231 ],
1232 epsilon = 1e-6
1233 );
1234 assert_abs_diff_eq!(w, params1);
1235 let loss = multi_logistic_loss(&x, &y, alpha, ¶ms1);
1236 assert_abs_diff_eq!(loss, 57.11212197835295, epsilon = 1e-6);
1237 let grad = multi_logistic_grad(&x, &y, alpha, ¶ms1);
1238 assert_abs_diff_eq!(
1239 grad,
1240 array![
1241 [1.7536815, -9.71074369, 11.85706219],
1242 [2.79002537, 9.12059357, -9.81061893]
1243 ],
1244 epsilon = 1e-6
1245 );
1246 }
1247
1248 {
1249 let (log_prob, w) = multi_logistic_prob_params(&x, ¶ms2);
1250 assert_abs_diff_eq!(
1251 log_prob,
1252 array![
1253 [-1.06637742e+00, -1.16637742e+00, -1.06637742e+00],
1254 [-4.12429463e-03, -9.90512429e+00, -5.50512429e+00],
1255 [-2.74092305e-04, -1.75022741e+01, -8.20227409e+00],
1256 [-1.84027855e-05, -2.51030184e+01, -1.09030184e+01],
1257 [-1.23554225e-06, -3.27040012e+01, -1.36040012e+01],
1258 [-8.29523046e-08, -4.03050001e+01, -1.63050001e+01],
1259 [-5.56928016e-09, -4.79060000e+01, -1.90060000e+01],
1260 [-3.73912013e-10, -5.55070000e+01, -2.17070000e+01]
1261 ],
1262 epsilon = 1e-6
1263 );
1264 assert_abs_diff_eq!(w, params2.slice(s![..params2.nrows() - 1, ..]));
1265 let loss = multi_logistic_loss(&x, &y, alpha, ¶ms2);
1266 assert_abs_diff_eq!(loss, 154.8177958366479, epsilon = 1e-6);
1267 let grad = multi_logistic_grad(&x, &y, alpha, ¶ms2);
1268 assert_abs_diff_eq!(
1269 grad,
1270 array![
1271 [26.99587549, -10.91995003, -16.25532546],
1272 [-27.26314882, 11.85569669, 21.58745213],
1273 [5.33984376, -2.68845675, -2.65138701]
1274 ],
1275 epsilon = 1e-6
1276 );
1277 }
1278 }
1279
1280 #[test]
1281 fn simple_multi_example() {
1282 let x = array![[-1., 0.], [0., 1.], [1., 1.]];
1283 let y = array![2, 1, 0];
1284 let log_reg = MultiLogisticRegression::default()
1285 .alpha(0.1)
1286 .initial_params(Array::zeros((3, 3)));
1287 let dataset = Dataset::new(x, y);
1288 let res = log_reg.fit(&dataset).unwrap();
1289 assert_eq!(res.params().dim(), (2, 3));
1290 assert_eq!(res.intercept().dim(), 3);
1291 assert_eq!(
1292 &res.predict(dataset.records()),
1293 dataset.targets().as_single_targets()
1294 );
1295 }
1296
1297 #[test]
1298 fn simple_multi_example_2() {
1299 let x = array![[1.0], [0.0], [1.0], [0.0]];
1300 let y = array![1, 0, 1, 0];
1301 let dataset = DatasetBase::new(x, y);
1302 let model = MultiLogisticRegression::default().fit(&dataset).unwrap();
1303
1304 let pred = model.predict(&dataset.records);
1305 assert_eq!(dataset.targets(), pred);
1306 }
1307
1308 #[test]
1309 fn simple_multi_example_text() {
1310 let log_reg = MultiLogisticRegression::default().alpha(0.1);
1311 let x = array![[0.1], [1.0], [-1.0], [-0.1]];
1312 let y = array!["dog", "ape", "rocket", "cat"];
1313 let dataset = Dataset::new(x, y);
1314 let res = log_reg.fit(&dataset).unwrap();
1315 assert_eq!(res.params().dim(), (1, 4));
1316 assert_eq!(res.intercept().dim(), 4);
1317 assert_eq!(
1318 &res.predict(dataset.records()),
1319 dataset.targets().as_single_targets()
1320 );
1321 }
1322
1323 #[test]
1324 fn multi_on_binary_problem() {
1325 let log_reg = MultiLogisticRegression::default().alpha(1.0);
1326 let x = array![
1327 [0.0],
1328 [1.0],
1329 [2.0],
1330 [3.0],
1331 [4.0],
1332 [5.0],
1333 [6.0],
1334 [7.0],
1335 [8.0],
1336 [9.0]
1337 ];
1338 let y = array![0, 0, 0, 0, 1, 1, 1, 1, 1, 1];
1339 let dataset = Dataset::new(x, y);
1340 let res = log_reg.fit(&dataset).unwrap();
1341 assert_eq!(res.params().dim(), (1, 2));
1342 assert_eq!(res.intercept().dim(), 2);
1343 assert_eq!(
1344 &res.predict(dataset.records()),
1345 dataset.targets().as_single_targets()
1346 );
1347 }
1348
1349 #[test]
1350 fn reject_num_class_mismatch() {
1351 let n_samples = 4;
1352 let n_classes = 3;
1353 let n_features = 1;
1354 let x = Array2::<f64>::zeros((n_samples, n_features));
1355 let y = array![0, 1, 2, 0];
1356 let dataset = Dataset::new(x, y);
1357
1358 let log_reg = MultiLogisticRegression::default()
1359 .with_intercept(false)
1360 .initial_params(Array::zeros((n_features, n_classes - 1)));
1361 assert!(matches!(
1362 log_reg.fit(&dataset).unwrap_err(),
1363 Error::InitialParameterClassesMismatch {
1364 cols: 2,
1365 n_classes: 3,
1366 }
1367 ));
1368 }
1369
1370 #[test]
1371 fn label_order_independent() {
1372 let x1 = array![[-1.0], [1.0], [-0.5], [0.5]];
1373 let y1 = array!["cat", "dog", "cat", "dog"];
1374
1375 let x2 = array![[1.0], [-1.0], [0.5], [-0.5]];
1376 let y2 = array!["dog", "cat", "dog", "cat"];
1377
1378 let model1 = LogisticRegression::default()
1379 .fit(&Dataset::new(x1, y1))
1380 .unwrap();
1381 let model2 = LogisticRegression::default()
1382 .fit(&Dataset::new(x2, y2))
1383 .unwrap();
1384
1385 assert_eq!(model1.labels().pos.class, "dog");
1386 assert_eq!(model1.labels().neg.class, "cat");
1387 assert_eq!(model2.labels().pos.class, "dog");
1388 assert_eq!(model2.labels().neg.class, "cat");
1389
1390 assert_abs_diff_eq!(model1.intercept(), model2.intercept());
1391 assert!(model1.params().abs_diff_eq(model2.params(), 1e-6));
1392 }
1393}