linfa_logistic/
lib.rs

1//! # Logistic Regression
2//!
3//! ## The Big Picture
4//!
5//! `linfa-logistic` is a crate in the [`linfa`](https://crates.io/crates/linfa) ecosystem, an effort to create a toolkit for classical Machine Learning implemented in pure Rust, akin to Python's `scikit-learn`.
6//!
7//! ## Current state
8//! `linfa-logistic` provides a pure Rust implementation of a [binomial logistic regression model](LogisticRegression) and a [multinomial logistic regression model](MultiLogisticRegression).
9//!
10//! ## Examples
11//!
12//! There is an usage example in the `examples/` directory. To run, use:
13//!
14//! ```bash
15//! $ cargo run --example winequality
16//! ```
17//!
18
19pub mod error;
20
21use crate::error::{Error, Result};
22use argmin::core::{CostFunction, Executor, Gradient, IterState, OptimizationResult, Solver};
23use argmin::solver::linesearch::MoreThuenteLineSearch;
24use argmin::solver::quasinewton::LBFGS;
25use linfa::dataset::AsSingleTargets;
26use linfa::prelude::DatasetBase;
27use linfa::traits::{Fit, PredictInplace};
28use ndarray::{
29    s, Array, Array1, Array2, ArrayBase, ArrayView, ArrayView2, Axis, CowArray, Data, DataMut,
30    Dimension, IntoDimension, Ix1, Ix2, RemoveAxis, Slice, Zip,
31};
32use ndarray_stats::QuantileExt;
33use std::default::Default;
34
35#[cfg(feature = "serde")]
36use serde_crate::de::DeserializeOwned;
37#[cfg(feature = "serde")]
38use serde_crate::{Deserialize, Serialize};
39
40mod argmin_param;
41mod float;
42mod hyperparams;
43
44use argmin_param::*;
45use float::Float;
46use hyperparams::{LogisticRegressionParams, LogisticRegressionValidParams};
47
48/// A two-class logistic regression model.
49///
50/// Logistic regression combines linear models with
51/// the sigmoid function `sigm(x) = 1/(1+exp(-x))`
52/// to learn a family of functions that map the feature space to `[0,1]`.
53///
54/// Logistic regression is used in binary classification
55/// by interpreting the predicted value as the probability that the sample
56/// has label `1`. A threshold can be set in the [fitted model](FittedLogisticRegression) to decide the minimum
57/// probability needed to classify a sample as `1`, which defaults to `0.5`.
58///
59/// In this implementation any binary set of labels can be used, not necessarily `0` and `1`.
60///
61/// l2 regularization is used by this algorithm and is weighted by parameter `alpha`. Setting `alpha`
62/// close to zero removes regularization and the problem solved minimizes only the
63/// empirical risk. On the other hand, setting `alpha` to a high value increases
64/// the weight of the l2 norm of the linear model coefficients in the cost function.
65///
66/// ## Examples
67///
68/// Here's an example on how to train a logistic regression model on the `winequality` dataset
69/// ```rust
70/// use linfa::traits::{Fit, Predict};
71/// use linfa_logistic::LogisticRegression;
72///
73/// // Example on using binary labels different from 0 and 1
74/// let dataset = linfa_datasets::winequality().map_targets(|x| if *x > 6 { "good" } else { "bad" });
75/// let model = LogisticRegression::default().fit(&dataset).unwrap();
76/// let prediction = model.predict(&dataset);
77/// ```
78pub type LogisticRegression<F> = LogisticRegressionParams<F, Ix1>;
79
80/// Validated version of `LogisticRegression`
81pub type ValidLogisticRegression<F> = LogisticRegressionValidParams<F, Ix1>;
82
83/// A multinomial class logistic regression model.
84///
85/// The output labels can map to any discrete feature space, since the algorithm calculates the
86/// likelihood of a feature vector corresponding to any given outcome using the softmax function
87/// `softmax(x) = exp(x) / sum(exp(xi))`
88///
89/// l2 regularization is used by this algorithm and is weighted by parameter `alpha`. Setting `alpha`
90/// close to zero removes regularization and the problem solved minimizes only the
91/// empirical risk. On the other hand, setting `alpha` to a high value increases
92/// the weight of the l2 norm of the linear model coefficients in the cost function.
93pub type MultiLogisticRegression<F> = LogisticRegressionParams<F, Ix2>;
94
95/// Validated version of `MultiLogisticRegression`
96pub type ValidMultiLogisticRegression<F> = LogisticRegressionValidParams<F, Ix2>;
97
98impl<F: Float, D: Dimension> Default for LogisticRegressionParams<F, D> {
99    fn default() -> Self {
100        LogisticRegressionParams::new()
101    }
102}
103
104type LBFGSType<F, D> = LBFGS<
105    MoreThuenteLineSearch<ArgminParam<F, D>, ArgminParam<F, D>, F>,
106    ArgminParam<F, D>,
107    ArgminParam<F, D>,
108    F,
109>;
110type LBFGSType1<F> = LBFGSType<F, Ix1>;
111type LBFGSType2<F> = LBFGSType<F, Ix2>;
112
113type IterStateType<F, D> = IterState<ArgminParam<F, D>, ArgminParam<F, D>, (), (), F>;
114
115impl<F: Float, D: Dimension> LogisticRegressionValidParams<F, D> {
116    /// Create the initial parameters, either from a user supplied array
117    /// or an array of 0s
118    fn setup_init_params(&self, dims: D::Pattern) -> ArgminParam<F, D> {
119        if let Some(params) = self.initial_params.as_ref() {
120            ArgminParam(params.clone())
121        } else {
122            let mut dims = dims.into_dimension();
123            dims.as_array_view_mut()[0] += self.fit_intercept as usize;
124            ArgminParam(Array::zeros(dims))
125        }
126    }
127
128    /// Ensure that `x` and `y` have the right shape and that all data and
129    /// configuration parameters are finite.
130    fn validate_data<A: Data<Elem = F>, B: Data<Elem = F>>(
131        &self,
132        x: &ArrayBase<A, Ix2>,
133        y: &ArrayBase<B, D>,
134    ) -> Result<()> {
135        if x.shape()[0] != y.shape()[0] {
136            return Err(Error::MismatchedShapes(x.shape()[0], y.shape()[0]));
137        }
138        if x.iter().any(|x| !x.is_finite()) || y.iter().any(|y| !y.is_finite()) {
139            return Err(Error::InvalidValues);
140        }
141        self.validate_init_dims(x.shape()[1], y.shape().get(1).copied())?;
142        Ok(())
143    }
144
145    fn validate_init_dims(&self, mut n_features: usize, n_classes: Option<usize>) -> Result<()> {
146        if let Some(params) = self.initial_params.as_ref() {
147            let shape = params.shape();
148            n_features += self.fit_intercept as usize;
149            if n_features != shape[0] {
150                return Err(Error::InitialParameterFeaturesMismatch {
151                    n_features,
152                    rows: shape[0],
153                });
154            }
155            if let Some(n_classes) = n_classes {
156                if n_classes != shape[1] {
157                    return Err(Error::InitialParameterClassesMismatch {
158                        n_classes,
159                        cols: shape[1],
160                    });
161                }
162            }
163        }
164        Ok(())
165    }
166
167    /// Create a `LogisticRegressionProblem`.
168    fn setup_problem<'a, A: Data<Elem = F>>(
169        &self,
170        x: &'a ArrayBase<A, Ix2>,
171        target: Array<F, D>,
172    ) -> LogisticRegressionProblem<'a, F, A, D> {
173        LogisticRegressionProblem {
174            x,
175            target,
176            alpha: self.alpha,
177        }
178    }
179
180    /// Create the LBFGS solver using MoreThuenteLineSearch and set gradient
181    /// tolerance.
182    fn setup_solver(&self) -> LBFGSType<F, D> {
183        let linesearch = MoreThuenteLineSearch::new();
184        LBFGS::new(linesearch, 10)
185            .with_tolerance_grad(self.gradient_tolerance)
186            .unwrap()
187    }
188}
189
190impl<
191        F: Float,
192        #[cfg(feature = "serde")] D: Dimension + Serialize + DeserializeOwned,
193        #[cfg(not(feature = "serde"))] D: Dimension,
194    > LogisticRegressionValidParams<F, D>
195{
196    /// Run the LBFGS solver until it converges or runs out of iterations.
197    fn run_solver<P: SolvableProblem<F, D>>(
198        &self,
199        problem: P,
200        solver: P::Solver,
201        init_params: ArgminParam<F, D>,
202    ) -> Result<OptimizationResult<P, P::Solver, IterStateType<F, D>>> {
203        Executor::new(problem, solver)
204            .configure(|state| state.param(init_params).max_iters(self.max_iterations))
205            .run()
206            .map_err(move |err| err.into())
207    }
208}
209
210impl<C: Ord + Clone, F: Float, D: Data<Elem = F>, T: AsSingleTargets<Elem = C>>
211    Fit<ArrayBase<D, Ix2>, T, Error> for ValidLogisticRegression<F>
212{
213    type Object = FittedLogisticRegression<F, C>;
214
215    /// Given a 2-dimensional feature matrix array `x` with shape
216    /// (n_samples, n_features) and an array of target classes to predict,
217    /// create a `FittedLinearRegression` object which allows making
218    /// predictions.
219    ///
220    /// The array of target classes `y` must have exactly two discrete values, (e.g. 0 and 1, "cat"
221    /// and "dog", ...), which represent the two different classes the model is supposed to
222    /// predict.
223    ///
224    /// The array `y` must also have exactly `n_samples` items, i.e.
225    /// exactly as many items as there are rows in the feature matrix `x`.
226    ///
227    /// This method returns an error if any of the preconditions are violated,
228    /// i.e. any values are `Inf` or `NaN`, `y` doesn't have as many items as
229    /// `x` has rows, or if other parameters (gradient_tolerance, alpha) have
230    /// been set to inalid values.
231    fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
232        let (x, y) = (dataset.records(), dataset.targets());
233        let (labels, target) = label_classes(y)?;
234        self.validate_data(x, &target)?;
235        let problem = self.setup_problem(x, target);
236        let solver = self.setup_solver();
237        let init_params = self.setup_init_params(x.ncols());
238        let result = self.run_solver(problem, solver, init_params)?;
239
240        let params = result
241            .state
242            .best_param
243            .unwrap_or(self.setup_init_params(x.ncols()));
244        let (w, intercept) = convert_params(x.ncols(), params.as_array());
245        Ok(FittedLogisticRegression::new(
246            *intercept.view().into_scalar(),
247            w.to_owned(),
248            labels,
249        ))
250    }
251}
252
253impl<C: Ord + Clone, F: Float, D: Data<Elem = F>, T: AsSingleTargets<Elem = C>>
254    Fit<ArrayBase<D, Ix2>, T, Error> for ValidMultiLogisticRegression<F>
255{
256    type Object = MultiFittedLogisticRegression<F, C>;
257
258    /// Given a 2-dimensional feature matrix array `x` with shape
259    /// (n_samples, n_features) and an array of target classes to predict,
260    /// create a `MultiFittedLogisticRegression` object which allows making
261    /// predictions. The target classes can have any number of discrete values.
262    ///
263    /// This method returns an error if any of the preconditions are violated,
264    /// i.e. any values are `Inf` or `NaN`, `y` doesn't have as many items as
265    /// `x` has rows, or if other parameters (gradient_tolerance, alpha) have
266    /// been set to inalid values. The input features are also strongly recommended to be
267    /// normalized to ensure numerical stability.
268    fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
269        let (x, y) = (dataset.records(), dataset.targets());
270        let (classes, target) = label_classes_multi(y)?;
271        self.validate_data(x, &target)?;
272        let problem = self.setup_problem(x, target);
273        let solver = self.setup_solver();
274        let init_params = self.setup_init_params((x.ncols(), classes.len()));
275        let result = self.run_solver(problem, solver, init_params)?;
276
277        let params = result
278            .state
279            .best_param
280            .unwrap_or(self.setup_init_params((x.ncols(), classes.len())));
281        let (w, intercept) = convert_params(x.ncols(), params.as_array());
282        Ok(MultiFittedLogisticRegression::new(
283            intercept.to_owned(),
284            w.to_owned(),
285            classes,
286        ))
287    }
288}
289
290/// Identify the distinct values of the classes `y` and associate
291/// the target labels `-1.0` and `1.0` to it. -1.0 always labels the
292/// smaller class (by PartialOrd) and 1.0 always labels the larger
293/// class.
294///
295/// It is an error to have more than two classes.
296fn label_classes<F, T, C>(y: T) -> Result<(BinaryClassLabels<F, C>, Array1<F>)>
297where
298    F: Float,
299    T: AsSingleTargets<Elem = C>,
300    C: Ord + Clone,
301{
302    let y = y.as_single_targets();
303
304    // counts the instances of two distinct class labels
305    let mut binary_classes = [None, None];
306    // find binary classes of our target dataset
307    for class in y {
308        binary_classes = match binary_classes {
309            // count the first class label
310            [None, None] => [Some((class, 1)), None],
311            // if the class has already been counted, increment the count
312            [Some((c, count)), c2] if c == class => [Some((class, count + 1)), c2],
313            [c1, Some((c, count))] if c == class => [c1, Some((class, count + 1))],
314            // count the second class label
315            [Some(c1), None] => [Some(c1), Some((class, 1))],
316
317            // should not be possible
318            [None, Some(_)] => unreachable!("impossible binary class array"),
319            // found 3rd distinct class
320            [Some(_), Some(_)] => return Err(Error::TooManyClasses),
321        };
322    }
323
324    let (pos_class, neg_class) = match binary_classes {
325        [Some(a), Some(b)] => (a, b),
326        _ => return Err(Error::TooFewClasses),
327    };
328
329    let mut target_array = y
330        .into_iter()
331        .map(|x| {
332            if x == pos_class.0 {
333                F::POSITIVE_LABEL
334            } else {
335                F::NEGATIVE_LABEL
336            }
337        })
338        .collect::<Array1<_>>();
339
340    let (pos_cl, neg_cl) = if pos_class.1 < neg_class.1 {
341        // If we found the larger class first, flip the sign in the target
342        // vector, so that -1.0 is always the label for the smaller class
343        // and 1.0 the label for the larger class
344        target_array *= -F::one();
345        (neg_class.0.clone(), pos_class.0.clone())
346    } else {
347        (pos_class.0.clone(), neg_class.0.clone())
348    };
349
350    Ok((
351        BinaryClassLabels {
352            pos: ClassLabel {
353                class: pos_cl,
354                label: F::POSITIVE_LABEL,
355            },
356            neg: ClassLabel {
357                class: neg_cl,
358                label: F::NEGATIVE_LABEL,
359            },
360        },
361        target_array,
362    ))
363}
364
365/// Identify the distinct values of the classes in `y` and map each value to an integer. Smaller
366/// classes (by `PartialOrd`) map to smaller integers. Returns the mapping along with a one-hot
367/// encoding of the numerical labels corresponding to `y`.
368fn label_classes_multi<F, T, C>(y: T) -> Result<(Vec<C>, Array2<F>)>
369where
370    F: Float,
371    T: AsSingleTargets<Elem = C>,
372    C: Ord + Clone,
373{
374    let y_single_target = y.as_single_targets();
375    let mut classes = y_single_target.to_vec();
376    // Dedup the list of classes
377    classes.sort();
378    classes.dedup();
379
380    let mut onehot = Array2::zeros((y_single_target.len(), classes.len()));
381    Zip::from(onehot.rows_mut())
382        .and(&y_single_target)
383        .for_each(|mut oh_row, cls| {
384            let idx = classes.binary_search(cls).unwrap();
385            oh_row[idx] = F::one();
386        });
387    Ok((classes, onehot))
388}
389
390/// Conditionally split the feature vector `w` into parameter vector and
391/// intercept parameter.
392/// Dimensions of `w` are either (f) or (f, n_classes)
393fn convert_params<F: Float, D: Dimension + RemoveAxis>(
394    n_features: usize,
395    w: &Array<F, D>,
396) -> (ArrayView<F, D>, CowArray<F, D::Smaller>) {
397    let nrows = w.shape()[0];
398    if n_features == nrows {
399        (
400            w.view(),
401            Array::zeros(w.raw_dim().remove_axis(Axis(0))).into(),
402        )
403    } else if n_features + 1 == nrows {
404        (
405            w.slice_axis(Axis(0), Slice::from(..n_features)),
406            w.index_axis(Axis(0), n_features).into(),
407        )
408    } else {
409        panic!(
410            "Unexpected length of parameter vector `w`, exected {} or {}, found {}",
411            n_features,
412            n_features + 1,
413            nrows
414        );
415    }
416}
417
418/// The logistic function
419fn logistic<F: linfa::Float>(x: F) -> F {
420    F::one() / (F::one() + (-x).exp())
421}
422
423/// A numerically stable version of the log of the logistic function.
424///
425/// Taken from scikit-learn
426/// https://github.com/scikit-learn/scikit-learn/blob/0.23.1/sklearn/utils/_logistic_sigmoid.pyx
427///
428/// See the blog post describing this implementation:
429/// http://fa.bianp.net/blog/2013/numerical-optimizers-for-logistic-regression/
430fn log_logistic<F: linfa::Float>(x: F) -> F {
431    if x > F::zero() {
432        -(F::one() + (-x).exp()).ln()
433    } else {
434        x - (F::one() + x.exp()).ln()
435    }
436}
437
438/// Finds the log of the sum of exponents across a specific axis in a numerically stable way. More
439/// specifically, computes `ln(exp(x1) + exp(x2) + exp(e3) + ...)` across an axis.
440///
441/// Based off this implementation: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.logsumexp.html
442fn log_sum_exp<F: linfa::Float, A: Data<Elem = F>>(
443    m: &ArrayBase<A, Ix2>,
444    axis: Axis,
445) -> Array<F, Ix1> {
446    // Find max value of the array
447    let max = m.iter().copied().reduce(F::max).unwrap();
448    // Computes `max + ln(exp(x1-max) + exp(x2-max) + exp(x3-max) + ...)`, which is equal to the
449    // log_sum_exp formula
450    let reduced = m.fold_axis(axis, F::zero(), |acc, elem| *acc + (*elem - max).exp());
451    reduced.mapv_into(|e| e.max(F::cast(1e-15)).ln() + max)
452}
453
454/// Computes `exp(n - max) / sum(exp(n- max))`, which is a numerically stable version of softmax
455fn softmax_inplace<F: linfa::Float, A: DataMut<Elem = F>>(v: &mut ArrayBase<A, Ix1>) {
456    let max = v.iter().copied().reduce(F::max).unwrap();
457    v.mapv_inplace(|n| (n - max).exp());
458    let sum = v.sum();
459    v.mapv_inplace(|n| n / sum);
460}
461
462/// Computes the logistic loss assuming the training labels $y \in {-1, 1}$
463///
464/// Because the logistic function fullfills $\sigma(-z) = 1 - \sigma(z)$
465/// we can write $P(y=1|z) = \sigma(z) = \sigma(yz)$ and
466/// $P(y=-1|z) = 1 - P(y=1|z) = 1 - \sigma(z) = \sigma(-z) = \sigma(yz)$, so
467/// $P(y|z) = \sigma(yz)$ for both $y=1$ and $y=-1$.
468///
469/// Thus, the log loss can be written as
470/// $$-\sum_{i=1}^{N} \log(\sigma(y_i z_i)) + \frac{\alpha}{2}\text{params}^T\text{params}$$
471fn logistic_loss<F: Float, A: Data<Elem = F>>(
472    x: &ArrayBase<A, Ix2>,
473    y: &Array1<F>,
474    alpha: F,
475    w: &Array1<F>,
476) -> F {
477    let n_features = x.shape()[1];
478    let (params, intercept) = convert_params(n_features, w);
479    let yz = x.dot(&params.into_shape((params.len(), 1)).unwrap()) + intercept;
480    let len = yz.len();
481    let mut yz = yz.into_shape(len).unwrap() * y;
482    yz.mapv_inplace(log_logistic);
483    -yz.sum() + F::cast(0.5) * alpha * params.dot(&params)
484}
485
486/// Computes the gradient of the logistic loss function
487fn logistic_grad<F: Float, A: Data<Elem = F>>(
488    x: &ArrayBase<A, Ix2>,
489    y: &Array1<F>,
490    alpha: F,
491    w: &Array1<F>,
492) -> Array1<F> {
493    let n_features = x.shape()[1];
494    let (params, intercept) = convert_params(n_features, w);
495    let yz = x.dot(&params.into_shape((params.len(), 1)).unwrap()) + intercept;
496    let len = yz.len();
497    let mut yz = yz.into_shape(len).unwrap() * y;
498    yz.mapv_inplace(logistic);
499    yz -= F::one();
500    yz *= y;
501    if w.len() == n_features + 1 {
502        let mut grad = Array::zeros(w.len());
503        grad.slice_mut(s![..n_features])
504            .assign(&(x.t().dot(&yz) + (&params * alpha)));
505        grad[n_features] = yz.sum();
506        grad
507    } else {
508        x.t().dot(&yz) + (&params * alpha)
509    }
510}
511
512/// Compute the log of probabilities, which is `log(softmax(H))`, where H is `X . W + b`. Also
513/// returns `W` without the intercept.
514/// `Y` is the output (n_samples * n_classes), `X` is the input (n_samples * n_features), `W` is the
515/// params (n_features * n_classes), `b` is the intercept vector (n_classes).
516fn multi_logistic_prob_params<'a, F: Float, A: Data<Elem = F>>(
517    x: &ArrayBase<A, Ix2>,
518    w: &'a Array2<F>, // This parameter includes `W` and `b`
519) -> (Array2<F>, ArrayView2<'a, F>) {
520    let n_features = x.shape()[1];
521    let (params, intercept) = convert_params(n_features, w);
522    // Compute H
523    let h = x.dot(&params) + intercept;
524    // This computes `H - log(sum(exp(H)))`, which is equal to
525    // `log(softmax(H)) = log(exp(H) / sum(exp(H)))`
526    let log_prob = &h - log_sum_exp(&h, Axis(1)).into_shape((h.nrows(), 1)).unwrap();
527    (log_prob, params)
528}
529
530/// Computes loss function of `-sum(Y * log(softmax(H))) + alpha/2 * norm(W)`
531fn multi_logistic_loss<F: Float, A: Data<Elem = F>>(
532    x: &ArrayBase<A, Ix2>,
533    y: &Array2<F>,
534    alpha: F,
535    w: &Array2<F>,
536) -> F {
537    let (log_prob, params) = multi_logistic_prob_params(x, w);
538    // Calculate loss
539    -elem_dot(&log_prob, y) + F::cast(0.5) * alpha * elem_dot(&params, &params)
540}
541
542/// Computes multinomial gradients for `W` and `b` and combine them.
543/// Gradient for `W` is `Xt . (softmax(H) - Y) + alpha * W`.
544/// Gradient for `b` is `sum(softmax(H) - Y)`.
545fn multi_logistic_grad<F: Float, A: Data<Elem = F>>(
546    x: &ArrayBase<A, Ix2>,
547    y: &Array2<F>,
548    alpha: F,
549    w: &Array2<F>,
550) -> Array2<F> {
551    let (log_prob, params) = multi_logistic_prob_params(x, w);
552    let (n_features, n_classes) = params.dim();
553    let intercept = w.nrows() > n_features;
554    let mut grad = Array::zeros((n_features + intercept as usize, n_classes));
555
556    // This value is `softmax(H)`
557    let prob = log_prob.mapv_into(num_traits::Float::exp);
558    let diff = prob - y;
559    // Compute gradient for `W` and place it at start of the grad matrix
560    let dw = x.t().dot(&diff) + (&params * alpha);
561    grad.slice_mut(s![..n_features, ..]).assign(&dw);
562    // Compute gradient for `b` and place it at end of grad matrix
563    if intercept {
564        grad.row_mut(n_features).assign(&diff.sum_axis(Axis(0)));
565    }
566    grad
567}
568
569/// A fitted logistic regression which can make predictions
570#[derive(Debug, Clone, PartialEq)]
571#[cfg_attr(
572    feature = "serde",
573    derive(Serialize, Deserialize),
574    serde(crate = "serde_crate"),
575    serde(bound(deserialize = "C: Deserialize<'de>"))
576)]
577pub struct FittedLogisticRegression<F: Float, C: PartialOrd + Clone> {
578    threshold: F,
579    intercept: F,
580    params: Array1<F>,
581    labels: BinaryClassLabels<F, C>,
582}
583
584impl<F: Float, C: PartialOrd + Clone> FittedLogisticRegression<F, C> {
585    fn new(
586        intercept: F,
587        params: Array1<F>,
588        labels: BinaryClassLabels<F, C>,
589    ) -> FittedLogisticRegression<F, C> {
590        FittedLogisticRegression {
591            threshold: F::cast(0.5),
592            intercept,
593            params,
594            labels,
595        }
596    }
597
598    /// Set the probability threshold for which the 'positive' class will be
599    /// predicted. Defaults to 0.5.
600    pub fn set_threshold(mut self, threshold: F) -> FittedLogisticRegression<F, C> {
601        if threshold < F::zero() || threshold > F::one() {
602            panic!("FittedLogisticRegression::set_threshold: threshold needs to be between 0.0 and 1.0");
603        }
604        self.threshold = threshold;
605        self
606    }
607
608    pub fn intercept(&self) -> F {
609        self.intercept
610    }
611
612    pub fn params(&self) -> &Array1<F> {
613        &self.params
614    }
615
616    /// Get the model positive and negative classes mapped to their
617    /// corresponding problem input labels.
618    pub fn labels(&self) -> &BinaryClassLabels<F, C> {
619        &self.labels
620    }
621
622    /// Given a feature matrix, predict the probabilities that a sample
623    /// should be classified as the larger of the two classes learned when the
624    /// model was fitted.
625    pub fn predict_probabilities<A: Data<Elem = F>>(&self, x: &ArrayBase<A, Ix2>) -> Array1<F> {
626        let mut probs = x.dot(&self.params) + self.intercept;
627        probs.mapv_inplace(logistic);
628        probs
629    }
630}
631
632impl<C: PartialOrd + Clone + Default, F: Float, D: Data<Elem = F>>
633    PredictInplace<ArrayBase<D, Ix2>, Array1<C>> for FittedLogisticRegression<F, C>
634{
635    /// Given a feature matrix, predict the classes learned when the model was
636    /// fitted.
637    fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<C>) {
638        assert_eq!(
639            x.nrows(),
640            y.len(),
641            "The number of data points must match the number of output targets."
642        );
643        assert_eq!(
644            x.ncols(),
645            self.params.len(),
646            "Number of data features must match the number of features the model was trained with."
647        );
648
649        let pos_class = &self.labels.pos.class;
650        let neg_class = &self.labels.neg.class;
651        Zip::from(&self.predict_probabilities(x))
652            .and(y)
653            .for_each(|prob, out| {
654                *out = if *prob >= self.threshold {
655                    pos_class.clone()
656                } else {
657                    neg_class.clone()
658                }
659            });
660    }
661
662    fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<C> {
663        Array1::default(x.nrows())
664    }
665}
666
667/// A fitted multinomial logistic regression which can make predictions
668#[derive(Debug, Clone, PartialEq, Eq)]
669#[cfg_attr(
670    feature = "serde",
671    derive(Serialize, Deserialize),
672    serde(crate = "serde_crate")
673)]
674pub struct MultiFittedLogisticRegression<F, C: PartialOrd + Clone> {
675    intercept: Array1<F>,
676    params: Array2<F>,
677    classes: Vec<C>,
678}
679
680impl<F: Float, C: PartialOrd + Clone> MultiFittedLogisticRegression<F, C> {
681    fn new(intercept: Array1<F>, params: Array2<F>, classes: Vec<C>) -> Self {
682        Self {
683            intercept,
684            params,
685            classes,
686        }
687    }
688
689    pub fn intercept(&self) -> &Array1<F> {
690        &self.intercept
691    }
692
693    pub fn params(&self) -> &Array2<F> {
694        &self.params
695    }
696
697    /// Return non-normalized probabilities (n_samples * n_classes)
698    fn predict_nonorm_probabilities<A: Data<Elem = F>>(&self, x: &ArrayBase<A, Ix2>) -> Array2<F> {
699        x.dot(&self.params) + &self.intercept
700    }
701
702    /// Return normalized probabilities for each output class. The output dimensions are (n_samples
703    /// * n_classes).
704    pub fn predict_probabilities<A: Data<Elem = F>>(&self, x: &ArrayBase<A, Ix2>) -> Array2<F> {
705        let mut probs = self.predict_nonorm_probabilities(x);
706        probs
707            .rows_mut()
708            .into_iter()
709            .for_each(|mut row| softmax_inplace(&mut row));
710        probs
711    }
712
713    /// Get the list of class labels, which maps the numerical class indices to the labels
714    pub fn classes(&self) -> &[C] {
715        &self.classes
716    }
717}
718
719impl<C: PartialOrd + Clone + Default, F: Float, D: Data<Elem = F>>
720    PredictInplace<ArrayBase<D, Ix2>, Array1<C>> for MultiFittedLogisticRegression<F, C>
721{
722    /// Given a feature matrix, predict the classes learned when the model was
723    /// fitted.
724    fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<C>) {
725        assert_eq!(
726            x.nrows(),
727            y.len(),
728            "The number of data points must match the number of output targets."
729        );
730        assert_eq!(
731            x.ncols(),
732            self.params.nrows(),
733            "Number of data features must match the number of features the model was trained with."
734        );
735
736        let probs = self.predict_nonorm_probabilities(x);
737        Zip::from(probs.rows()).and(y).for_each(|prob_row, out| {
738            let idx = prob_row.argmax().unwrap();
739            *out = self.classes[idx].clone();
740        });
741    }
742
743    fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<C> {
744        Array1::default(x.nrows())
745    }
746}
747
748#[derive(Debug, Clone, PartialEq)]
749#[cfg_attr(
750    feature = "serde",
751    derive(Serialize, Deserialize),
752    serde(crate = "serde_crate")
753)]
754pub struct ClassLabel<F, C: PartialOrd> {
755    pub class: C,
756    pub label: F,
757}
758
759#[derive(Debug, Clone, PartialEq)]
760#[cfg_attr(
761    feature = "serde",
762    derive(Serialize, Deserialize),
763    serde(crate = "serde_crate")
764)]
765pub struct BinaryClassLabels<F, C: PartialOrd> {
766    pub pos: ClassLabel<F, C>,
767    pub neg: ClassLabel<F, C>,
768}
769
770/// Internal representation of a logistic regression problem.
771/// This data structure exists to be handed to Argmin.
772struct LogisticRegressionProblem<'a, F: Float, A: Data<Elem = F>, D: Dimension> {
773    x: &'a ArrayBase<A, Ix2>,
774    target: Array<F, D>,
775    alpha: F,
776}
777
778type LogisticRegressionProblem1<'a, F, A> = LogisticRegressionProblem<'a, F, A, Ix1>;
779type LogisticRegressionProblem2<'a, F, A> = LogisticRegressionProblem<'a, F, A, Ix2>;
780
781impl<F: Float, A: Data<Elem = F>> CostFunction for LogisticRegressionProblem1<'_, F, A> {
782    type Param = ArgminParam<F, Ix1>;
783    type Output = F;
784
785    /// Apply the cost function to a parameter `p`
786    fn cost(&self, p: &Self::Param) -> std::result::Result<Self::Output, argmin::core::Error> {
787        let w = p.as_array();
788        let cost = logistic_loss(self.x, &self.target, self.alpha, w);
789        Ok(cost)
790    }
791}
792
793impl<F: Float, A: Data<Elem = F>> Gradient for LogisticRegressionProblem1<'_, F, A> {
794    type Param = ArgminParam<F, Ix1>;
795    type Gradient = ArgminParam<F, Ix1>;
796
797    /// Compute the gradient at parameter `p`.
798    fn gradient(&self, p: &Self::Param) -> std::result::Result<Self::Param, argmin::core::Error> {
799        let w = p.as_array();
800        let grad = ArgminParam(logistic_grad(self.x, &self.target, self.alpha, w));
801        Ok(grad)
802    }
803}
804
805impl<F: Float, A: Data<Elem = F>> CostFunction for LogisticRegressionProblem2<'_, F, A> {
806    type Param = ArgminParam<F, Ix2>;
807    type Output = F;
808
809    /// Apply the cost function to a parameter `p`
810    fn cost(&self, p: &Self::Param) -> std::result::Result<Self::Output, argmin::core::Error> {
811        let w = p.as_array();
812        let cost = multi_logistic_loss(self.x, &self.target, self.alpha, w);
813        Ok(cost)
814    }
815}
816
817impl<F: Float, A: Data<Elem = F>> Gradient for LogisticRegressionProblem2<'_, F, A> {
818    type Param = ArgminParam<F, Ix2>;
819    type Gradient = ArgminParam<F, Ix2>;
820
821    /// Compute the gradient at parameter `p`.
822    fn gradient(&self, p: &Self::Param) -> std::result::Result<Self::Param, argmin::core::Error> {
823        let w = p.as_array();
824        let grad = ArgminParam(multi_logistic_grad(self.x, &self.target, self.alpha, w));
825        Ok(grad)
826    }
827}
828
829trait SolvableProblem<F: Float, D: Dimension>: Gradient + Sized {
830    type Solver: Solver<Self, IterStateType<F, D>>;
831}
832
833impl<F: Float, A: Data<Elem = F>> SolvableProblem<F, Ix1> for LogisticRegressionProblem1<'_, F, A> {
834    type Solver = LBFGSType1<F>;
835}
836
837impl<F: Float, A: Data<Elem = F>> SolvableProblem<F, Ix2> for LogisticRegressionProblem2<'_, F, A> {
838    type Solver = LBFGSType2<F>;
839}
840
841#[cfg(test)]
842mod test {
843    extern crate linfa;
844
845    use super::Error;
846    use super::*;
847    use approx::{assert_abs_diff_eq, assert_relative_eq, AbsDiffEq};
848    use linfa::prelude::*;
849    use ndarray::{array, Array2, Dim, Ix};
850
851    #[test]
852    fn autotraits() {
853        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
854        has_autotraits::<LogisticRegressionParams<f64, Dim<[Ix; 0]>>>();
855        has_autotraits::<LogisticRegressionValidParams<f64, Dim<[Ix; 0]>>>();
856        has_autotraits::<ArgminParam<f64, Dim<[Ix; 0]>>>();
857    }
858
859    /// Test that the logistic loss function works as expected.
860    /// The expected values were obtained from running sklearn's
861    /// _logistic_loss_and_grad function.
862    #[test]
863    fn test_logistic_loss() {
864        let x = array![
865            [0.0],
866            [1.0],
867            [2.0],
868            [3.0],
869            [4.0],
870            [5.0],
871            [6.0],
872            [7.0],
873            [8.0],
874            [9.0]
875        ];
876        let y = array![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
877        let ws = vec![
878            array![0.0, 0.0],
879            array![0.0, 1.0],
880            array![1.0, 0.0],
881            array![1.0, 1.0],
882            array![0.0, -1.0],
883            array![-1.0, 0.0],
884            array![-1.0, -1.0],
885        ];
886        let alphas = &[0.0, 1.0, 10.0];
887        let expecteds = vec![
888            6.931471805599453,
889            6.931471805599453,
890            6.931471805599453,
891            4.652158847349118,
892            4.652158847349118,
893            4.652158847349118,
894            2.8012999588008323,
895            3.3012999588008323,
896            7.801299958800833,
897            2.783195429782239,
898            3.283195429782239,
899            7.783195429782239,
900            10.652158847349117,
901            10.652158847349117,
902            10.652158847349117,
903            41.80129995880083,
904            42.30129995880083,
905            46.80129995880083,
906            47.78319542978224,
907            48.28319542978224,
908            52.78319542978224,
909        ];
910
911        for ((w, alpha), exp) in ws
912            .iter()
913            .flat_map(|w| alphas.iter().map(move |&alpha| (w, alpha)))
914            .zip(&expecteds)
915        {
916            assert_abs_diff_eq!(logistic_loss(&x, &y, alpha, w), *exp);
917        }
918    }
919
920    /// Test that the logistic grad function works as expected.
921    /// The expected values were obtained from running sklearn's
922    /// _logistic_loss_and_grad function.
923    #[test]
924    fn test_logistic_grad() {
925        let x = array![
926            [0.0],
927            [1.0],
928            [2.0],
929            [3.0],
930            [4.0],
931            [5.0],
932            [6.0],
933            [7.0],
934            [8.0],
935            [9.0]
936        ];
937        let y = array![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
938        let ws = vec![
939            array![0.0, 0.0],
940            array![0.0, 1.0],
941            array![1.0, 0.0],
942            array![1.0, 1.0],
943            array![0.0, -1.0],
944            array![-1.0, 0.0],
945            array![-1.0, -1.0],
946        ];
947        let alphas = &[0.0, 1.0, 10.0];
948        let expecteds = vec![
949            array![-19.5, -3.],
950            array![-19.5, -3.],
951            array![-19.5, -3.],
952            array![-10.48871543, -1.61364853],
953            array![-10.48871543, -1.61364853],
954            array![-10.48871543, -1.61364853],
955            array![-0.13041554, -0.02852148],
956            array![0.86958446, -0.02852148],
957            array![9.86958446, -0.02852148],
958            array![-0.04834401, -0.01058067],
959            array![0.95165599, -0.01058067],
960            array![9.95165599, -0.01058067],
961            array![-28.51128457, -4.38635147],
962            array![-28.51128457, -4.38635147],
963            array![-28.51128457, -4.38635147],
964            array![-38.86958446, -5.97147852],
965            array![-39.86958446, -5.97147852],
966            array![-48.86958446, -5.97147852],
967            array![-38.95165599, -5.98941933],
968            array![-39.95165599, -5.98941933],
969            array![-48.95165599, -5.98941933],
970        ];
971
972        for ((w, alpha), exp) in ws
973            .iter()
974            .flat_map(|w| alphas.iter().map(move |&alpha| (w, alpha)))
975            .zip(&expecteds)
976        {
977            let actual = logistic_grad(&x, &y, alpha, w);
978            assert!(actual.abs_diff_eq(exp, 1e-8));
979        }
980    }
981
982    #[test]
983    fn simple_example_1() {
984        let log_reg = LogisticRegression::default();
985        let x = array![[-1.0], [-0.01], [0.01], [1.0]];
986        let y = array![0, 0, 1, 1];
987        let dataset = Dataset::new(x, y);
988        let res = log_reg.fit(&dataset).unwrap();
989        assert_abs_diff_eq!(res.intercept(), 0.0);
990        assert!(res.params().abs_diff_eq(&array![-0.681], 1e-3));
991        assert_eq!(
992            &res.predict(dataset.records()),
993            dataset.targets().as_single_targets()
994        );
995    }
996
997    #[test]
998    fn simple_example_1_cats_dogs() {
999        let log_reg = LogisticRegression::default();
1000        let x = array![[0.01], [1.0], [-1.0], [-0.01]];
1001        let y = array!["dog", "dog", "cat", "cat"];
1002        let dataset = Dataset::new(x, y);
1003        let res = log_reg.fit(&dataset).unwrap();
1004        assert_abs_diff_eq!(res.intercept(), 0.0);
1005        assert!(res.params().abs_diff_eq(&array![0.681], 1e-3));
1006        assert!(res
1007            .predict_probabilities(dataset.records())
1008            .abs_diff_eq(&array![0.501, 0.664, 0.335, 0.498], 1e-3));
1009        assert_eq!(
1010            &res.predict(dataset.records()),
1011            dataset.targets().as_single_targets()
1012        );
1013        assert_eq!(res.labels().pos.class, "dog");
1014        assert_eq!(res.labels().neg.class, "cat");
1015    }
1016
1017    #[test]
1018    fn simple_example_2() {
1019        let log_reg = LogisticRegression::default().alpha(1.0);
1020        let x = array![
1021            [0.0],
1022            [1.0],
1023            [2.0],
1024            [3.0],
1025            [4.0],
1026            [5.0],
1027            [6.0],
1028            [7.0],
1029            [8.0],
1030            [9.0]
1031        ];
1032        let y = array![0, 0, 0, 0, 1, 1, 1, 1, 1, 1];
1033        let dataset = Dataset::new(x, y);
1034        let res = log_reg.fit(&dataset).unwrap();
1035        assert_eq!(
1036            &res.predict(dataset.records()),
1037            dataset.targets().as_single_targets()
1038        );
1039    }
1040
1041    #[test]
1042    fn simple_example_3() {
1043        let x = array![[1.0], [0.0], [1.0], [0.0]];
1044        let y = array![1, 0, 1, 0];
1045        let dataset = DatasetBase::new(x, y);
1046        let model = LogisticRegression::default().fit(&dataset).unwrap();
1047
1048        let pred = model.predict(&dataset.records);
1049        assert_eq!(dataset.targets(), pred);
1050    }
1051
1052    #[test]
1053    fn rejects_mismatching_x_y() {
1054        let log_reg = LogisticRegression::default();
1055        let x = array![[-1.0], [-0.01], [0.01]];
1056        let y = array![0, 0, 1, 1];
1057        let res = log_reg.fit(&Dataset::new(x, y));
1058        assert!(matches!(res.unwrap_err(), Error::MismatchedShapes(3, 4)));
1059    }
1060
1061    #[test]
1062    fn rejects_inf_values() {
1063        let infs = &[f64::INFINITY, f64::NEG_INFINITY, f64::NAN];
1064        let inf_xs: Vec<_> = infs.iter().map(|&inf| array![[1.0], [inf]]).collect();
1065        let log_reg = LogisticRegression::default();
1066        let normal_x = array![[-1.0], [1.0]];
1067        let y = array![0, 1];
1068        for inf_x in &inf_xs {
1069            let res = log_reg.fit(&DatasetBase::new(inf_x.view(), &y));
1070            assert!(matches!(res.unwrap_err(), Error::InvalidValues));
1071        }
1072        for inf in infs {
1073            let log_reg = LogisticRegression::default().alpha(*inf);
1074            let res = log_reg.fit(&DatasetBase::new(normal_x.view(), &y));
1075            assert!(matches!(res.unwrap_err(), Error::InvalidAlpha));
1076        }
1077        let mut non_positives = infs.to_vec();
1078        non_positives.push(-1.0);
1079        non_positives.push(0.0);
1080        for inf in &non_positives {
1081            let log_reg = LogisticRegression::default().gradient_tolerance(*inf);
1082            let res = log_reg.fit(&Dataset::new(normal_x.to_owned(), y.to_owned()));
1083            assert!(matches!(res.unwrap_err(), Error::InvalidGradientTolerance));
1084        }
1085    }
1086
1087    #[test]
1088    fn validates_initial_params() {
1089        let infs = &[f64::INFINITY, f64::NEG_INFINITY, f64::NAN];
1090        let normal_x = array![[-1.0], [1.0]];
1091        let normal_y = array![0, 1];
1092        let dataset = Dataset::new(normal_x, normal_y);
1093        for inf in infs {
1094            let log_reg = LogisticRegression::default().initial_params(array![*inf, 0.0]);
1095            let res = log_reg.fit(&dataset);
1096            assert!(matches!(res.unwrap_err(), Error::InvalidInitialParameters));
1097        }
1098        {
1099            let log_reg = LogisticRegression::default().initial_params(array![0.0, 0.0, 0.0]);
1100            let res = log_reg.fit(&dataset);
1101            assert!(matches!(
1102                res.unwrap_err(),
1103                Error::InitialParameterFeaturesMismatch {
1104                    rows: 3,
1105                    n_features: 2
1106                }
1107            ));
1108        }
1109        {
1110            let log_reg = LogisticRegression::default()
1111                .with_intercept(false)
1112                .initial_params(array![0.0, 0.0]);
1113            let res = log_reg.fit(&dataset);
1114            assert!(matches!(
1115                res.unwrap_err(),
1116                Error::InitialParameterFeaturesMismatch {
1117                    rows: 2,
1118                    n_features: 1
1119                }
1120            ));
1121        }
1122    }
1123
1124    #[test]
1125    fn uses_initial_params() {
1126        let params = array![1.2, -4.12];
1127        let log_reg = LogisticRegression::default()
1128            .initial_params(params)
1129            .max_iterations(5);
1130        let x = array![
1131            [0.0],
1132            [1.0],
1133            [2.0],
1134            [3.0],
1135            [4.0],
1136            [5.0],
1137            [6.0],
1138            [7.0],
1139            [8.0],
1140            [9.0]
1141        ];
1142        let y = array![0, 0, 0, 0, 1, 1, 1, 1, 1, 1];
1143        let dataset = Dataset::new(x, y);
1144        let res = log_reg.fit(&dataset).unwrap();
1145        assert!(res.intercept().abs_diff_eq(&-4.124, 1e-3));
1146        assert!(res.params().abs_diff_eq(&array![1.181], 1e-3));
1147        assert_eq!(
1148            &res.predict(dataset.records()),
1149            dataset.targets().as_single_targets()
1150        );
1151
1152        // Test serialization
1153        #[cfg(feature = "serde")]
1154        {
1155            let ser = rmp_serde::to_vec(&res).unwrap();
1156            let unser: FittedLogisticRegression<f32, f32> = rmp_serde::from_slice(&ser).unwrap();
1157
1158            let x = array![[1.0]];
1159            let y_hat = unser.predict(&x);
1160
1161            assert!(y_hat[0] == 0.0);
1162        }
1163    }
1164
1165    #[test]
1166    fn works_with_f32() {
1167        let log_reg = LogisticRegression::default();
1168        let x: Array2<f32> = array![[-1.0], [-0.01], [0.01], [1.0]];
1169        let y = array![0, 0, 1, 1];
1170        let dataset = Dataset::new(x, y);
1171        let res = log_reg.fit(&dataset).unwrap();
1172        assert_abs_diff_eq!(res.intercept(), 0.0_f32);
1173        assert!(res.params().abs_diff_eq(&array![-0.682_f32], 1e-3));
1174        assert_eq!(
1175            &res.predict(dataset.records()),
1176            dataset.targets().as_single_targets()
1177        );
1178    }
1179
1180    #[test]
1181    fn test_log_sum_exp() {
1182        let data = array![[3.3, 0.4, -2.1], [0.4, 2.2, -0.1], [1., 0., -1.]];
1183        let out = log_sum_exp(&data, Axis(1));
1184        assert_abs_diff_eq!(out, array![3.35783, 2.43551, 1.40761], epsilon = 1e-5);
1185    }
1186
1187    #[test]
1188    fn test_softmax() {
1189        let mut data = array![3.3, 5.5, 0.1, -4.4, 8.0];
1190        softmax_inplace(&mut data);
1191        assert_relative_eq!(
1192            data,
1193            array![0.0083324, 0.075200047, 0.000339647, 0.000003773, 0.91612413],
1194            epsilon = 1e-8
1195        );
1196        assert_abs_diff_eq!(data.sum(), 1.0);
1197    }
1198
1199    #[test]
1200    fn test_multi_logistic_loss_grad() {
1201        let x = array![
1202            [0.0, 0.5],
1203            [1.0, -1.0],
1204            [2.0, -2.0],
1205            [3.0, -3.0],
1206            [4.0, -4.0],
1207            [5.0, -5.0],
1208            [6.0, -6.0],
1209            [7.0, -7.0],
1210        ];
1211        let y = array![
1212            [1.0, 0.0, 0.0],
1213            [1.0, 0.0, 0.0],
1214            [0.0, 1.0, 0.0],
1215            [0.0, 1.0, 0.0],
1216            [0.0, 1.0, 0.0],
1217            [0.0, 0.0, 1.0],
1218            [0.0, 0.0, 1.0],
1219            [0.0, 0.0, 1.0],
1220        ];
1221        let params1 = array![[4.4, -1.2, 3.3], [3.4, 0.1, 0.0]];
1222        let params2 = array![[0.001, -3.2, 2.9], [0.1, 4.5, 5.7], [4.5, 2.2, 1.7]];
1223        let alpha = 0.6;
1224
1225        {
1226            let (log_prob, w) = multi_logistic_prob_params(&x, &params1);
1227            assert_abs_diff_eq!(
1228                log_prob,
1229                array![
1230                    [-3.18259845e-01, -1.96825985e+00, -2.01825985e+00],
1231                    [-2.40463987e+00, -4.70463987e+00, -1.04639868e-01],
1232                    [-4.61010168e+00, -9.21010168e+00, -1.01016809e-02],
1233                    [-6.90100829e+00, -1.38010083e+01, -1.00829256e-03],
1234                    [-9.20010104e+00, -1.84001010e+01, -1.01044506e-04],
1235                    [-1.15000101e+01, -2.30000101e+01, -1.01301449e-05],
1236                    [-1.38000010e+01, -2.76000010e+01, -1.01563199e-06],
1237                    [-1.61000001e+01, -3.22000001e+01, -1.01826043e-07],
1238                ],
1239                epsilon = 1e-6
1240            );
1241            assert_abs_diff_eq!(w, params1);
1242            let loss = multi_logistic_loss(&x, &y, alpha, &params1);
1243            assert_abs_diff_eq!(loss, 57.11212197835295, epsilon = 1e-6);
1244            let grad = multi_logistic_grad(&x, &y, alpha, &params1);
1245            assert_abs_diff_eq!(
1246                grad,
1247                array![
1248                    [1.7536815, -9.71074369, 11.85706219],
1249                    [2.79002537, 9.12059357, -9.81061893]
1250                ],
1251                epsilon = 1e-6
1252            );
1253        }
1254
1255        {
1256            let (log_prob, w) = multi_logistic_prob_params(&x, &params2);
1257            assert_abs_diff_eq!(
1258                log_prob,
1259                array![
1260                    [-1.06637742e+00, -1.16637742e+00, -1.06637742e+00],
1261                    [-4.12429463e-03, -9.90512429e+00, -5.50512429e+00],
1262                    [-2.74092305e-04, -1.75022741e+01, -8.20227409e+00],
1263                    [-1.84027855e-05, -2.51030184e+01, -1.09030184e+01],
1264                    [-1.23554225e-06, -3.27040012e+01, -1.36040012e+01],
1265                    [-8.29523046e-08, -4.03050001e+01, -1.63050001e+01],
1266                    [-5.56928016e-09, -4.79060000e+01, -1.90060000e+01],
1267                    [-3.73912013e-10, -5.55070000e+01, -2.17070000e+01]
1268                ],
1269                epsilon = 1e-6
1270            );
1271            assert_abs_diff_eq!(w, params2.slice(s![..params2.nrows() - 1, ..]));
1272            let loss = multi_logistic_loss(&x, &y, alpha, &params2);
1273            assert_abs_diff_eq!(loss, 154.8177958366479, epsilon = 1e-6);
1274            let grad = multi_logistic_grad(&x, &y, alpha, &params2);
1275            assert_abs_diff_eq!(
1276                grad,
1277                array![
1278                    [26.99587549, -10.91995003, -16.25532546],
1279                    [-27.26314882, 11.85569669, 21.58745213],
1280                    [5.33984376, -2.68845675, -2.65138701]
1281                ],
1282                epsilon = 1e-6
1283            );
1284        }
1285    }
1286
1287    #[test]
1288    fn simple_multi_example() {
1289        let x = array![[-1., 0.], [0., 1.], [1., 1.]];
1290        let y = array![2, 1, 0];
1291        let log_reg = MultiLogisticRegression::default()
1292            .alpha(0.1)
1293            .initial_params(Array::zeros((3, 3)));
1294        let dataset = Dataset::new(x, y);
1295        let res = log_reg.fit(&dataset).unwrap();
1296        assert_eq!(res.params().dim(), (2, 3));
1297        assert_eq!(res.intercept().dim(), 3);
1298        assert_eq!(
1299            &res.predict(dataset.records()),
1300            dataset.targets().as_single_targets()
1301        );
1302    }
1303
1304    #[test]
1305    fn simple_multi_example_2() {
1306        let x = array![[1.0], [0.0], [1.0], [0.0]];
1307        let y = array![1, 0, 1, 0];
1308        let dataset = DatasetBase::new(x, y);
1309        let model = MultiLogisticRegression::default().fit(&dataset).unwrap();
1310
1311        let pred = model.predict(&dataset.records);
1312        assert_eq!(dataset.targets(), pred);
1313    }
1314
1315    #[test]
1316    fn simple_multi_example_text() {
1317        let log_reg = MultiLogisticRegression::default().alpha(0.1);
1318        let x = array![[0.1], [1.0], [-1.0], [-0.1]];
1319        let y = array!["dog", "ape", "rocket", "cat"];
1320        let dataset = Dataset::new(x, y);
1321        let res = log_reg.fit(&dataset).unwrap();
1322        assert_eq!(res.params().dim(), (1, 4));
1323        assert_eq!(res.intercept().dim(), 4);
1324        assert_eq!(
1325            &res.predict(dataset.records()),
1326            dataset.targets().as_single_targets()
1327        );
1328    }
1329
1330    #[test]
1331    fn multi_on_binary_problem() {
1332        let log_reg = MultiLogisticRegression::default().alpha(1.0);
1333        let x = array![
1334            [0.0],
1335            [1.0],
1336            [2.0],
1337            [3.0],
1338            [4.0],
1339            [5.0],
1340            [6.0],
1341            [7.0],
1342            [8.0],
1343            [9.0]
1344        ];
1345        let y = array![0, 0, 0, 0, 1, 1, 1, 1, 1, 1];
1346        let dataset = Dataset::new(x, y);
1347        let res = log_reg.fit(&dataset).unwrap();
1348        assert_eq!(res.params().dim(), (1, 2));
1349        assert_eq!(res.intercept().dim(), 2);
1350        assert_eq!(
1351            &res.predict(dataset.records()),
1352            dataset.targets().as_single_targets()
1353        );
1354    }
1355
1356    #[test]
1357    fn reject_num_class_mismatch() {
1358        let n_samples = 4;
1359        let n_classes = 3;
1360        let n_features = 1;
1361        let x = Array2::<f64>::zeros((n_samples, n_features));
1362        let y = array![0, 1, 2, 0];
1363        let dataset = Dataset::new(x, y);
1364
1365        let log_reg = MultiLogisticRegression::default()
1366            .with_intercept(false)
1367            .initial_params(Array::zeros((n_features, n_classes - 1)));
1368        assert!(matches!(
1369            log_reg.fit(&dataset).unwrap_err(),
1370            Error::InitialParameterClassesMismatch {
1371                cols: 2,
1372                n_classes: 3,
1373            }
1374        ));
1375    }
1376}