linfa_logistic/
lib.rs

1//! # Logistic Regression
2//!
3//! ## The Big Picture
4//!
5//! `linfa-logistic` is a crate in the [`linfa`](https://crates.io/crates/linfa) ecosystem, an effort to create a toolkit for classical Machine Learning implemented in pure Rust, akin to Python's `scikit-learn`.
6//!
7//! ## Current state
8//! `linfa-logistic` provides a pure Rust implementation of a [binomial logistic regression model](LogisticRegression) and a [multinomial logistic regression model](MultiLogisticRegression).
9//!
10//! ## Examples
11//!
12//! There is an usage example in the `examples/` directory. To run, use:
13//!
14//! ```bash
15//! $ cargo run --example winequality
16//! ```
17//!
18
19pub mod error;
20
21use crate::error::{Error, Result};
22use argmin::core::{CostFunction, Executor, Gradient, IterState, OptimizationResult, Solver};
23use argmin::solver::linesearch::MoreThuenteLineSearch;
24use argmin::solver::quasinewton::LBFGS;
25use linfa::dataset::AsSingleTargets;
26use linfa::prelude::DatasetBase;
27use linfa::traits::{Fit, PredictInplace};
28use ndarray::{
29    s, Array, Array1, Array2, ArrayBase, ArrayView, ArrayView2, Axis, CowArray, Data, DataMut,
30    Dimension, IntoDimension, Ix1, Ix2, RemoveAxis, Slice, Zip,
31};
32use ndarray_stats::QuantileExt;
33use std::default::Default;
34
35#[cfg(feature = "serde")]
36use serde_crate::de::DeserializeOwned;
37#[cfg(feature = "serde")]
38use serde_crate::{Deserialize, Serialize};
39
40mod argmin_param;
41mod float;
42mod hyperparams;
43
44use argmin_param::*;
45use float::Float;
46use hyperparams::{LogisticRegressionParams, LogisticRegressionValidParams};
47
48/// A two-class logistic regression model.
49///
50/// Logistic regression combines linear models with
51/// the sigmoid function `sigm(x) = 1/(1+exp(-x))`
52/// to learn a family of functions that map the feature space to `[0,1]`.
53///
54/// Logistic regression is used in binary classification
55/// by interpreting the predicted value as the probability that the sample
56/// has label `1`. A threshold can be set in the [fitted model](FittedLogisticRegression) to decide the minimum
57/// probability needed to classify a sample as `1`, which defaults to `0.5`.
58///
59/// In this implementation any binary set of labels can be used, not necessarily `0` and `1`.
60///
61/// l2 regularization is used by this algorithm and is weighted by parameter `alpha`. Setting `alpha`
62/// close to zero removes regularization and the problem solved minimizes only the
63/// empirical risk. On the other hand, setting `alpha` to a high value increases
64/// the weight of the l2 norm of the linear model coefficients in the cost function.
65///
66/// ## Examples
67///
68/// Here's an example on how to train a logistic regression model on the `winequality` dataset
69/// ```rust
70/// use linfa::traits::{Fit, Predict};
71/// use linfa_logistic::LogisticRegression;
72///
73/// // Example on using binary labels different from 0 and 1
74/// let dataset = linfa_datasets::winequality().map_targets(|x| if *x > 6 { "good" } else { "bad" });
75/// let model = LogisticRegression::default().fit(&dataset).unwrap();
76/// let prediction = model.predict(&dataset);
77/// ```
78pub type LogisticRegression<F> = LogisticRegressionParams<F, Ix1>;
79
80/// Validated version of `LogisticRegression`
81pub type ValidLogisticRegression<F> = LogisticRegressionValidParams<F, Ix1>;
82
83/// A multinomial class logistic regression model.
84///
85/// The output labels can map to any discrete feature space, since the algorithm calculates the
86/// likelihood of a feature vector corresponding to any given outcome using the softmax function
87/// `softmax(x) = exp(x) / sum(exp(xi))`
88///
89/// l2 regularization is used by this algorithm and is weighted by parameter `alpha`. Setting `alpha`
90/// close to zero removes regularization and the problem solved minimizes only the
91/// empirical risk. On the other hand, setting `alpha` to a high value increases
92/// the weight of the l2 norm of the linear model coefficients in the cost function.
93pub type MultiLogisticRegression<F> = LogisticRegressionParams<F, Ix2>;
94
95/// Validated version of `MultiLogisticRegression`
96pub type ValidMultiLogisticRegression<F> = LogisticRegressionValidParams<F, Ix2>;
97
98impl<F: Float, D: Dimension> Default for LogisticRegressionParams<F, D> {
99    fn default() -> Self {
100        LogisticRegressionParams::new()
101    }
102}
103
104type LBFGSType<F, D> = LBFGS<
105    MoreThuenteLineSearch<ArgminParam<F, D>, ArgminParam<F, D>, F>,
106    ArgminParam<F, D>,
107    ArgminParam<F, D>,
108    F,
109>;
110type LBFGSType1<F> = LBFGSType<F, Ix1>;
111type LBFGSType2<F> = LBFGSType<F, Ix2>;
112
113type IterStateType<F, D> = IterState<ArgminParam<F, D>, ArgminParam<F, D>, (), (), (), F>;
114
115impl<F: Float, D: Dimension> LogisticRegressionValidParams<F, D> {
116    /// Create the initial parameters, either from a user supplied array
117    /// or an array of 0s
118    fn setup_init_params(&self, dims: D::Pattern) -> ArgminParam<F, D> {
119        if let Some(params) = self.initial_params.as_ref() {
120            ArgminParam(params.clone())
121        } else {
122            let mut dims = dims.into_dimension();
123            dims.as_array_view_mut()[0] += self.fit_intercept as usize;
124            ArgminParam(Array::zeros(dims))
125        }
126    }
127
128    /// Ensure that `x` and `y` have the right shape and that all data and
129    /// configuration parameters are finite.
130    fn validate_data<A: Data<Elem = F>, B: Data<Elem = F>>(
131        &self,
132        x: &ArrayBase<A, Ix2>,
133        y: &ArrayBase<B, D>,
134    ) -> Result<()> {
135        if x.shape()[0] != y.shape()[0] {
136            return Err(Error::MismatchedShapes(x.shape()[0], y.shape()[0]));
137        }
138        if x.iter().any(|x| !x.is_finite()) || y.iter().any(|y| !y.is_finite()) {
139            return Err(Error::InvalidValues);
140        }
141        self.validate_init_dims(x.shape()[1], y.shape().get(1).copied())?;
142        Ok(())
143    }
144
145    fn validate_init_dims(&self, mut n_features: usize, n_classes: Option<usize>) -> Result<()> {
146        if let Some(params) = self.initial_params.as_ref() {
147            let shape = params.shape();
148            n_features += self.fit_intercept as usize;
149            if n_features != shape[0] {
150                return Err(Error::InitialParameterFeaturesMismatch {
151                    n_features,
152                    rows: shape[0],
153                });
154            }
155            if let Some(n_classes) = n_classes {
156                if n_classes != shape[1] {
157                    return Err(Error::InitialParameterClassesMismatch {
158                        n_classes,
159                        cols: shape[1],
160                    });
161                }
162            }
163        }
164        Ok(())
165    }
166
167    /// Create a `LogisticRegressionProblem`.
168    fn setup_problem<'a, A: Data<Elem = F>>(
169        &self,
170        x: &'a ArrayBase<A, Ix2>,
171        target: Array<F, D>,
172    ) -> LogisticRegressionProblem<'a, F, A, D> {
173        LogisticRegressionProblem {
174            x,
175            target,
176            alpha: self.alpha,
177        }
178    }
179
180    /// Create the LBFGS solver using MoreThuenteLineSearch and set gradient
181    /// tolerance.
182    fn setup_solver(&self) -> LBFGSType<F, D> {
183        let linesearch = MoreThuenteLineSearch::new();
184        LBFGS::new(linesearch, 10)
185            .with_tolerance_grad(self.gradient_tolerance)
186            .unwrap()
187    }
188}
189
190impl<
191        F: Float,
192        #[cfg(feature = "serde")] D: Dimension + Serialize + DeserializeOwned,
193        #[cfg(not(feature = "serde"))] D: Dimension,
194    > LogisticRegressionValidParams<F, D>
195{
196    /// Run the LBFGS solver until it converges or runs out of iterations.
197    fn run_solver<P: SolvableProblem<F, D>>(
198        &self,
199        problem: P,
200        solver: P::Solver,
201        init_params: ArgminParam<F, D>,
202    ) -> Result<OptimizationResult<P, P::Solver, IterStateType<F, D>>> {
203        Executor::new(problem, solver)
204            .configure(|state| state.param(init_params).max_iters(self.max_iterations))
205            .run()
206            .map_err(move |err| err.into())
207    }
208}
209
210impl<C: Ord + Clone, F: Float, D: Data<Elem = F>, T: AsSingleTargets<Elem = C>>
211    Fit<ArrayBase<D, Ix2>, T, Error> for ValidLogisticRegression<F>
212{
213    type Object = FittedLogisticRegression<F, C>;
214
215    /// Given a 2-dimensional feature matrix array `x` with shape
216    /// (n_samples, n_features) and an array of target classes to predict,
217    /// create a `FittedLinearRegression` object which allows making
218    /// predictions.
219    ///
220    /// The array of target classes `y` must have exactly two discrete values, (e.g. 0 and 1, "cat"
221    /// and "dog", ...), which represent the two different classes the model is supposed to
222    /// predict.
223    ///
224    /// The array `y` must also have exactly `n_samples` items, i.e.
225    /// exactly as many items as there are rows in the feature matrix `x`.
226    ///
227    /// This method returns an error if any of the preconditions are violated,
228    /// i.e. any values are `Inf` or `NaN`, `y` doesn't have as many items as
229    /// `x` has rows, or if other parameters (gradient_tolerance, alpha) have
230    /// been set to inalid values.
231    fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
232        let (x, y) = (dataset.records(), dataset.targets());
233        let (labels, target) = label_classes(y)?;
234        self.validate_data(x, &target)?;
235        let problem = self.setup_problem(x, target);
236        let solver = self.setup_solver();
237        let init_params = self.setup_init_params(x.ncols());
238        let result = self.run_solver(problem, solver, init_params)?;
239
240        let params = result
241            .state
242            .best_param
243            .unwrap_or(self.setup_init_params(x.ncols()));
244        let (w, intercept) = convert_params(x.ncols(), params.as_array());
245        Ok(FittedLogisticRegression::new(
246            *intercept.view().into_scalar(),
247            w.to_owned(),
248            labels,
249        ))
250    }
251}
252
253impl<C: Ord + Clone, F: Float, D: Data<Elem = F>, T: AsSingleTargets<Elem = C>>
254    Fit<ArrayBase<D, Ix2>, T, Error> for ValidMultiLogisticRegression<F>
255{
256    type Object = MultiFittedLogisticRegression<F, C>;
257
258    /// Given a 2-dimensional feature matrix array `x` with shape
259    /// (n_samples, n_features) and an array of target classes to predict,
260    /// create a `MultiFittedLogisticRegression` object which allows making
261    /// predictions. The target classes can have any number of discrete values.
262    ///
263    /// This method returns an error if any of the preconditions are violated,
264    /// i.e. any values are `Inf` or `NaN`, `y` doesn't have as many items as
265    /// `x` has rows, or if other parameters (gradient_tolerance, alpha) have
266    /// been set to inalid values. The input features are also strongly recommended to be
267    /// normalized to ensure numerical stability.
268    fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
269        let (x, y) = (dataset.records(), dataset.targets());
270        let (classes, target) = label_classes_multi(y)?;
271        self.validate_data(x, &target)?;
272        let problem = self.setup_problem(x, target);
273        let solver = self.setup_solver();
274        let init_params = self.setup_init_params((x.ncols(), classes.len()));
275        let result = self.run_solver(problem, solver, init_params)?;
276
277        let params = result
278            .state
279            .best_param
280            .unwrap_or(self.setup_init_params((x.ncols(), classes.len())));
281        let (w, intercept) = convert_params(x.ncols(), params.as_array());
282        Ok(MultiFittedLogisticRegression::new(
283            intercept.to_owned(),
284            w.to_owned(),
285            classes,
286        ))
287    }
288}
289
290/// Identify the distinct values of the classes `y` and associate
291/// the target labels `-1.0` and `1.0` to it. -1.0 always labels the
292/// smaller class (by PartialOrd) and 1.0 always labels the larger
293/// class.
294///
295/// It is an error to have more than two classes.
296fn label_classes<F, T, C>(y: T) -> Result<(BinaryClassLabels<F, C>, Array1<F>)>
297where
298    F: Float,
299    T: AsSingleTargets<Elem = C>,
300    C: Ord + Clone,
301{
302    let y = y.as_single_targets();
303
304    let mut binary_classes = [None, None];
305    for class in y {
306        binary_classes = match binary_classes {
307            [None, None] => [Some((class, 1)), None],
308            [Some((c, count)), c2] if c == class => [Some((class, count + 1)), c2],
309            [c1, Some((c, count))] if c == class => [c1, Some((class, count + 1))],
310            [Some(c1), None] => [Some(c1), Some((class, 1))],
311            [None, Some(_)] => unreachable!("impossible binary class array"),
312            [Some(_), Some(_)] => return Err(Error::TooManyClasses),
313        };
314    }
315
316    let (class_a, class_b) = match binary_classes {
317        [Some(a), Some(b)] => (a, b),
318        _ => return Err(Error::TooFewClasses),
319    };
320
321    // Sort by label value (Ord), not by encounter order or count.
322    // The smaller label is always negative (-1),
323    // the larger label is always positive (+1).
324    let (neg_class, pos_class) = if class_a.0 < class_b.0 {
325        (class_a, class_b)
326    } else {
327        (class_b, class_a)
328    };
329
330    let target_array = y
331        .into_iter()
332        .map(|x| {
333            if x == pos_class.0 {
334                F::POSITIVE_LABEL
335            } else {
336                F::NEGATIVE_LABEL
337            }
338        })
339        .collect::<Array1<_>>();
340
341    Ok((
342        BinaryClassLabels {
343            pos: ClassLabel {
344                class: pos_class.0.clone(),
345                label: F::POSITIVE_LABEL,
346            },
347            neg: ClassLabel {
348                class: neg_class.0.clone(),
349                label: F::NEGATIVE_LABEL,
350            },
351        },
352        target_array,
353    ))
354}
355
356/// Identify the distinct values of the classes in `y` and map each value to an integer. Smaller
357/// classes (by `PartialOrd`) map to smaller integers. Returns the mapping along with a one-hot
358/// encoding of the numerical labels corresponding to `y`.
359fn label_classes_multi<F, T, C>(y: T) -> Result<(Vec<C>, Array2<F>)>
360where
361    F: Float,
362    T: AsSingleTargets<Elem = C>,
363    C: Ord + Clone,
364{
365    let y_single_target = y.as_single_targets();
366    let mut classes = y_single_target.to_vec();
367    // Dedup the list of classes
368    classes.sort();
369    classes.dedup();
370
371    let mut onehot = Array2::zeros((y_single_target.len(), classes.len()));
372    Zip::from(onehot.rows_mut())
373        .and(&y_single_target)
374        .for_each(|mut oh_row, cls| {
375            let idx = classes.binary_search(cls).unwrap();
376            oh_row[idx] = F::one();
377        });
378    Ok((classes, onehot))
379}
380
381/// Conditionally split the feature vector `w` into parameter vector and
382/// intercept parameter.
383/// Dimensions of `w` are either (f) or (f, n_classes)
384fn convert_params<F: Float, D: Dimension + RemoveAxis>(
385    n_features: usize,
386    w: &Array<F, D>,
387) -> (ArrayView<'_, F, D>, CowArray<'_, F, D::Smaller>) {
388    let nrows = w.shape()[0];
389    if n_features == nrows {
390        (
391            w.view(),
392            Array::zeros(w.raw_dim().remove_axis(Axis(0))).into(),
393        )
394    } else if n_features + 1 == nrows {
395        (
396            w.slice_axis(Axis(0), Slice::from(..n_features)),
397            w.index_axis(Axis(0), n_features).into(),
398        )
399    } else {
400        panic!(
401            "Unexpected length of parameter vector `w`, exected {} or {}, found {}",
402            n_features,
403            n_features + 1,
404            nrows
405        );
406    }
407}
408
409/// The logistic function
410fn logistic<F: linfa::Float>(x: F) -> F {
411    F::one() / (F::one() + (-x).exp())
412}
413
414/// A numerically stable version of the log of the logistic function.
415///
416/// Taken from scikit-learn
417/// https://github.com/scikit-learn/scikit-learn/blob/0.23.1/sklearn/utils/_logistic_sigmoid.pyx
418///
419/// See the blog post describing this implementation:
420/// http://fa.bianp.net/blog/2013/numerical-optimizers-for-logistic-regression/
421fn log_logistic<F: linfa::Float>(x: F) -> F {
422    if x > F::zero() {
423        -(F::one() + (-x).exp()).ln()
424    } else {
425        x - (F::one() + x.exp()).ln()
426    }
427}
428
429/// Finds the log of the sum of exponents across a specific axis in a numerically stable way. More
430/// specifically, computes `ln(exp(x1) + exp(x2) + exp(e3) + ...)` across an axis.
431///
432/// Based off this implementation: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.logsumexp.html
433fn log_sum_exp<F: linfa::Float, A: Data<Elem = F>>(
434    m: &ArrayBase<A, Ix2>,
435    axis: Axis,
436) -> Array<F, Ix1> {
437    // Find max value of the array
438    let max = m.iter().copied().reduce(F::max).unwrap();
439    // Computes `max + ln(exp(x1-max) + exp(x2-max) + exp(x3-max) + ...)`, which is equal to the
440    // log_sum_exp formula
441    let reduced = m.fold_axis(axis, F::zero(), |acc, elem| *acc + (*elem - max).exp());
442    reduced.mapv_into(|e| e.max(F::cast(1e-15)).ln() + max)
443}
444
445/// Computes `exp(n - max) / sum(exp(n- max))`, which is a numerically stable version of softmax
446fn softmax_inplace<F: linfa::Float, A: DataMut<Elem = F>>(v: &mut ArrayBase<A, Ix1>) {
447    let max = v.iter().copied().reduce(F::max).unwrap();
448    v.mapv_inplace(|n| (n - max).exp());
449    let sum = v.sum();
450    v.mapv_inplace(|n| n / sum);
451}
452
453/// Computes the logistic loss assuming the training labels $y \in {-1, 1}$
454///
455/// Because the logistic function fullfills $\sigma(-z) = 1 - \sigma(z)$
456/// we can write $P(y=1|z) = \sigma(z) = \sigma(yz)$ and
457/// $P(y=-1|z) = 1 - P(y=1|z) = 1 - \sigma(z) = \sigma(-z) = \sigma(yz)$, so
458/// $P(y|z) = \sigma(yz)$ for both $y=1$ and $y=-1$.
459///
460/// Thus, the log loss can be written as
461/// $$-\sum_{i=1}^{N} \log(\sigma(y_i z_i)) + \frac{\alpha}{2}\text{params}^T\text{params}$$
462fn logistic_loss<F: Float, A: Data<Elem = F>>(
463    x: &ArrayBase<A, Ix2>,
464    y: &Array1<F>,
465    alpha: F,
466    w: &Array1<F>,
467) -> F {
468    let n_features = x.shape()[1];
469    let (params, intercept) = convert_params(n_features, w);
470    let yz = x.dot(&params.into_shape_with_order((params.len(), 1)).unwrap()) + intercept;
471    let len = yz.len();
472    let mut yz = yz.into_shape_with_order(len).unwrap() * y;
473    yz.mapv_inplace(log_logistic);
474    -yz.sum() + F::cast(0.5) * alpha * params.dot(&params)
475}
476
477/// Computes the gradient of the logistic loss function
478fn logistic_grad<F: Float, A: Data<Elem = F>>(
479    x: &ArrayBase<A, Ix2>,
480    y: &Array1<F>,
481    alpha: F,
482    w: &Array1<F>,
483) -> Array1<F> {
484    let n_features = x.shape()[1];
485    let (params, intercept) = convert_params(n_features, w);
486    let yz = x.dot(&params.into_shape_with_order((params.len(), 1)).unwrap()) + intercept;
487    let len = yz.len();
488    let mut yz = yz.into_shape_with_order(len).unwrap() * y;
489    yz.mapv_inplace(logistic);
490    yz -= F::one();
491    yz *= y;
492    if w.len() == n_features + 1 {
493        let mut grad = Array::zeros(w.len());
494        grad.slice_mut(s![..n_features])
495            .assign(&(x.t().dot(&yz) + (&params * alpha)));
496        grad[n_features] = yz.sum();
497        grad
498    } else {
499        x.t().dot(&yz) + (&params * alpha)
500    }
501}
502
503/// Compute the log of probabilities, which is `log(softmax(H))`, where H is `X . W + b`. Also
504/// returns `W` without the intercept.
505/// `Y` is the output (n_samples * n_classes), `X` is the input (n_samples * n_features), `W` is the
506/// params (n_features * n_classes), `b` is the intercept vector (n_classes).
507fn multi_logistic_prob_params<'a, F: Float, A: Data<Elem = F>>(
508    x: &ArrayBase<A, Ix2>,
509    w: &'a Array2<F>, // This parameter includes `W` and `b`
510) -> (Array2<F>, ArrayView2<'a, F>) {
511    let n_features = x.shape()[1];
512    let (params, intercept) = convert_params(n_features, w);
513    // Compute H
514    let h = x.dot(&params) + intercept;
515    // This computes `H - log(sum(exp(H)))`, which is equal to
516    // `log(softmax(H)) = log(exp(H) / sum(exp(H)))`
517    let log_prob = &h
518        - log_sum_exp(&h, Axis(1))
519            .into_shape_with_order((h.nrows(), 1))
520            .unwrap();
521    (log_prob, params)
522}
523
524/// Computes loss function of `-sum(Y * log(softmax(H))) + alpha/2 * norm(W)`
525fn multi_logistic_loss<F: Float, A: Data<Elem = F>>(
526    x: &ArrayBase<A, Ix2>,
527    y: &Array2<F>,
528    alpha: F,
529    w: &Array2<F>,
530) -> F {
531    let (log_prob, params) = multi_logistic_prob_params(x, w);
532    // Calculate loss
533    -elem_dot(&log_prob, y) + F::cast(0.5) * alpha * elem_dot(&params, &params)
534}
535
536/// Computes multinomial gradients for `W` and `b` and combine them.
537/// Gradient for `W` is `Xt . (softmax(H) - Y) + alpha * W`.
538/// Gradient for `b` is `sum(softmax(H) - Y)`.
539fn multi_logistic_grad<F: Float, A: Data<Elem = F>>(
540    x: &ArrayBase<A, Ix2>,
541    y: &Array2<F>,
542    alpha: F,
543    w: &Array2<F>,
544) -> Array2<F> {
545    let (log_prob, params) = multi_logistic_prob_params(x, w);
546    let (n_features, n_classes) = params.dim();
547    let intercept = w.nrows() > n_features;
548    let mut grad = Array::zeros((n_features + intercept as usize, n_classes));
549
550    // This value is `softmax(H)`
551    let prob = log_prob.mapv_into(num_traits::Float::exp);
552    let diff = prob - y;
553    // Compute gradient for `W` and place it at start of the grad matrix
554    let dw = x.t().dot(&diff) + (&params * alpha);
555    grad.slice_mut(s![..n_features, ..]).assign(&dw);
556    // Compute gradient for `b` and place it at end of grad matrix
557    if intercept {
558        grad.row_mut(n_features).assign(&diff.sum_axis(Axis(0)));
559    }
560    grad
561}
562
563/// A fitted logistic regression which can make predictions
564#[derive(Debug, Clone, PartialEq)]
565#[cfg_attr(
566    feature = "serde",
567    derive(Serialize, Deserialize),
568    serde(crate = "serde_crate")
569)]
570pub struct FittedLogisticRegression<F: Float, C: PartialOrd + Clone> {
571    threshold: F,
572    intercept: F,
573    params: Array1<F>,
574    labels: BinaryClassLabels<F, C>,
575}
576
577impl<F: Float, C: PartialOrd + Clone> FittedLogisticRegression<F, C> {
578    fn new(
579        intercept: F,
580        params: Array1<F>,
581        labels: BinaryClassLabels<F, C>,
582    ) -> FittedLogisticRegression<F, C> {
583        FittedLogisticRegression {
584            threshold: F::cast(0.5),
585            intercept,
586            params,
587            labels,
588        }
589    }
590
591    /// Set the probability threshold for which the 'positive' class will be
592    /// predicted. Defaults to 0.5.
593    pub fn set_threshold(mut self, threshold: F) -> FittedLogisticRegression<F, C> {
594        if threshold < F::zero() || threshold > F::one() {
595            panic!("FittedLogisticRegression::set_threshold: threshold needs to be between 0.0 and 1.0");
596        }
597        self.threshold = threshold;
598        self
599    }
600
601    pub fn intercept(&self) -> F {
602        self.intercept
603    }
604
605    pub fn params(&self) -> &Array1<F> {
606        &self.params
607    }
608
609    /// Get the model positive and negative classes mapped to their
610    /// corresponding problem input labels.
611    pub fn labels(&self) -> &BinaryClassLabels<F, C> {
612        &self.labels
613    }
614
615    /// Given a feature matrix, predict the probabilities that a sample
616    /// should be classified as the larger of the two classes learned when the
617    /// model was fitted.
618    pub fn predict_probabilities<A: Data<Elem = F>>(&self, x: &ArrayBase<A, Ix2>) -> Array1<F> {
619        let mut probs = x.dot(&self.params) + self.intercept;
620        probs.mapv_inplace(logistic);
621        probs
622    }
623}
624
625impl<C: PartialOrd + Clone + Default, F: Float, D: Data<Elem = F>>
626    PredictInplace<ArrayBase<D, Ix2>, Array1<C>> for FittedLogisticRegression<F, C>
627{
628    /// Given a feature matrix, predict the classes learned when the model was
629    /// fitted.
630    fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<C>) {
631        assert_eq!(
632            x.nrows(),
633            y.len(),
634            "The number of data points must match the number of output targets."
635        );
636        assert_eq!(
637            x.ncols(),
638            self.params.len(),
639            "Number of data features must match the number of features the model was trained with."
640        );
641
642        let pos_class = &self.labels.pos.class;
643        let neg_class = &self.labels.neg.class;
644        Zip::from(&self.predict_probabilities(x))
645            .and(y)
646            .for_each(|prob, out| {
647                *out = if *prob >= self.threshold {
648                    pos_class.clone()
649                } else {
650                    neg_class.clone()
651                }
652            });
653    }
654
655    fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<C> {
656        Array1::default(x.nrows())
657    }
658}
659
660/// A fitted multinomial logistic regression which can make predictions
661#[derive(Debug, Clone, PartialEq, Eq)]
662#[cfg_attr(
663    feature = "serde",
664    derive(Serialize, Deserialize),
665    serde(crate = "serde_crate")
666)]
667pub struct MultiFittedLogisticRegression<F, C: PartialOrd + Clone> {
668    intercept: Array1<F>,
669    params: Array2<F>,
670    classes: Vec<C>,
671}
672
673impl<F: Float, C: PartialOrd + Clone> MultiFittedLogisticRegression<F, C> {
674    fn new(intercept: Array1<F>, params: Array2<F>, classes: Vec<C>) -> Self {
675        Self {
676            intercept,
677            params,
678            classes,
679        }
680    }
681
682    pub fn intercept(&self) -> &Array1<F> {
683        &self.intercept
684    }
685
686    pub fn params(&self) -> &Array2<F> {
687        &self.params
688    }
689
690    /// Return non-normalized probabilities (n_samples * n_classes)
691    fn predict_nonorm_probabilities<A: Data<Elem = F>>(&self, x: &ArrayBase<A, Ix2>) -> Array2<F> {
692        x.dot(&self.params) + &self.intercept
693    }
694
695    /// Return normalized probabilities for each output class. The output dimensions are (n_samples
696    /// * n_classes).
697    pub fn predict_probabilities<A: Data<Elem = F>>(&self, x: &ArrayBase<A, Ix2>) -> Array2<F> {
698        let mut probs = self.predict_nonorm_probabilities(x);
699        probs
700            .rows_mut()
701            .into_iter()
702            .for_each(|mut row| softmax_inplace(&mut row));
703        probs
704    }
705
706    /// Get the list of class labels, which maps the numerical class indices to the labels
707    pub fn classes(&self) -> &[C] {
708        &self.classes
709    }
710}
711
712impl<C: PartialOrd + Clone + Default, F: Float, D: Data<Elem = F>>
713    PredictInplace<ArrayBase<D, Ix2>, Array1<C>> for MultiFittedLogisticRegression<F, C>
714{
715    /// Given a feature matrix, predict the classes learned when the model was
716    /// fitted.
717    fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<C>) {
718        assert_eq!(
719            x.nrows(),
720            y.len(),
721            "The number of data points must match the number of output targets."
722        );
723        assert_eq!(
724            x.ncols(),
725            self.params.nrows(),
726            "Number of data features must match the number of features the model was trained with."
727        );
728
729        let probs = self.predict_nonorm_probabilities(x);
730        Zip::from(probs.rows()).and(y).for_each(|prob_row, out| {
731            let idx = prob_row.argmax().unwrap();
732            *out = self.classes[idx].clone();
733        });
734    }
735
736    fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<C> {
737        Array1::default(x.nrows())
738    }
739}
740
741#[derive(Debug, Clone, PartialEq)]
742#[cfg_attr(
743    feature = "serde",
744    derive(Serialize, Deserialize),
745    serde(crate = "serde_crate")
746)]
747pub struct ClassLabel<F, C: PartialOrd> {
748    pub class: C,
749    pub label: F,
750}
751
752#[derive(Debug, Clone, PartialEq)]
753#[cfg_attr(
754    feature = "serde",
755    derive(Serialize, Deserialize),
756    serde(crate = "serde_crate")
757)]
758pub struct BinaryClassLabels<F, C: PartialOrd> {
759    pub pos: ClassLabel<F, C>,
760    pub neg: ClassLabel<F, C>,
761}
762
763/// Internal representation of a logistic regression problem.
764/// This data structure exists to be handed to Argmin.
765struct LogisticRegressionProblem<'a, F: Float, A: Data<Elem = F>, D: Dimension> {
766    x: &'a ArrayBase<A, Ix2>,
767    target: Array<F, D>,
768    alpha: F,
769}
770
771type LogisticRegressionProblem1<'a, F, A> = LogisticRegressionProblem<'a, F, A, Ix1>;
772type LogisticRegressionProblem2<'a, F, A> = LogisticRegressionProblem<'a, F, A, Ix2>;
773
774impl<F: Float, A: Data<Elem = F>> CostFunction for LogisticRegressionProblem1<'_, F, A> {
775    type Param = ArgminParam<F, Ix1>;
776    type Output = F;
777
778    /// Apply the cost function to a parameter `p`
779    fn cost(&self, p: &Self::Param) -> std::result::Result<Self::Output, argmin::core::Error> {
780        let w = p.as_array();
781        let cost = logistic_loss(self.x, &self.target, self.alpha, w);
782        Ok(cost)
783    }
784}
785
786impl<F: Float, A: Data<Elem = F>> Gradient for LogisticRegressionProblem1<'_, F, A> {
787    type Param = ArgminParam<F, Ix1>;
788    type Gradient = ArgminParam<F, Ix1>;
789
790    /// Compute the gradient at parameter `p`.
791    fn gradient(&self, p: &Self::Param) -> std::result::Result<Self::Param, argmin::core::Error> {
792        let w = p.as_array();
793        let grad = ArgminParam(logistic_grad(self.x, &self.target, self.alpha, w));
794        Ok(grad)
795    }
796}
797
798impl<F: Float, A: Data<Elem = F>> CostFunction for LogisticRegressionProblem2<'_, F, A> {
799    type Param = ArgminParam<F, Ix2>;
800    type Output = F;
801
802    /// Apply the cost function to a parameter `p`
803    fn cost(&self, p: &Self::Param) -> std::result::Result<Self::Output, argmin::core::Error> {
804        let w = p.as_array();
805        let cost = multi_logistic_loss(self.x, &self.target, self.alpha, w);
806        Ok(cost)
807    }
808}
809
810impl<F: Float, A: Data<Elem = F>> Gradient for LogisticRegressionProblem2<'_, F, A> {
811    type Param = ArgminParam<F, Ix2>;
812    type Gradient = ArgminParam<F, Ix2>;
813
814    /// Compute the gradient at parameter `p`.
815    fn gradient(&self, p: &Self::Param) -> std::result::Result<Self::Param, argmin::core::Error> {
816        let w = p.as_array();
817        let grad = ArgminParam(multi_logistic_grad(self.x, &self.target, self.alpha, w));
818        Ok(grad)
819    }
820}
821
822trait SolvableProblem<F: Float, D: Dimension>: Gradient + Sized {
823    type Solver: Solver<Self, IterStateType<F, D>>;
824}
825
826impl<F: Float, A: Data<Elem = F>> SolvableProblem<F, Ix1> for LogisticRegressionProblem1<'_, F, A> {
827    type Solver = LBFGSType1<F>;
828}
829
830impl<F: Float, A: Data<Elem = F>> SolvableProblem<F, Ix2> for LogisticRegressionProblem2<'_, F, A> {
831    type Solver = LBFGSType2<F>;
832}
833
834#[cfg(test)]
835mod test {
836    extern crate linfa;
837
838    use super::Error;
839    use super::*;
840    use approx::{assert_abs_diff_eq, assert_relative_eq, AbsDiffEq};
841    use linfa::prelude::*;
842    use ndarray::{array, Array2, Dim, Ix};
843
844    #[test]
845    fn autotraits() {
846        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
847        has_autotraits::<LogisticRegressionParams<f64, Dim<[Ix; 0]>>>();
848        has_autotraits::<LogisticRegressionValidParams<f64, Dim<[Ix; 0]>>>();
849        has_autotraits::<ArgminParam<f64, Dim<[Ix; 0]>>>();
850    }
851
852    /// Test that the logistic loss function works as expected.
853    /// The expected values were obtained from running sklearn's
854    /// _logistic_loss_and_grad function.
855    #[test]
856    fn test_logistic_loss() {
857        let x = array![
858            [0.0],
859            [1.0],
860            [2.0],
861            [3.0],
862            [4.0],
863            [5.0],
864            [6.0],
865            [7.0],
866            [8.0],
867            [9.0]
868        ];
869        let y = array![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
870        let ws = [
871            array![0.0, 0.0],
872            array![0.0, 1.0],
873            array![1.0, 0.0],
874            array![1.0, 1.0],
875            array![0.0, -1.0],
876            array![-1.0, 0.0],
877            array![-1.0, -1.0],
878        ];
879        let alphas = &[0.0, 1.0, 10.0];
880        let expecteds = vec![
881            6.931471805599453,
882            6.931471805599453,
883            6.931471805599453,
884            4.652158847349118,
885            4.652158847349118,
886            4.652158847349118,
887            2.8012999588008323,
888            3.3012999588008323,
889            7.801299958800833,
890            2.783195429782239,
891            3.283195429782239,
892            7.783195429782239,
893            10.652158847349117,
894            10.652158847349117,
895            10.652158847349117,
896            41.80129995880083,
897            42.30129995880083,
898            46.80129995880083,
899            47.78319542978224,
900            48.28319542978224,
901            52.78319542978224,
902        ];
903
904        for ((w, alpha), exp) in ws
905            .iter()
906            .flat_map(|w| alphas.iter().map(move |&alpha| (w, alpha)))
907            .zip(&expecteds)
908        {
909            assert_abs_diff_eq!(logistic_loss(&x, &y, alpha, w), *exp);
910        }
911    }
912
913    /// Test that the logistic grad function works as expected.
914    /// The expected values were obtained from running sklearn's
915    /// _logistic_loss_and_grad function.
916    #[test]
917    fn test_logistic_grad() {
918        let x = array![
919            [0.0],
920            [1.0],
921            [2.0],
922            [3.0],
923            [4.0],
924            [5.0],
925            [6.0],
926            [7.0],
927            [8.0],
928            [9.0]
929        ];
930        let y = array![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
931        let ws = [
932            array![0.0, 0.0],
933            array![0.0, 1.0],
934            array![1.0, 0.0],
935            array![1.0, 1.0],
936            array![0.0, -1.0],
937            array![-1.0, 0.0],
938            array![-1.0, -1.0],
939        ];
940        let alphas = &[0.0, 1.0, 10.0];
941        let expecteds = vec![
942            array![-19.5, -3.],
943            array![-19.5, -3.],
944            array![-19.5, -3.],
945            array![-10.48871543, -1.61364853],
946            array![-10.48871543, -1.61364853],
947            array![-10.48871543, -1.61364853],
948            array![-0.13041554, -0.02852148],
949            array![0.86958446, -0.02852148],
950            array![9.86958446, -0.02852148],
951            array![-0.04834401, -0.01058067],
952            array![0.95165599, -0.01058067],
953            array![9.95165599, -0.01058067],
954            array![-28.51128457, -4.38635147],
955            array![-28.51128457, -4.38635147],
956            array![-28.51128457, -4.38635147],
957            array![-38.86958446, -5.97147852],
958            array![-39.86958446, -5.97147852],
959            array![-48.86958446, -5.97147852],
960            array![-38.95165599, -5.98941933],
961            array![-39.95165599, -5.98941933],
962            array![-48.95165599, -5.98941933],
963        ];
964
965        for ((w, alpha), exp) in ws
966            .iter()
967            .flat_map(|w| alphas.iter().map(move |&alpha| (w, alpha)))
968            .zip(&expecteds)
969        {
970            let actual = logistic_grad(&x, &y, alpha, w);
971            assert!(actual.abs_diff_eq(exp, 1e-8));
972        }
973    }
974
975    #[test]
976    fn simple_example_1() {
977        let log_reg = LogisticRegression::default();
978        let x = array![[-1.0], [-0.01], [0.01], [1.0]];
979        let y = array![0, 0, 1, 1];
980        let dataset = Dataset::new(x, y);
981        let res = log_reg.fit(&dataset).unwrap();
982        assert_abs_diff_eq!(res.intercept(), 0.0);
983        assert!(res.params().abs_diff_eq(&array![0.681], 1e-3));
984        assert_eq!(
985            &res.predict(dataset.records()),
986            dataset.targets().as_single_targets()
987        );
988    }
989
990    #[test]
991    fn simple_example_1_cats_dogs() {
992        let log_reg = LogisticRegression::default();
993        let x = array![[0.01], [1.0], [-1.0], [-0.01]];
994        let y = array!["dog", "dog", "cat", "cat"];
995        let dataset = Dataset::new(x, y);
996        let res = log_reg.fit(&dataset).unwrap();
997        assert_abs_diff_eq!(res.intercept(), 0.0);
998        assert!(res.params().abs_diff_eq(&array![0.681], 1e-3));
999        assert!(res
1000            .predict_probabilities(dataset.records())
1001            .abs_diff_eq(&array![0.501, 0.664, 0.335, 0.498], 1e-3));
1002        assert_eq!(
1003            &res.predict(dataset.records()),
1004            dataset.targets().as_single_targets()
1005        );
1006        assert_eq!(res.labels().pos.class, "dog");
1007        assert_eq!(res.labels().neg.class, "cat");
1008    }
1009
1010    #[test]
1011    fn simple_example_2() {
1012        let log_reg = LogisticRegression::default().alpha(1.0);
1013        let x = array![
1014            [0.0],
1015            [1.0],
1016            [2.0],
1017            [3.0],
1018            [4.0],
1019            [5.0],
1020            [6.0],
1021            [7.0],
1022            [8.0],
1023            [9.0]
1024        ];
1025        let y = array![0, 0, 0, 0, 1, 1, 1, 1, 1, 1];
1026        let dataset = Dataset::new(x, y);
1027        let res = log_reg.fit(&dataset).unwrap();
1028        assert_eq!(
1029            &res.predict(dataset.records()),
1030            dataset.targets().as_single_targets()
1031        );
1032    }
1033
1034    #[test]
1035    fn simple_example_3() {
1036        let x = array![[1.0], [0.0], [1.0], [0.0]];
1037        let y = array![1, 0, 1, 0];
1038        let dataset = DatasetBase::new(x, y);
1039        let model = LogisticRegression::default().fit(&dataset).unwrap();
1040
1041        let pred = model.predict(&dataset.records);
1042        assert_eq!(dataset.targets(), pred);
1043    }
1044
1045    #[test]
1046    fn rejects_mismatching_x_y() {
1047        let log_reg = LogisticRegression::default();
1048        let x = array![[-1.0], [-0.01], [0.01]];
1049        let y = array![0, 0, 1, 1];
1050        let res = log_reg.fit(&Dataset::new(x, y));
1051        assert!(matches!(res.unwrap_err(), Error::MismatchedShapes(3, 4)));
1052    }
1053
1054    #[test]
1055    fn rejects_inf_values() {
1056        let infs = &[f64::INFINITY, f64::NEG_INFINITY, f64::NAN];
1057        let inf_xs: Vec<_> = infs.iter().map(|&inf| array![[1.0], [inf]]).collect();
1058        let log_reg = LogisticRegression::default();
1059        let normal_x = array![[-1.0], [1.0]];
1060        let y = array![0, 1];
1061        for inf_x in &inf_xs {
1062            let res = log_reg.fit(&DatasetBase::new(inf_x.view(), &y));
1063            assert!(matches!(res.unwrap_err(), Error::InvalidValues));
1064        }
1065        for inf in infs {
1066            let log_reg = LogisticRegression::default().alpha(*inf);
1067            let res = log_reg.fit(&DatasetBase::new(normal_x.view(), &y));
1068            assert!(matches!(res.unwrap_err(), Error::InvalidAlpha));
1069        }
1070        let mut non_positives = infs.to_vec();
1071        non_positives.push(-1.0);
1072        non_positives.push(0.0);
1073        for inf in &non_positives {
1074            let log_reg = LogisticRegression::default().gradient_tolerance(*inf);
1075            let res = log_reg.fit(&Dataset::new(normal_x.to_owned(), y.to_owned()));
1076            assert!(matches!(res.unwrap_err(), Error::InvalidGradientTolerance));
1077        }
1078    }
1079
1080    #[test]
1081    fn validates_initial_params() {
1082        let infs = &[f64::INFINITY, f64::NEG_INFINITY, f64::NAN];
1083        let normal_x = array![[-1.0], [1.0]];
1084        let normal_y = array![0, 1];
1085        let dataset = Dataset::new(normal_x, normal_y);
1086        for inf in infs {
1087            let log_reg = LogisticRegression::default().initial_params(array![*inf, 0.0]);
1088            let res = log_reg.fit(&dataset);
1089            assert!(matches!(res.unwrap_err(), Error::InvalidInitialParameters));
1090        }
1091        {
1092            let log_reg = LogisticRegression::default().initial_params(array![0.0, 0.0, 0.0]);
1093            let res = log_reg.fit(&dataset);
1094            assert!(matches!(
1095                res.unwrap_err(),
1096                Error::InitialParameterFeaturesMismatch {
1097                    rows: 3,
1098                    n_features: 2
1099                }
1100            ));
1101        }
1102        {
1103            let log_reg = LogisticRegression::default()
1104                .with_intercept(false)
1105                .initial_params(array![0.0, 0.0]);
1106            let res = log_reg.fit(&dataset);
1107            assert!(matches!(
1108                res.unwrap_err(),
1109                Error::InitialParameterFeaturesMismatch {
1110                    rows: 2,
1111                    n_features: 1
1112                }
1113            ));
1114        }
1115    }
1116
1117    #[test]
1118    fn uses_initial_params() {
1119        let params = array![1.2, -4.12];
1120        let log_reg = LogisticRegression::default()
1121            .initial_params(params)
1122            .max_iterations(5);
1123        let x = array![
1124            [0.0],
1125            [1.0],
1126            [2.0],
1127            [3.0],
1128            [4.0],
1129            [5.0],
1130            [6.0],
1131            [7.0],
1132            [8.0],
1133            [9.0]
1134        ];
1135        let y = array![0, 0, 0, 0, 1, 1, 1, 1, 1, 1];
1136        let dataset = Dataset::new(x, y);
1137        let res = log_reg.fit(&dataset).unwrap();
1138        assert!(res.intercept().abs_diff_eq(&-4.124, 1e-3));
1139        assert!(res.params().abs_diff_eq(&array![1.181], 1e-3));
1140        assert_eq!(
1141            &res.predict(dataset.records()),
1142            dataset.targets().as_single_targets()
1143        );
1144
1145        // Test serialization
1146        #[cfg(feature = "serde")]
1147        {
1148            let ser = rmp_serde::to_vec(&res).unwrap();
1149            let unser: FittedLogisticRegression<f32, f32> = rmp_serde::from_slice(&ser).unwrap();
1150
1151            let x = array![[1.0]];
1152            let y_hat = unser.predict(&x);
1153
1154            assert!(y_hat[0] == 0.0);
1155        }
1156    }
1157
1158    #[test]
1159    fn works_with_f32() {
1160        let log_reg = LogisticRegression::default();
1161        let x: Array2<f32> = array![[-1.0], [-0.01], [0.01], [1.0]];
1162        let y = array![0, 0, 1, 1];
1163        let dataset = Dataset::new(x, y);
1164        let res = log_reg.fit(&dataset).unwrap();
1165        assert_abs_diff_eq!(res.intercept(), 0.0_f32);
1166        assert!(res.params().abs_diff_eq(&array![0.682_f32], 1e-3));
1167        assert_eq!(
1168            &res.predict(dataset.records()),
1169            dataset.targets().as_single_targets()
1170        );
1171    }
1172
1173    #[test]
1174    fn test_log_sum_exp() {
1175        let data = array![[3.3, 0.4, -2.1], [0.4, 2.2, -0.1], [1., 0., -1.]];
1176        let out = log_sum_exp(&data, Axis(1));
1177        assert_abs_diff_eq!(out, array![3.35783, 2.43551, 1.40761], epsilon = 1e-5);
1178    }
1179
1180    #[test]
1181    fn test_softmax() {
1182        let mut data = array![3.3, 5.5, 0.1, -4.4, 8.0];
1183        softmax_inplace(&mut data);
1184        assert_relative_eq!(
1185            data,
1186            array![0.0083324, 0.075200047, 0.000339647, 0.000003773, 0.91612413],
1187            epsilon = 1e-8
1188        );
1189        assert_abs_diff_eq!(data.sum(), 1.0);
1190    }
1191
1192    #[test]
1193    fn test_multi_logistic_loss_grad() {
1194        let x = array![
1195            [0.0, 0.5],
1196            [1.0, -1.0],
1197            [2.0, -2.0],
1198            [3.0, -3.0],
1199            [4.0, -4.0],
1200            [5.0, -5.0],
1201            [6.0, -6.0],
1202            [7.0, -7.0],
1203        ];
1204        let y = array![
1205            [1.0, 0.0, 0.0],
1206            [1.0, 0.0, 0.0],
1207            [0.0, 1.0, 0.0],
1208            [0.0, 1.0, 0.0],
1209            [0.0, 1.0, 0.0],
1210            [0.0, 0.0, 1.0],
1211            [0.0, 0.0, 1.0],
1212            [0.0, 0.0, 1.0],
1213        ];
1214        let params1 = array![[4.4, -1.2, 3.3], [3.4, 0.1, 0.0]];
1215        let params2 = array![[0.001, -3.2, 2.9], [0.1, 4.5, 5.7], [4.5, 2.2, 1.7]];
1216        let alpha = 0.6;
1217
1218        {
1219            let (log_prob, w) = multi_logistic_prob_params(&x, &params1);
1220            assert_abs_diff_eq!(
1221                log_prob,
1222                array![
1223                    [-3.18259845e-01, -1.96825985e+00, -2.01825985e+00],
1224                    [-2.40463987e+00, -4.70463987e+00, -1.04639868e-01],
1225                    [-4.61010168e+00, -9.21010168e+00, -1.01016809e-02],
1226                    [-6.90100829e+00, -1.38010083e+01, -1.00829256e-03],
1227                    [-9.20010104e+00, -1.84001010e+01, -1.01044506e-04],
1228                    [-1.15000101e+01, -2.30000101e+01, -1.01301449e-05],
1229                    [-1.38000010e+01, -2.76000010e+01, -1.01563199e-06],
1230                    [-1.61000001e+01, -3.22000001e+01, -1.01826043e-07],
1231                ],
1232                epsilon = 1e-6
1233            );
1234            assert_abs_diff_eq!(w, params1);
1235            let loss = multi_logistic_loss(&x, &y, alpha, &params1);
1236            assert_abs_diff_eq!(loss, 57.11212197835295, epsilon = 1e-6);
1237            let grad = multi_logistic_grad(&x, &y, alpha, &params1);
1238            assert_abs_diff_eq!(
1239                grad,
1240                array![
1241                    [1.7536815, -9.71074369, 11.85706219],
1242                    [2.79002537, 9.12059357, -9.81061893]
1243                ],
1244                epsilon = 1e-6
1245            );
1246        }
1247
1248        {
1249            let (log_prob, w) = multi_logistic_prob_params(&x, &params2);
1250            assert_abs_diff_eq!(
1251                log_prob,
1252                array![
1253                    [-1.06637742e+00, -1.16637742e+00, -1.06637742e+00],
1254                    [-4.12429463e-03, -9.90512429e+00, -5.50512429e+00],
1255                    [-2.74092305e-04, -1.75022741e+01, -8.20227409e+00],
1256                    [-1.84027855e-05, -2.51030184e+01, -1.09030184e+01],
1257                    [-1.23554225e-06, -3.27040012e+01, -1.36040012e+01],
1258                    [-8.29523046e-08, -4.03050001e+01, -1.63050001e+01],
1259                    [-5.56928016e-09, -4.79060000e+01, -1.90060000e+01],
1260                    [-3.73912013e-10, -5.55070000e+01, -2.17070000e+01]
1261                ],
1262                epsilon = 1e-6
1263            );
1264            assert_abs_diff_eq!(w, params2.slice(s![..params2.nrows() - 1, ..]));
1265            let loss = multi_logistic_loss(&x, &y, alpha, &params2);
1266            assert_abs_diff_eq!(loss, 154.8177958366479, epsilon = 1e-6);
1267            let grad = multi_logistic_grad(&x, &y, alpha, &params2);
1268            assert_abs_diff_eq!(
1269                grad,
1270                array![
1271                    [26.99587549, -10.91995003, -16.25532546],
1272                    [-27.26314882, 11.85569669, 21.58745213],
1273                    [5.33984376, -2.68845675, -2.65138701]
1274                ],
1275                epsilon = 1e-6
1276            );
1277        }
1278    }
1279
1280    #[test]
1281    fn simple_multi_example() {
1282        let x = array![[-1., 0.], [0., 1.], [1., 1.]];
1283        let y = array![2, 1, 0];
1284        let log_reg = MultiLogisticRegression::default()
1285            .alpha(0.1)
1286            .initial_params(Array::zeros((3, 3)));
1287        let dataset = Dataset::new(x, y);
1288        let res = log_reg.fit(&dataset).unwrap();
1289        assert_eq!(res.params().dim(), (2, 3));
1290        assert_eq!(res.intercept().dim(), 3);
1291        assert_eq!(
1292            &res.predict(dataset.records()),
1293            dataset.targets().as_single_targets()
1294        );
1295    }
1296
1297    #[test]
1298    fn simple_multi_example_2() {
1299        let x = array![[1.0], [0.0], [1.0], [0.0]];
1300        let y = array![1, 0, 1, 0];
1301        let dataset = DatasetBase::new(x, y);
1302        let model = MultiLogisticRegression::default().fit(&dataset).unwrap();
1303
1304        let pred = model.predict(&dataset.records);
1305        assert_eq!(dataset.targets(), pred);
1306    }
1307
1308    #[test]
1309    fn simple_multi_example_text() {
1310        let log_reg = MultiLogisticRegression::default().alpha(0.1);
1311        let x = array![[0.1], [1.0], [-1.0], [-0.1]];
1312        let y = array!["dog", "ape", "rocket", "cat"];
1313        let dataset = Dataset::new(x, y);
1314        let res = log_reg.fit(&dataset).unwrap();
1315        assert_eq!(res.params().dim(), (1, 4));
1316        assert_eq!(res.intercept().dim(), 4);
1317        assert_eq!(
1318            &res.predict(dataset.records()),
1319            dataset.targets().as_single_targets()
1320        );
1321    }
1322
1323    #[test]
1324    fn multi_on_binary_problem() {
1325        let log_reg = MultiLogisticRegression::default().alpha(1.0);
1326        let x = array![
1327            [0.0],
1328            [1.0],
1329            [2.0],
1330            [3.0],
1331            [4.0],
1332            [5.0],
1333            [6.0],
1334            [7.0],
1335            [8.0],
1336            [9.0]
1337        ];
1338        let y = array![0, 0, 0, 0, 1, 1, 1, 1, 1, 1];
1339        let dataset = Dataset::new(x, y);
1340        let res = log_reg.fit(&dataset).unwrap();
1341        assert_eq!(res.params().dim(), (1, 2));
1342        assert_eq!(res.intercept().dim(), 2);
1343        assert_eq!(
1344            &res.predict(dataset.records()),
1345            dataset.targets().as_single_targets()
1346        );
1347    }
1348
1349    #[test]
1350    fn reject_num_class_mismatch() {
1351        let n_samples = 4;
1352        let n_classes = 3;
1353        let n_features = 1;
1354        let x = Array2::<f64>::zeros((n_samples, n_features));
1355        let y = array![0, 1, 2, 0];
1356        let dataset = Dataset::new(x, y);
1357
1358        let log_reg = MultiLogisticRegression::default()
1359            .with_intercept(false)
1360            .initial_params(Array::zeros((n_features, n_classes - 1)));
1361        assert!(matches!(
1362            log_reg.fit(&dataset).unwrap_err(),
1363            Error::InitialParameterClassesMismatch {
1364                cols: 2,
1365                n_classes: 3,
1366            }
1367        ));
1368    }
1369
1370    #[test]
1371    fn label_order_independent() {
1372        let x1 = array![[-1.0], [1.0], [-0.5], [0.5]];
1373        let y1 = array!["cat", "dog", "cat", "dog"];
1374
1375        let x2 = array![[1.0], [-1.0], [0.5], [-0.5]];
1376        let y2 = array!["dog", "cat", "dog", "cat"];
1377
1378        let model1 = LogisticRegression::default()
1379            .fit(&Dataset::new(x1, y1))
1380            .unwrap();
1381        let model2 = LogisticRegression::default()
1382            .fit(&Dataset::new(x2, y2))
1383            .unwrap();
1384
1385        assert_eq!(model1.labels().pos.class, "dog");
1386        assert_eq!(model1.labels().neg.class, "cat");
1387        assert_eq!(model2.labels().pos.class, "dog");
1388        assert_eq!(model2.labels().neg.class, "cat");
1389
1390        assert_abs_diff_eq!(model1.intercept(), model2.intercept());
1391        assert!(model1.params().abs_diff_eq(model2.params(), 1e-6));
1392    }
1393}