1pub mod error;
20
21use crate::error::{Error, Result};
22use argmin::core::{CostFunction, Executor, Gradient, IterState, OptimizationResult, Solver};
23use argmin::solver::linesearch::MoreThuenteLineSearch;
24use argmin::solver::quasinewton::LBFGS;
25use linfa::dataset::AsSingleTargets;
26use linfa::prelude::DatasetBase;
27use linfa::traits::{Fit, PredictInplace};
28use ndarray::{
29 s, Array, Array1, Array2, ArrayBase, ArrayView, ArrayView2, Axis, CowArray, Data, DataMut,
30 Dimension, IntoDimension, Ix1, Ix2, RemoveAxis, Slice, Zip,
31};
32use ndarray_stats::QuantileExt;
33use std::default::Default;
34
35#[cfg(feature = "serde")]
36use serde_crate::de::DeserializeOwned;
37#[cfg(feature = "serde")]
38use serde_crate::{Deserialize, Serialize};
39
40mod argmin_param;
41mod float;
42mod hyperparams;
43
44use argmin_param::*;
45use float::Float;
46use hyperparams::{LogisticRegressionParams, LogisticRegressionValidParams};
47
48pub type LogisticRegression<F> = LogisticRegressionParams<F, Ix1>;
79
80pub type ValidLogisticRegression<F> = LogisticRegressionValidParams<F, Ix1>;
82
83pub type MultiLogisticRegression<F> = LogisticRegressionParams<F, Ix2>;
94
95pub type ValidMultiLogisticRegression<F> = LogisticRegressionValidParams<F, Ix2>;
97
98impl<F: Float, D: Dimension> Default for LogisticRegressionParams<F, D> {
99 fn default() -> Self {
100 LogisticRegressionParams::new()
101 }
102}
103
104type LBFGSType<F, D> = LBFGS<
105 MoreThuenteLineSearch<ArgminParam<F, D>, ArgminParam<F, D>, F>,
106 ArgminParam<F, D>,
107 ArgminParam<F, D>,
108 F,
109>;
110type LBFGSType1<F> = LBFGSType<F, Ix1>;
111type LBFGSType2<F> = LBFGSType<F, Ix2>;
112
113type IterStateType<F, D> = IterState<ArgminParam<F, D>, ArgminParam<F, D>, (), (), (), F>;
114
115impl<F: Float, D: Dimension> LogisticRegressionValidParams<F, D> {
116 fn setup_init_params(&self, dims: D::Pattern) -> ArgminParam<F, D> {
119 if let Some(params) = self.initial_params.as_ref() {
120 ArgminParam(params.clone())
121 } else {
122 let mut dims = dims.into_dimension();
123 dims.as_array_view_mut()[0] += self.fit_intercept as usize;
124 ArgminParam(Array::zeros(dims))
125 }
126 }
127
128 fn validate_data<A: Data<Elem = F>, B: Data<Elem = F>>(
131 &self,
132 x: &ArrayBase<A, Ix2>,
133 y: &ArrayBase<B, D>,
134 ) -> Result<()> {
135 if x.shape()[0] != y.shape()[0] {
136 return Err(Error::MismatchedShapes(x.shape()[0], y.shape()[0]));
137 }
138 if x.iter().any(|x| !x.is_finite()) || y.iter().any(|y| !y.is_finite()) {
139 return Err(Error::InvalidValues);
140 }
141 self.validate_init_dims(x.shape()[1], y.shape().get(1).copied())?;
142 Ok(())
143 }
144
145 fn validate_init_dims(&self, mut n_features: usize, n_classes: Option<usize>) -> Result<()> {
146 if let Some(params) = self.initial_params.as_ref() {
147 let shape = params.shape();
148 n_features += self.fit_intercept as usize;
149 if n_features != shape[0] {
150 return Err(Error::InitialParameterFeaturesMismatch {
151 n_features,
152 rows: shape[0],
153 });
154 }
155 if let Some(n_classes) = n_classes {
156 if n_classes != shape[1] {
157 return Err(Error::InitialParameterClassesMismatch {
158 n_classes,
159 cols: shape[1],
160 });
161 }
162 }
163 }
164 Ok(())
165 }
166
167 fn setup_problem<'a, A: Data<Elem = F>>(
169 &self,
170 x: &'a ArrayBase<A, Ix2>,
171 target: Array<F, D>,
172 ) -> LogisticRegressionProblem<'a, F, A, D> {
173 LogisticRegressionProblem {
174 x,
175 target,
176 alpha: self.alpha,
177 }
178 }
179
180 fn setup_solver(&self) -> LBFGSType<F, D> {
183 let linesearch = MoreThuenteLineSearch::new();
184 LBFGS::new(linesearch, 10)
185 .with_tolerance_grad(self.gradient_tolerance)
186 .unwrap()
187 }
188}
189
190impl<
191 F: Float,
192 #[cfg(feature = "serde")] D: Dimension + Serialize + DeserializeOwned,
193 #[cfg(not(feature = "serde"))] D: Dimension,
194 > LogisticRegressionValidParams<F, D>
195{
196 fn run_solver<P: SolvableProblem<F, D>>(
198 &self,
199 problem: P,
200 solver: P::Solver,
201 init_params: ArgminParam<F, D>,
202 ) -> Result<OptimizationResult<P, P::Solver, IterStateType<F, D>>> {
203 Executor::new(problem, solver)
204 .configure(|state| state.param(init_params).max_iters(self.max_iterations))
205 .run()
206 .map_err(move |err| err.into())
207 }
208}
209
210impl<C: Ord + Clone, F: Float, D: Data<Elem = F>, T: AsSingleTargets<Elem = C>>
211 Fit<ArrayBase<D, Ix2>, T, Error> for ValidLogisticRegression<F>
212{
213 type Object = FittedLogisticRegression<F, C>;
214
215 fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
232 let (x, y) = (dataset.records(), dataset.targets());
233 let (labels, target) = label_classes(y)?;
234 self.validate_data(x, &target)?;
235 let problem = self.setup_problem(x, target);
236 let solver = self.setup_solver();
237 let init_params = self.setup_init_params(x.ncols());
238 let result = self.run_solver(problem, solver, init_params)?;
239
240 let params = result
241 .state
242 .best_param
243 .unwrap_or(self.setup_init_params(x.ncols()));
244 let (w, intercept) = convert_params(x.ncols(), params.as_array());
245 Ok(FittedLogisticRegression::new(
246 *intercept.view().into_scalar(),
247 w.to_owned(),
248 labels,
249 ))
250 }
251}
252
253impl<C: Ord + Clone, F: Float, D: Data<Elem = F>, T: AsSingleTargets<Elem = C>>
254 Fit<ArrayBase<D, Ix2>, T, Error> for ValidMultiLogisticRegression<F>
255{
256 type Object = MultiFittedLogisticRegression<F, C>;
257
258 fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
269 let (x, y) = (dataset.records(), dataset.targets());
270 let (classes, target) = label_classes_multi(y)?;
271 self.validate_data(x, &target)?;
272 let problem = self.setup_problem(x, target);
273 let solver = self.setup_solver();
274 let init_params = self.setup_init_params((x.ncols(), classes.len()));
275 let result = self.run_solver(problem, solver, init_params)?;
276
277 let params = result
278 .state
279 .best_param
280 .unwrap_or(self.setup_init_params((x.ncols(), classes.len())));
281 let (w, intercept) = convert_params(x.ncols(), params.as_array());
282 Ok(MultiFittedLogisticRegression::new(
283 intercept.to_owned(),
284 w.to_owned(),
285 classes,
286 ))
287 }
288}
289
290fn label_classes<F, T, C>(y: T) -> Result<(BinaryClassLabels<F, C>, Array1<F>)>
297where
298 F: Float,
299 T: AsSingleTargets<Elem = C>,
300 C: Ord + Clone,
301{
302 let y = y.as_single_targets();
303
304 let mut binary_classes = [None, None];
306 for class in y {
308 binary_classes = match binary_classes {
309 [None, None] => [Some((class, 1)), None],
311 [Some((c, count)), c2] if c == class => [Some((class, count + 1)), c2],
313 [c1, Some((c, count))] if c == class => [c1, Some((class, count + 1))],
314 [Some(c1), None] => [Some(c1), Some((class, 1))],
316
317 [None, Some(_)] => unreachable!("impossible binary class array"),
319 [Some(_), Some(_)] => return Err(Error::TooManyClasses),
321 };
322 }
323
324 let (pos_class, neg_class) = match binary_classes {
325 [Some(a), Some(b)] => (a, b),
326 _ => return Err(Error::TooFewClasses),
327 };
328
329 let mut target_array = y
330 .into_iter()
331 .map(|x| {
332 if x == pos_class.0 {
333 F::POSITIVE_LABEL
334 } else {
335 F::NEGATIVE_LABEL
336 }
337 })
338 .collect::<Array1<_>>();
339
340 let (pos_cl, neg_cl) = if pos_class.1 < neg_class.1 {
341 target_array *= -F::one();
345 (neg_class.0.clone(), pos_class.0.clone())
346 } else {
347 (pos_class.0.clone(), neg_class.0.clone())
348 };
349
350 Ok((
351 BinaryClassLabels {
352 pos: ClassLabel {
353 class: pos_cl,
354 label: F::POSITIVE_LABEL,
355 },
356 neg: ClassLabel {
357 class: neg_cl,
358 label: F::NEGATIVE_LABEL,
359 },
360 },
361 target_array,
362 ))
363}
364
365fn label_classes_multi<F, T, C>(y: T) -> Result<(Vec<C>, Array2<F>)>
369where
370 F: Float,
371 T: AsSingleTargets<Elem = C>,
372 C: Ord + Clone,
373{
374 let y_single_target = y.as_single_targets();
375 let mut classes = y_single_target.to_vec();
376 classes.sort();
378 classes.dedup();
379
380 let mut onehot = Array2::zeros((y_single_target.len(), classes.len()));
381 Zip::from(onehot.rows_mut())
382 .and(&y_single_target)
383 .for_each(|mut oh_row, cls| {
384 let idx = classes.binary_search(cls).unwrap();
385 oh_row[idx] = F::one();
386 });
387 Ok((classes, onehot))
388}
389
390fn convert_params<F: Float, D: Dimension + RemoveAxis>(
394 n_features: usize,
395 w: &Array<F, D>,
396) -> (ArrayView<'_, F, D>, CowArray<'_, F, D::Smaller>) {
397 let nrows = w.shape()[0];
398 if n_features == nrows {
399 (
400 w.view(),
401 Array::zeros(w.raw_dim().remove_axis(Axis(0))).into(),
402 )
403 } else if n_features + 1 == nrows {
404 (
405 w.slice_axis(Axis(0), Slice::from(..n_features)),
406 w.index_axis(Axis(0), n_features).into(),
407 )
408 } else {
409 panic!(
410 "Unexpected length of parameter vector `w`, exected {} or {}, found {}",
411 n_features,
412 n_features + 1,
413 nrows
414 );
415 }
416}
417
418fn logistic<F: linfa::Float>(x: F) -> F {
420 F::one() / (F::one() + (-x).exp())
421}
422
423fn log_logistic<F: linfa::Float>(x: F) -> F {
431 if x > F::zero() {
432 -(F::one() + (-x).exp()).ln()
433 } else {
434 x - (F::one() + x.exp()).ln()
435 }
436}
437
438fn log_sum_exp<F: linfa::Float, A: Data<Elem = F>>(
443 m: &ArrayBase<A, Ix2>,
444 axis: Axis,
445) -> Array<F, Ix1> {
446 let max = m.iter().copied().reduce(F::max).unwrap();
448 let reduced = m.fold_axis(axis, F::zero(), |acc, elem| *acc + (*elem - max).exp());
451 reduced.mapv_into(|e| e.max(F::cast(1e-15)).ln() + max)
452}
453
454fn softmax_inplace<F: linfa::Float, A: DataMut<Elem = F>>(v: &mut ArrayBase<A, Ix1>) {
456 let max = v.iter().copied().reduce(F::max).unwrap();
457 v.mapv_inplace(|n| (n - max).exp());
458 let sum = v.sum();
459 v.mapv_inplace(|n| n / sum);
460}
461
462fn logistic_loss<F: Float, A: Data<Elem = F>>(
472 x: &ArrayBase<A, Ix2>,
473 y: &Array1<F>,
474 alpha: F,
475 w: &Array1<F>,
476) -> F {
477 let n_features = x.shape()[1];
478 let (params, intercept) = convert_params(n_features, w);
479 let yz = x.dot(¶ms.into_shape_with_order((params.len(), 1)).unwrap()) + intercept;
480 let len = yz.len();
481 let mut yz = yz.into_shape_with_order(len).unwrap() * y;
482 yz.mapv_inplace(log_logistic);
483 -yz.sum() + F::cast(0.5) * alpha * params.dot(¶ms)
484}
485
486fn logistic_grad<F: Float, A: Data<Elem = F>>(
488 x: &ArrayBase<A, Ix2>,
489 y: &Array1<F>,
490 alpha: F,
491 w: &Array1<F>,
492) -> Array1<F> {
493 let n_features = x.shape()[1];
494 let (params, intercept) = convert_params(n_features, w);
495 let yz = x.dot(¶ms.into_shape_with_order((params.len(), 1)).unwrap()) + intercept;
496 let len = yz.len();
497 let mut yz = yz.into_shape_with_order(len).unwrap() * y;
498 yz.mapv_inplace(logistic);
499 yz -= F::one();
500 yz *= y;
501 if w.len() == n_features + 1 {
502 let mut grad = Array::zeros(w.len());
503 grad.slice_mut(s![..n_features])
504 .assign(&(x.t().dot(&yz) + (¶ms * alpha)));
505 grad[n_features] = yz.sum();
506 grad
507 } else {
508 x.t().dot(&yz) + (¶ms * alpha)
509 }
510}
511
512fn multi_logistic_prob_params<'a, F: Float, A: Data<Elem = F>>(
517 x: &ArrayBase<A, Ix2>,
518 w: &'a Array2<F>, ) -> (Array2<F>, ArrayView2<'a, F>) {
520 let n_features = x.shape()[1];
521 let (params, intercept) = convert_params(n_features, w);
522 let h = x.dot(¶ms) + intercept;
524 let log_prob = &h
527 - log_sum_exp(&h, Axis(1))
528 .into_shape_with_order((h.nrows(), 1))
529 .unwrap();
530 (log_prob, params)
531}
532
533fn multi_logistic_loss<F: Float, A: Data<Elem = F>>(
535 x: &ArrayBase<A, Ix2>,
536 y: &Array2<F>,
537 alpha: F,
538 w: &Array2<F>,
539) -> F {
540 let (log_prob, params) = multi_logistic_prob_params(x, w);
541 -elem_dot(&log_prob, y) + F::cast(0.5) * alpha * elem_dot(¶ms, ¶ms)
543}
544
545fn multi_logistic_grad<F: Float, A: Data<Elem = F>>(
549 x: &ArrayBase<A, Ix2>,
550 y: &Array2<F>,
551 alpha: F,
552 w: &Array2<F>,
553) -> Array2<F> {
554 let (log_prob, params) = multi_logistic_prob_params(x, w);
555 let (n_features, n_classes) = params.dim();
556 let intercept = w.nrows() > n_features;
557 let mut grad = Array::zeros((n_features + intercept as usize, n_classes));
558
559 let prob = log_prob.mapv_into(num_traits::Float::exp);
561 let diff = prob - y;
562 let dw = x.t().dot(&diff) + (¶ms * alpha);
564 grad.slice_mut(s![..n_features, ..]).assign(&dw);
565 if intercept {
567 grad.row_mut(n_features).assign(&diff.sum_axis(Axis(0)));
568 }
569 grad
570}
571
572#[derive(Debug, Clone, PartialEq)]
574#[cfg_attr(
575 feature = "serde",
576 derive(Serialize, Deserialize),
577 serde(crate = "serde_crate")
578)]
579pub struct FittedLogisticRegression<F: Float, C: PartialOrd + Clone> {
580 threshold: F,
581 intercept: F,
582 params: Array1<F>,
583 labels: BinaryClassLabels<F, C>,
584}
585
586impl<F: Float, C: PartialOrd + Clone> FittedLogisticRegression<F, C> {
587 fn new(
588 intercept: F,
589 params: Array1<F>,
590 labels: BinaryClassLabels<F, C>,
591 ) -> FittedLogisticRegression<F, C> {
592 FittedLogisticRegression {
593 threshold: F::cast(0.5),
594 intercept,
595 params,
596 labels,
597 }
598 }
599
600 pub fn set_threshold(mut self, threshold: F) -> FittedLogisticRegression<F, C> {
603 if threshold < F::zero() || threshold > F::one() {
604 panic!("FittedLogisticRegression::set_threshold: threshold needs to be between 0.0 and 1.0");
605 }
606 self.threshold = threshold;
607 self
608 }
609
610 pub fn intercept(&self) -> F {
611 self.intercept
612 }
613
614 pub fn params(&self) -> &Array1<F> {
615 &self.params
616 }
617
618 pub fn labels(&self) -> &BinaryClassLabels<F, C> {
621 &self.labels
622 }
623
624 pub fn predict_probabilities<A: Data<Elem = F>>(&self, x: &ArrayBase<A, Ix2>) -> Array1<F> {
628 let mut probs = x.dot(&self.params) + self.intercept;
629 probs.mapv_inplace(logistic);
630 probs
631 }
632}
633
634impl<C: PartialOrd + Clone + Default, F: Float, D: Data<Elem = F>>
635 PredictInplace<ArrayBase<D, Ix2>, Array1<C>> for FittedLogisticRegression<F, C>
636{
637 fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<C>) {
640 assert_eq!(
641 x.nrows(),
642 y.len(),
643 "The number of data points must match the number of output targets."
644 );
645 assert_eq!(
646 x.ncols(),
647 self.params.len(),
648 "Number of data features must match the number of features the model was trained with."
649 );
650
651 let pos_class = &self.labels.pos.class;
652 let neg_class = &self.labels.neg.class;
653 Zip::from(&self.predict_probabilities(x))
654 .and(y)
655 .for_each(|prob, out| {
656 *out = if *prob >= self.threshold {
657 pos_class.clone()
658 } else {
659 neg_class.clone()
660 }
661 });
662 }
663
664 fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<C> {
665 Array1::default(x.nrows())
666 }
667}
668
669#[derive(Debug, Clone, PartialEq, Eq)]
671#[cfg_attr(
672 feature = "serde",
673 derive(Serialize, Deserialize),
674 serde(crate = "serde_crate")
675)]
676pub struct MultiFittedLogisticRegression<F, C: PartialOrd + Clone> {
677 intercept: Array1<F>,
678 params: Array2<F>,
679 classes: Vec<C>,
680}
681
682impl<F: Float, C: PartialOrd + Clone> MultiFittedLogisticRegression<F, C> {
683 fn new(intercept: Array1<F>, params: Array2<F>, classes: Vec<C>) -> Self {
684 Self {
685 intercept,
686 params,
687 classes,
688 }
689 }
690
691 pub fn intercept(&self) -> &Array1<F> {
692 &self.intercept
693 }
694
695 pub fn params(&self) -> &Array2<F> {
696 &self.params
697 }
698
699 fn predict_nonorm_probabilities<A: Data<Elem = F>>(&self, x: &ArrayBase<A, Ix2>) -> Array2<F> {
701 x.dot(&self.params) + &self.intercept
702 }
703
704 pub fn predict_probabilities<A: Data<Elem = F>>(&self, x: &ArrayBase<A, Ix2>) -> Array2<F> {
707 let mut probs = self.predict_nonorm_probabilities(x);
708 probs
709 .rows_mut()
710 .into_iter()
711 .for_each(|mut row| softmax_inplace(&mut row));
712 probs
713 }
714
715 pub fn classes(&self) -> &[C] {
717 &self.classes
718 }
719}
720
721impl<C: PartialOrd + Clone + Default, F: Float, D: Data<Elem = F>>
722 PredictInplace<ArrayBase<D, Ix2>, Array1<C>> for MultiFittedLogisticRegression<F, C>
723{
724 fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<C>) {
727 assert_eq!(
728 x.nrows(),
729 y.len(),
730 "The number of data points must match the number of output targets."
731 );
732 assert_eq!(
733 x.ncols(),
734 self.params.nrows(),
735 "Number of data features must match the number of features the model was trained with."
736 );
737
738 let probs = self.predict_nonorm_probabilities(x);
739 Zip::from(probs.rows()).and(y).for_each(|prob_row, out| {
740 let idx = prob_row.argmax().unwrap();
741 *out = self.classes[idx].clone();
742 });
743 }
744
745 fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<C> {
746 Array1::default(x.nrows())
747 }
748}
749
750#[derive(Debug, Clone, PartialEq)]
751#[cfg_attr(
752 feature = "serde",
753 derive(Serialize, Deserialize),
754 serde(crate = "serde_crate")
755)]
756pub struct ClassLabel<F, C: PartialOrd> {
757 pub class: C,
758 pub label: F,
759}
760
761#[derive(Debug, Clone, PartialEq)]
762#[cfg_attr(
763 feature = "serde",
764 derive(Serialize, Deserialize),
765 serde(crate = "serde_crate")
766)]
767pub struct BinaryClassLabels<F, C: PartialOrd> {
768 pub pos: ClassLabel<F, C>,
769 pub neg: ClassLabel<F, C>,
770}
771
772struct LogisticRegressionProblem<'a, F: Float, A: Data<Elem = F>, D: Dimension> {
775 x: &'a ArrayBase<A, Ix2>,
776 target: Array<F, D>,
777 alpha: F,
778}
779
780type LogisticRegressionProblem1<'a, F, A> = LogisticRegressionProblem<'a, F, A, Ix1>;
781type LogisticRegressionProblem2<'a, F, A> = LogisticRegressionProblem<'a, F, A, Ix2>;
782
783impl<F: Float, A: Data<Elem = F>> CostFunction for LogisticRegressionProblem1<'_, F, A> {
784 type Param = ArgminParam<F, Ix1>;
785 type Output = F;
786
787 fn cost(&self, p: &Self::Param) -> std::result::Result<Self::Output, argmin::core::Error> {
789 let w = p.as_array();
790 let cost = logistic_loss(self.x, &self.target, self.alpha, w);
791 Ok(cost)
792 }
793}
794
795impl<F: Float, A: Data<Elem = F>> Gradient for LogisticRegressionProblem1<'_, F, A> {
796 type Param = ArgminParam<F, Ix1>;
797 type Gradient = ArgminParam<F, Ix1>;
798
799 fn gradient(&self, p: &Self::Param) -> std::result::Result<Self::Param, argmin::core::Error> {
801 let w = p.as_array();
802 let grad = ArgminParam(logistic_grad(self.x, &self.target, self.alpha, w));
803 Ok(grad)
804 }
805}
806
807impl<F: Float, A: Data<Elem = F>> CostFunction for LogisticRegressionProblem2<'_, F, A> {
808 type Param = ArgminParam<F, Ix2>;
809 type Output = F;
810
811 fn cost(&self, p: &Self::Param) -> std::result::Result<Self::Output, argmin::core::Error> {
813 let w = p.as_array();
814 let cost = multi_logistic_loss(self.x, &self.target, self.alpha, w);
815 Ok(cost)
816 }
817}
818
819impl<F: Float, A: Data<Elem = F>> Gradient for LogisticRegressionProblem2<'_, F, A> {
820 type Param = ArgminParam<F, Ix2>;
821 type Gradient = ArgminParam<F, Ix2>;
822
823 fn gradient(&self, p: &Self::Param) -> std::result::Result<Self::Param, argmin::core::Error> {
825 let w = p.as_array();
826 let grad = ArgminParam(multi_logistic_grad(self.x, &self.target, self.alpha, w));
827 Ok(grad)
828 }
829}
830
831trait SolvableProblem<F: Float, D: Dimension>: Gradient + Sized {
832 type Solver: Solver<Self, IterStateType<F, D>>;
833}
834
835impl<F: Float, A: Data<Elem = F>> SolvableProblem<F, Ix1> for LogisticRegressionProblem1<'_, F, A> {
836 type Solver = LBFGSType1<F>;
837}
838
839impl<F: Float, A: Data<Elem = F>> SolvableProblem<F, Ix2> for LogisticRegressionProblem2<'_, F, A> {
840 type Solver = LBFGSType2<F>;
841}
842
843#[cfg(test)]
844mod test {
845 extern crate linfa;
846
847 use super::Error;
848 use super::*;
849 use approx::{assert_abs_diff_eq, assert_relative_eq, AbsDiffEq};
850 use linfa::prelude::*;
851 use ndarray::{array, Array2, Dim, Ix};
852
853 #[test]
854 fn autotraits() {
855 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
856 has_autotraits::<LogisticRegressionParams<f64, Dim<[Ix; 0]>>>();
857 has_autotraits::<LogisticRegressionValidParams<f64, Dim<[Ix; 0]>>>();
858 has_autotraits::<ArgminParam<f64, Dim<[Ix; 0]>>>();
859 }
860
861 #[test]
865 fn test_logistic_loss() {
866 let x = array![
867 [0.0],
868 [1.0],
869 [2.0],
870 [3.0],
871 [4.0],
872 [5.0],
873 [6.0],
874 [7.0],
875 [8.0],
876 [9.0]
877 ];
878 let y = array![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
879 let ws = vec![
880 array![0.0, 0.0],
881 array![0.0, 1.0],
882 array![1.0, 0.0],
883 array![1.0, 1.0],
884 array![0.0, -1.0],
885 array![-1.0, 0.0],
886 array![-1.0, -1.0],
887 ];
888 let alphas = &[0.0, 1.0, 10.0];
889 let expecteds = vec![
890 6.931471805599453,
891 6.931471805599453,
892 6.931471805599453,
893 4.652158847349118,
894 4.652158847349118,
895 4.652158847349118,
896 2.8012999588008323,
897 3.3012999588008323,
898 7.801299958800833,
899 2.783195429782239,
900 3.283195429782239,
901 7.783195429782239,
902 10.652158847349117,
903 10.652158847349117,
904 10.652158847349117,
905 41.80129995880083,
906 42.30129995880083,
907 46.80129995880083,
908 47.78319542978224,
909 48.28319542978224,
910 52.78319542978224,
911 ];
912
913 for ((w, alpha), exp) in ws
914 .iter()
915 .flat_map(|w| alphas.iter().map(move |&alpha| (w, alpha)))
916 .zip(&expecteds)
917 {
918 assert_abs_diff_eq!(logistic_loss(&x, &y, alpha, w), *exp);
919 }
920 }
921
922 #[test]
926 fn test_logistic_grad() {
927 let x = array![
928 [0.0],
929 [1.0],
930 [2.0],
931 [3.0],
932 [4.0],
933 [5.0],
934 [6.0],
935 [7.0],
936 [8.0],
937 [9.0]
938 ];
939 let y = array![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
940 let ws = vec![
941 array![0.0, 0.0],
942 array![0.0, 1.0],
943 array![1.0, 0.0],
944 array![1.0, 1.0],
945 array![0.0, -1.0],
946 array![-1.0, 0.0],
947 array![-1.0, -1.0],
948 ];
949 let alphas = &[0.0, 1.0, 10.0];
950 let expecteds = vec![
951 array![-19.5, -3.],
952 array![-19.5, -3.],
953 array![-19.5, -3.],
954 array![-10.48871543, -1.61364853],
955 array![-10.48871543, -1.61364853],
956 array![-10.48871543, -1.61364853],
957 array![-0.13041554, -0.02852148],
958 array![0.86958446, -0.02852148],
959 array![9.86958446, -0.02852148],
960 array![-0.04834401, -0.01058067],
961 array![0.95165599, -0.01058067],
962 array![9.95165599, -0.01058067],
963 array![-28.51128457, -4.38635147],
964 array![-28.51128457, -4.38635147],
965 array![-28.51128457, -4.38635147],
966 array![-38.86958446, -5.97147852],
967 array![-39.86958446, -5.97147852],
968 array![-48.86958446, -5.97147852],
969 array![-38.95165599, -5.98941933],
970 array![-39.95165599, -5.98941933],
971 array![-48.95165599, -5.98941933],
972 ];
973
974 for ((w, alpha), exp) in ws
975 .iter()
976 .flat_map(|w| alphas.iter().map(move |&alpha| (w, alpha)))
977 .zip(&expecteds)
978 {
979 let actual = logistic_grad(&x, &y, alpha, w);
980 assert!(actual.abs_diff_eq(exp, 1e-8));
981 }
982 }
983
984 #[test]
985 fn simple_example_1() {
986 let log_reg = LogisticRegression::default();
987 let x = array![[-1.0], [-0.01], [0.01], [1.0]];
988 let y = array![0, 0, 1, 1];
989 let dataset = Dataset::new(x, y);
990 let res = log_reg.fit(&dataset).unwrap();
991 assert_abs_diff_eq!(res.intercept(), 0.0);
992 assert!(res.params().abs_diff_eq(&array![-0.681], 1e-3));
993 assert_eq!(
994 &res.predict(dataset.records()),
995 dataset.targets().as_single_targets()
996 );
997 }
998
999 #[test]
1000 fn simple_example_1_cats_dogs() {
1001 let log_reg = LogisticRegression::default();
1002 let x = array![[0.01], [1.0], [-1.0], [-0.01]];
1003 let y = array!["dog", "dog", "cat", "cat"];
1004 let dataset = Dataset::new(x, y);
1005 let res = log_reg.fit(&dataset).unwrap();
1006 assert_abs_diff_eq!(res.intercept(), 0.0);
1007 assert!(res.params().abs_diff_eq(&array![0.681], 1e-3));
1008 assert!(res
1009 .predict_probabilities(dataset.records())
1010 .abs_diff_eq(&array![0.501, 0.664, 0.335, 0.498], 1e-3));
1011 assert_eq!(
1012 &res.predict(dataset.records()),
1013 dataset.targets().as_single_targets()
1014 );
1015 assert_eq!(res.labels().pos.class, "dog");
1016 assert_eq!(res.labels().neg.class, "cat");
1017 }
1018
1019 #[test]
1020 fn simple_example_2() {
1021 let log_reg = LogisticRegression::default().alpha(1.0);
1022 let x = array![
1023 [0.0],
1024 [1.0],
1025 [2.0],
1026 [3.0],
1027 [4.0],
1028 [5.0],
1029 [6.0],
1030 [7.0],
1031 [8.0],
1032 [9.0]
1033 ];
1034 let y = array![0, 0, 0, 0, 1, 1, 1, 1, 1, 1];
1035 let dataset = Dataset::new(x, y);
1036 let res = log_reg.fit(&dataset).unwrap();
1037 assert_eq!(
1038 &res.predict(dataset.records()),
1039 dataset.targets().as_single_targets()
1040 );
1041 }
1042
1043 #[test]
1044 fn simple_example_3() {
1045 let x = array![[1.0], [0.0], [1.0], [0.0]];
1046 let y = array![1, 0, 1, 0];
1047 let dataset = DatasetBase::new(x, y);
1048 let model = LogisticRegression::default().fit(&dataset).unwrap();
1049
1050 let pred = model.predict(&dataset.records);
1051 assert_eq!(dataset.targets(), pred);
1052 }
1053
1054 #[test]
1055 fn rejects_mismatching_x_y() {
1056 let log_reg = LogisticRegression::default();
1057 let x = array![[-1.0], [-0.01], [0.01]];
1058 let y = array![0, 0, 1, 1];
1059 let res = log_reg.fit(&Dataset::new(x, y));
1060 assert!(matches!(res.unwrap_err(), Error::MismatchedShapes(3, 4)));
1061 }
1062
1063 #[test]
1064 fn rejects_inf_values() {
1065 let infs = &[f64::INFINITY, f64::NEG_INFINITY, f64::NAN];
1066 let inf_xs: Vec<_> = infs.iter().map(|&inf| array![[1.0], [inf]]).collect();
1067 let log_reg = LogisticRegression::default();
1068 let normal_x = array![[-1.0], [1.0]];
1069 let y = array![0, 1];
1070 for inf_x in &inf_xs {
1071 let res = log_reg.fit(&DatasetBase::new(inf_x.view(), &y));
1072 assert!(matches!(res.unwrap_err(), Error::InvalidValues));
1073 }
1074 for inf in infs {
1075 let log_reg = LogisticRegression::default().alpha(*inf);
1076 let res = log_reg.fit(&DatasetBase::new(normal_x.view(), &y));
1077 assert!(matches!(res.unwrap_err(), Error::InvalidAlpha));
1078 }
1079 let mut non_positives = infs.to_vec();
1080 non_positives.push(-1.0);
1081 non_positives.push(0.0);
1082 for inf in &non_positives {
1083 let log_reg = LogisticRegression::default().gradient_tolerance(*inf);
1084 let res = log_reg.fit(&Dataset::new(normal_x.to_owned(), y.to_owned()));
1085 assert!(matches!(res.unwrap_err(), Error::InvalidGradientTolerance));
1086 }
1087 }
1088
1089 #[test]
1090 fn validates_initial_params() {
1091 let infs = &[f64::INFINITY, f64::NEG_INFINITY, f64::NAN];
1092 let normal_x = array![[-1.0], [1.0]];
1093 let normal_y = array![0, 1];
1094 let dataset = Dataset::new(normal_x, normal_y);
1095 for inf in infs {
1096 let log_reg = LogisticRegression::default().initial_params(array![*inf, 0.0]);
1097 let res = log_reg.fit(&dataset);
1098 assert!(matches!(res.unwrap_err(), Error::InvalidInitialParameters));
1099 }
1100 {
1101 let log_reg = LogisticRegression::default().initial_params(array![0.0, 0.0, 0.0]);
1102 let res = log_reg.fit(&dataset);
1103 assert!(matches!(
1104 res.unwrap_err(),
1105 Error::InitialParameterFeaturesMismatch {
1106 rows: 3,
1107 n_features: 2
1108 }
1109 ));
1110 }
1111 {
1112 let log_reg = LogisticRegression::default()
1113 .with_intercept(false)
1114 .initial_params(array![0.0, 0.0]);
1115 let res = log_reg.fit(&dataset);
1116 assert!(matches!(
1117 res.unwrap_err(),
1118 Error::InitialParameterFeaturesMismatch {
1119 rows: 2,
1120 n_features: 1
1121 }
1122 ));
1123 }
1124 }
1125
1126 #[test]
1127 fn uses_initial_params() {
1128 let params = array![1.2, -4.12];
1129 let log_reg = LogisticRegression::default()
1130 .initial_params(params)
1131 .max_iterations(5);
1132 let x = array![
1133 [0.0],
1134 [1.0],
1135 [2.0],
1136 [3.0],
1137 [4.0],
1138 [5.0],
1139 [6.0],
1140 [7.0],
1141 [8.0],
1142 [9.0]
1143 ];
1144 let y = array![0, 0, 0, 0, 1, 1, 1, 1, 1, 1];
1145 let dataset = Dataset::new(x, y);
1146 let res = log_reg.fit(&dataset).unwrap();
1147 assert!(res.intercept().abs_diff_eq(&-4.124, 1e-3));
1148 assert!(res.params().abs_diff_eq(&array![1.181], 1e-3));
1149 assert_eq!(
1150 &res.predict(dataset.records()),
1151 dataset.targets().as_single_targets()
1152 );
1153
1154 #[cfg(feature = "serde")]
1156 {
1157 let ser = rmp_serde::to_vec(&res).unwrap();
1158 let unser: FittedLogisticRegression<f32, f32> = rmp_serde::from_slice(&ser).unwrap();
1159
1160 let x = array![[1.0]];
1161 let y_hat = unser.predict(&x);
1162
1163 assert!(y_hat[0] == 0.0);
1164 }
1165 }
1166
1167 #[test]
1168 fn works_with_f32() {
1169 let log_reg = LogisticRegression::default();
1170 let x: Array2<f32> = array![[-1.0], [-0.01], [0.01], [1.0]];
1171 let y = array![0, 0, 1, 1];
1172 let dataset = Dataset::new(x, y);
1173 let res = log_reg.fit(&dataset).unwrap();
1174 assert_abs_diff_eq!(res.intercept(), 0.0_f32);
1175 assert!(res.params().abs_diff_eq(&array![-0.682_f32], 1e-3));
1176 assert_eq!(
1177 &res.predict(dataset.records()),
1178 dataset.targets().as_single_targets()
1179 );
1180 }
1181
1182 #[test]
1183 fn test_log_sum_exp() {
1184 let data = array![[3.3, 0.4, -2.1], [0.4, 2.2, -0.1], [1., 0., -1.]];
1185 let out = log_sum_exp(&data, Axis(1));
1186 assert_abs_diff_eq!(out, array![3.35783, 2.43551, 1.40761], epsilon = 1e-5);
1187 }
1188
1189 #[test]
1190 fn test_softmax() {
1191 let mut data = array![3.3, 5.5, 0.1, -4.4, 8.0];
1192 softmax_inplace(&mut data);
1193 assert_relative_eq!(
1194 data,
1195 array![0.0083324, 0.075200047, 0.000339647, 0.000003773, 0.91612413],
1196 epsilon = 1e-8
1197 );
1198 assert_abs_diff_eq!(data.sum(), 1.0);
1199 }
1200
1201 #[test]
1202 fn test_multi_logistic_loss_grad() {
1203 let x = array![
1204 [0.0, 0.5],
1205 [1.0, -1.0],
1206 [2.0, -2.0],
1207 [3.0, -3.0],
1208 [4.0, -4.0],
1209 [5.0, -5.0],
1210 [6.0, -6.0],
1211 [7.0, -7.0],
1212 ];
1213 let y = array![
1214 [1.0, 0.0, 0.0],
1215 [1.0, 0.0, 0.0],
1216 [0.0, 1.0, 0.0],
1217 [0.0, 1.0, 0.0],
1218 [0.0, 1.0, 0.0],
1219 [0.0, 0.0, 1.0],
1220 [0.0, 0.0, 1.0],
1221 [0.0, 0.0, 1.0],
1222 ];
1223 let params1 = array![[4.4, -1.2, 3.3], [3.4, 0.1, 0.0]];
1224 let params2 = array![[0.001, -3.2, 2.9], [0.1, 4.5, 5.7], [4.5, 2.2, 1.7]];
1225 let alpha = 0.6;
1226
1227 {
1228 let (log_prob, w) = multi_logistic_prob_params(&x, ¶ms1);
1229 assert_abs_diff_eq!(
1230 log_prob,
1231 array![
1232 [-3.18259845e-01, -1.96825985e+00, -2.01825985e+00],
1233 [-2.40463987e+00, -4.70463987e+00, -1.04639868e-01],
1234 [-4.61010168e+00, -9.21010168e+00, -1.01016809e-02],
1235 [-6.90100829e+00, -1.38010083e+01, -1.00829256e-03],
1236 [-9.20010104e+00, -1.84001010e+01, -1.01044506e-04],
1237 [-1.15000101e+01, -2.30000101e+01, -1.01301449e-05],
1238 [-1.38000010e+01, -2.76000010e+01, -1.01563199e-06],
1239 [-1.61000001e+01, -3.22000001e+01, -1.01826043e-07],
1240 ],
1241 epsilon = 1e-6
1242 );
1243 assert_abs_diff_eq!(w, params1);
1244 let loss = multi_logistic_loss(&x, &y, alpha, ¶ms1);
1245 assert_abs_diff_eq!(loss, 57.11212197835295, epsilon = 1e-6);
1246 let grad = multi_logistic_grad(&x, &y, alpha, ¶ms1);
1247 assert_abs_diff_eq!(
1248 grad,
1249 array![
1250 [1.7536815, -9.71074369, 11.85706219],
1251 [2.79002537, 9.12059357, -9.81061893]
1252 ],
1253 epsilon = 1e-6
1254 );
1255 }
1256
1257 {
1258 let (log_prob, w) = multi_logistic_prob_params(&x, ¶ms2);
1259 assert_abs_diff_eq!(
1260 log_prob,
1261 array![
1262 [-1.06637742e+00, -1.16637742e+00, -1.06637742e+00],
1263 [-4.12429463e-03, -9.90512429e+00, -5.50512429e+00],
1264 [-2.74092305e-04, -1.75022741e+01, -8.20227409e+00],
1265 [-1.84027855e-05, -2.51030184e+01, -1.09030184e+01],
1266 [-1.23554225e-06, -3.27040012e+01, -1.36040012e+01],
1267 [-8.29523046e-08, -4.03050001e+01, -1.63050001e+01],
1268 [-5.56928016e-09, -4.79060000e+01, -1.90060000e+01],
1269 [-3.73912013e-10, -5.55070000e+01, -2.17070000e+01]
1270 ],
1271 epsilon = 1e-6
1272 );
1273 assert_abs_diff_eq!(w, params2.slice(s![..params2.nrows() - 1, ..]));
1274 let loss = multi_logistic_loss(&x, &y, alpha, ¶ms2);
1275 assert_abs_diff_eq!(loss, 154.8177958366479, epsilon = 1e-6);
1276 let grad = multi_logistic_grad(&x, &y, alpha, ¶ms2);
1277 assert_abs_diff_eq!(
1278 grad,
1279 array![
1280 [26.99587549, -10.91995003, -16.25532546],
1281 [-27.26314882, 11.85569669, 21.58745213],
1282 [5.33984376, -2.68845675, -2.65138701]
1283 ],
1284 epsilon = 1e-6
1285 );
1286 }
1287 }
1288
1289 #[test]
1290 fn simple_multi_example() {
1291 let x = array![[-1., 0.], [0., 1.], [1., 1.]];
1292 let y = array![2, 1, 0];
1293 let log_reg = MultiLogisticRegression::default()
1294 .alpha(0.1)
1295 .initial_params(Array::zeros((3, 3)));
1296 let dataset = Dataset::new(x, y);
1297 let res = log_reg.fit(&dataset).unwrap();
1298 assert_eq!(res.params().dim(), (2, 3));
1299 assert_eq!(res.intercept().dim(), 3);
1300 assert_eq!(
1301 &res.predict(dataset.records()),
1302 dataset.targets().as_single_targets()
1303 );
1304 }
1305
1306 #[test]
1307 fn simple_multi_example_2() {
1308 let x = array![[1.0], [0.0], [1.0], [0.0]];
1309 let y = array![1, 0, 1, 0];
1310 let dataset = DatasetBase::new(x, y);
1311 let model = MultiLogisticRegression::default().fit(&dataset).unwrap();
1312
1313 let pred = model.predict(&dataset.records);
1314 assert_eq!(dataset.targets(), pred);
1315 }
1316
1317 #[test]
1318 fn simple_multi_example_text() {
1319 let log_reg = MultiLogisticRegression::default().alpha(0.1);
1320 let x = array![[0.1], [1.0], [-1.0], [-0.1]];
1321 let y = array!["dog", "ape", "rocket", "cat"];
1322 let dataset = Dataset::new(x, y);
1323 let res = log_reg.fit(&dataset).unwrap();
1324 assert_eq!(res.params().dim(), (1, 4));
1325 assert_eq!(res.intercept().dim(), 4);
1326 assert_eq!(
1327 &res.predict(dataset.records()),
1328 dataset.targets().as_single_targets()
1329 );
1330 }
1331
1332 #[test]
1333 fn multi_on_binary_problem() {
1334 let log_reg = MultiLogisticRegression::default().alpha(1.0);
1335 let x = array![
1336 [0.0],
1337 [1.0],
1338 [2.0],
1339 [3.0],
1340 [4.0],
1341 [5.0],
1342 [6.0],
1343 [7.0],
1344 [8.0],
1345 [9.0]
1346 ];
1347 let y = array![0, 0, 0, 0, 1, 1, 1, 1, 1, 1];
1348 let dataset = Dataset::new(x, y);
1349 let res = log_reg.fit(&dataset).unwrap();
1350 assert_eq!(res.params().dim(), (1, 2));
1351 assert_eq!(res.intercept().dim(), 2);
1352 assert_eq!(
1353 &res.predict(dataset.records()),
1354 dataset.targets().as_single_targets()
1355 );
1356 }
1357
1358 #[test]
1359 fn reject_num_class_mismatch() {
1360 let n_samples = 4;
1361 let n_classes = 3;
1362 let n_features = 1;
1363 let x = Array2::<f64>::zeros((n_samples, n_features));
1364 let y = array![0, 1, 2, 0];
1365 let dataset = Dataset::new(x, y);
1366
1367 let log_reg = MultiLogisticRegression::default()
1368 .with_intercept(false)
1369 .initial_params(Array::zeros((n_features, n_classes - 1)));
1370 assert!(matches!(
1371 log_reg.fit(&dataset).unwrap_err(),
1372 Error::InitialParameterClassesMismatch {
1373 cols: 2,
1374 n_classes: 3,
1375 }
1376 ));
1377 }
1378}