linfa_logistic/
lib.rs

1//! # Logistic Regression
2//!
3//! ## The Big Picture
4//!
5//! `linfa-logistic` is a crate in the [`linfa`](https://crates.io/crates/linfa) ecosystem, an effort to create a toolkit for classical Machine Learning implemented in pure Rust, akin to Python's `scikit-learn`.
6//!
7//! ## Current state
8//! `linfa-logistic` provides a pure Rust implementation of a [binomial logistic regression model](LogisticRegression) and a [multinomial logistic regression model](MultiLogisticRegression).
9//!
10//! ## Examples
11//!
12//! There is an usage example in the `examples/` directory. To run, use:
13//!
14//! ```bash
15//! $ cargo run --example winequality
16//! ```
17//!
18
19pub mod error;
20
21use crate::error::{Error, Result};
22use argmin::core::{CostFunction, Executor, Gradient, IterState, OptimizationResult, Solver};
23use argmin::solver::linesearch::MoreThuenteLineSearch;
24use argmin::solver::quasinewton::LBFGS;
25use linfa::dataset::AsSingleTargets;
26use linfa::prelude::DatasetBase;
27use linfa::traits::{Fit, PredictInplace};
28use ndarray::{
29    s, Array, Array1, Array2, ArrayBase, ArrayView, ArrayView2, Axis, CowArray, Data, DataMut,
30    Dimension, IntoDimension, Ix1, Ix2, RemoveAxis, Slice, Zip,
31};
32use ndarray_stats::QuantileExt;
33use std::default::Default;
34
35#[cfg(feature = "serde")]
36use serde_crate::de::DeserializeOwned;
37#[cfg(feature = "serde")]
38use serde_crate::{Deserialize, Serialize};
39
40mod argmin_param;
41mod float;
42mod hyperparams;
43
44use argmin_param::*;
45use float::Float;
46use hyperparams::{LogisticRegressionParams, LogisticRegressionValidParams};
47
48/// A two-class logistic regression model.
49///
50/// Logistic regression combines linear models with
51/// the sigmoid function `sigm(x) = 1/(1+exp(-x))`
52/// to learn a family of functions that map the feature space to `[0,1]`.
53///
54/// Logistic regression is used in binary classification
55/// by interpreting the predicted value as the probability that the sample
56/// has label `1`. A threshold can be set in the [fitted model](FittedLogisticRegression) to decide the minimum
57/// probability needed to classify a sample as `1`, which defaults to `0.5`.
58///
59/// In this implementation any binary set of labels can be used, not necessarily `0` and `1`.
60///
61/// l2 regularization is used by this algorithm and is weighted by parameter `alpha`. Setting `alpha`
62/// close to zero removes regularization and the problem solved minimizes only the
63/// empirical risk. On the other hand, setting `alpha` to a high value increases
64/// the weight of the l2 norm of the linear model coefficients in the cost function.
65///
66/// ## Examples
67///
68/// Here's an example on how to train a logistic regression model on the `winequality` dataset
69/// ```rust
70/// use linfa::traits::{Fit, Predict};
71/// use linfa_logistic::LogisticRegression;
72///
73/// // Example on using binary labels different from 0 and 1
74/// let dataset = linfa_datasets::winequality().map_targets(|x| if *x > 6 { "good" } else { "bad" });
75/// let model = LogisticRegression::default().fit(&dataset).unwrap();
76/// let prediction = model.predict(&dataset);
77/// ```
78pub type LogisticRegression<F> = LogisticRegressionParams<F, Ix1>;
79
80/// Validated version of `LogisticRegression`
81pub type ValidLogisticRegression<F> = LogisticRegressionValidParams<F, Ix1>;
82
83/// A multinomial class logistic regression model.
84///
85/// The output labels can map to any discrete feature space, since the algorithm calculates the
86/// likelihood of a feature vector corresponding to any given outcome using the softmax function
87/// `softmax(x) = exp(x) / sum(exp(xi))`
88///
89/// l2 regularization is used by this algorithm and is weighted by parameter `alpha`. Setting `alpha`
90/// close to zero removes regularization and the problem solved minimizes only the
91/// empirical risk. On the other hand, setting `alpha` to a high value increases
92/// the weight of the l2 norm of the linear model coefficients in the cost function.
93pub type MultiLogisticRegression<F> = LogisticRegressionParams<F, Ix2>;
94
95/// Validated version of `MultiLogisticRegression`
96pub type ValidMultiLogisticRegression<F> = LogisticRegressionValidParams<F, Ix2>;
97
98impl<F: Float, D: Dimension> Default for LogisticRegressionParams<F, D> {
99    fn default() -> Self {
100        LogisticRegressionParams::new()
101    }
102}
103
104type LBFGSType<F, D> = LBFGS<
105    MoreThuenteLineSearch<ArgminParam<F, D>, ArgminParam<F, D>, F>,
106    ArgminParam<F, D>,
107    ArgminParam<F, D>,
108    F,
109>;
110type LBFGSType1<F> = LBFGSType<F, Ix1>;
111type LBFGSType2<F> = LBFGSType<F, Ix2>;
112
113type IterStateType<F, D> = IterState<ArgminParam<F, D>, ArgminParam<F, D>, (), (), (), F>;
114
115impl<F: Float, D: Dimension> LogisticRegressionValidParams<F, D> {
116    /// Create the initial parameters, either from a user supplied array
117    /// or an array of 0s
118    fn setup_init_params(&self, dims: D::Pattern) -> ArgminParam<F, D> {
119        if let Some(params) = self.initial_params.as_ref() {
120            ArgminParam(params.clone())
121        } else {
122            let mut dims = dims.into_dimension();
123            dims.as_array_view_mut()[0] += self.fit_intercept as usize;
124            ArgminParam(Array::zeros(dims))
125        }
126    }
127
128    /// Ensure that `x` and `y` have the right shape and that all data and
129    /// configuration parameters are finite.
130    fn validate_data<A: Data<Elem = F>, B: Data<Elem = F>>(
131        &self,
132        x: &ArrayBase<A, Ix2>,
133        y: &ArrayBase<B, D>,
134    ) -> Result<()> {
135        if x.shape()[0] != y.shape()[0] {
136            return Err(Error::MismatchedShapes(x.shape()[0], y.shape()[0]));
137        }
138        if x.iter().any(|x| !x.is_finite()) || y.iter().any(|y| !y.is_finite()) {
139            return Err(Error::InvalidValues);
140        }
141        self.validate_init_dims(x.shape()[1], y.shape().get(1).copied())?;
142        Ok(())
143    }
144
145    fn validate_init_dims(&self, mut n_features: usize, n_classes: Option<usize>) -> Result<()> {
146        if let Some(params) = self.initial_params.as_ref() {
147            let shape = params.shape();
148            n_features += self.fit_intercept as usize;
149            if n_features != shape[0] {
150                return Err(Error::InitialParameterFeaturesMismatch {
151                    n_features,
152                    rows: shape[0],
153                });
154            }
155            if let Some(n_classes) = n_classes {
156                if n_classes != shape[1] {
157                    return Err(Error::InitialParameterClassesMismatch {
158                        n_classes,
159                        cols: shape[1],
160                    });
161                }
162            }
163        }
164        Ok(())
165    }
166
167    /// Create a `LogisticRegressionProblem`.
168    fn setup_problem<'a, A: Data<Elem = F>>(
169        &self,
170        x: &'a ArrayBase<A, Ix2>,
171        target: Array<F, D>,
172    ) -> LogisticRegressionProblem<'a, F, A, D> {
173        LogisticRegressionProblem {
174            x,
175            target,
176            alpha: self.alpha,
177        }
178    }
179
180    /// Create the LBFGS solver using MoreThuenteLineSearch and set gradient
181    /// tolerance.
182    fn setup_solver(&self) -> LBFGSType<F, D> {
183        let linesearch = MoreThuenteLineSearch::new();
184        LBFGS::new(linesearch, 10)
185            .with_tolerance_grad(self.gradient_tolerance)
186            .unwrap()
187    }
188}
189
190impl<
191        F: Float,
192        #[cfg(feature = "serde")] D: Dimension + Serialize + DeserializeOwned,
193        #[cfg(not(feature = "serde"))] D: Dimension,
194    > LogisticRegressionValidParams<F, D>
195{
196    /// Run the LBFGS solver until it converges or runs out of iterations.
197    fn run_solver<P: SolvableProblem<F, D>>(
198        &self,
199        problem: P,
200        solver: P::Solver,
201        init_params: ArgminParam<F, D>,
202    ) -> Result<OptimizationResult<P, P::Solver, IterStateType<F, D>>> {
203        Executor::new(problem, solver)
204            .configure(|state| state.param(init_params).max_iters(self.max_iterations))
205            .run()
206            .map_err(move |err| err.into())
207    }
208}
209
210impl<C: Ord + Clone, F: Float, D: Data<Elem = F>, T: AsSingleTargets<Elem = C>>
211    Fit<ArrayBase<D, Ix2>, T, Error> for ValidLogisticRegression<F>
212{
213    type Object = FittedLogisticRegression<F, C>;
214
215    /// Given a 2-dimensional feature matrix array `x` with shape
216    /// (n_samples, n_features) and an array of target classes to predict,
217    /// create a `FittedLinearRegression` object which allows making
218    /// predictions.
219    ///
220    /// The array of target classes `y` must have exactly two discrete values, (e.g. 0 and 1, "cat"
221    /// and "dog", ...), which represent the two different classes the model is supposed to
222    /// predict.
223    ///
224    /// The array `y` must also have exactly `n_samples` items, i.e.
225    /// exactly as many items as there are rows in the feature matrix `x`.
226    ///
227    /// This method returns an error if any of the preconditions are violated,
228    /// i.e. any values are `Inf` or `NaN`, `y` doesn't have as many items as
229    /// `x` has rows, or if other parameters (gradient_tolerance, alpha) have
230    /// been set to inalid values.
231    fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
232        let (x, y) = (dataset.records(), dataset.targets());
233        let (labels, target) = label_classes(y)?;
234        self.validate_data(x, &target)?;
235        let problem = self.setup_problem(x, target);
236        let solver = self.setup_solver();
237        let init_params = self.setup_init_params(x.ncols());
238        let result = self.run_solver(problem, solver, init_params)?;
239
240        let params = result
241            .state
242            .best_param
243            .unwrap_or(self.setup_init_params(x.ncols()));
244        let (w, intercept) = convert_params(x.ncols(), params.as_array());
245        Ok(FittedLogisticRegression::new(
246            *intercept.view().into_scalar(),
247            w.to_owned(),
248            labels,
249        ))
250    }
251}
252
253impl<C: Ord + Clone, F: Float, D: Data<Elem = F>, T: AsSingleTargets<Elem = C>>
254    Fit<ArrayBase<D, Ix2>, T, Error> for ValidMultiLogisticRegression<F>
255{
256    type Object = MultiFittedLogisticRegression<F, C>;
257
258    /// Given a 2-dimensional feature matrix array `x` with shape
259    /// (n_samples, n_features) and an array of target classes to predict,
260    /// create a `MultiFittedLogisticRegression` object which allows making
261    /// predictions. The target classes can have any number of discrete values.
262    ///
263    /// This method returns an error if any of the preconditions are violated,
264    /// i.e. any values are `Inf` or `NaN`, `y` doesn't have as many items as
265    /// `x` has rows, or if other parameters (gradient_tolerance, alpha) have
266    /// been set to inalid values. The input features are also strongly recommended to be
267    /// normalized to ensure numerical stability.
268    fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
269        let (x, y) = (dataset.records(), dataset.targets());
270        let (classes, target) = label_classes_multi(y)?;
271        self.validate_data(x, &target)?;
272        let problem = self.setup_problem(x, target);
273        let solver = self.setup_solver();
274        let init_params = self.setup_init_params((x.ncols(), classes.len()));
275        let result = self.run_solver(problem, solver, init_params)?;
276
277        let params = result
278            .state
279            .best_param
280            .unwrap_or(self.setup_init_params((x.ncols(), classes.len())));
281        let (w, intercept) = convert_params(x.ncols(), params.as_array());
282        Ok(MultiFittedLogisticRegression::new(
283            intercept.to_owned(),
284            w.to_owned(),
285            classes,
286        ))
287    }
288}
289
290/// Identify the distinct values of the classes `y` and associate
291/// the target labels `-1.0` and `1.0` to it. -1.0 always labels the
292/// smaller class (by PartialOrd) and 1.0 always labels the larger
293/// class.
294///
295/// It is an error to have more than two classes.
296fn label_classes<F, T, C>(y: T) -> Result<(BinaryClassLabels<F, C>, Array1<F>)>
297where
298    F: Float,
299    T: AsSingleTargets<Elem = C>,
300    C: Ord + Clone,
301{
302    let y = y.as_single_targets();
303
304    // counts the instances of two distinct class labels
305    let mut binary_classes = [None, None];
306    // find binary classes of our target dataset
307    for class in y {
308        binary_classes = match binary_classes {
309            // count the first class label
310            [None, None] => [Some((class, 1)), None],
311            // if the class has already been counted, increment the count
312            [Some((c, count)), c2] if c == class => [Some((class, count + 1)), c2],
313            [c1, Some((c, count))] if c == class => [c1, Some((class, count + 1))],
314            // count the second class label
315            [Some(c1), None] => [Some(c1), Some((class, 1))],
316
317            // should not be possible
318            [None, Some(_)] => unreachable!("impossible binary class array"),
319            // found 3rd distinct class
320            [Some(_), Some(_)] => return Err(Error::TooManyClasses),
321        };
322    }
323
324    let (pos_class, neg_class) = match binary_classes {
325        [Some(a), Some(b)] => (a, b),
326        _ => return Err(Error::TooFewClasses),
327    };
328
329    let mut target_array = y
330        .into_iter()
331        .map(|x| {
332            if x == pos_class.0 {
333                F::POSITIVE_LABEL
334            } else {
335                F::NEGATIVE_LABEL
336            }
337        })
338        .collect::<Array1<_>>();
339
340    let (pos_cl, neg_cl) = if pos_class.1 < neg_class.1 {
341        // If we found the larger class first, flip the sign in the target
342        // vector, so that -1.0 is always the label for the smaller class
343        // and 1.0 the label for the larger class
344        target_array *= -F::one();
345        (neg_class.0.clone(), pos_class.0.clone())
346    } else {
347        (pos_class.0.clone(), neg_class.0.clone())
348    };
349
350    Ok((
351        BinaryClassLabels {
352            pos: ClassLabel {
353                class: pos_cl,
354                label: F::POSITIVE_LABEL,
355            },
356            neg: ClassLabel {
357                class: neg_cl,
358                label: F::NEGATIVE_LABEL,
359            },
360        },
361        target_array,
362    ))
363}
364
365/// Identify the distinct values of the classes in `y` and map each value to an integer. Smaller
366/// classes (by `PartialOrd`) map to smaller integers. Returns the mapping along with a one-hot
367/// encoding of the numerical labels corresponding to `y`.
368fn label_classes_multi<F, T, C>(y: T) -> Result<(Vec<C>, Array2<F>)>
369where
370    F: Float,
371    T: AsSingleTargets<Elem = C>,
372    C: Ord + Clone,
373{
374    let y_single_target = y.as_single_targets();
375    let mut classes = y_single_target.to_vec();
376    // Dedup the list of classes
377    classes.sort();
378    classes.dedup();
379
380    let mut onehot = Array2::zeros((y_single_target.len(), classes.len()));
381    Zip::from(onehot.rows_mut())
382        .and(&y_single_target)
383        .for_each(|mut oh_row, cls| {
384            let idx = classes.binary_search(cls).unwrap();
385            oh_row[idx] = F::one();
386        });
387    Ok((classes, onehot))
388}
389
390/// Conditionally split the feature vector `w` into parameter vector and
391/// intercept parameter.
392/// Dimensions of `w` are either (f) or (f, n_classes)
393fn convert_params<F: Float, D: Dimension + RemoveAxis>(
394    n_features: usize,
395    w: &Array<F, D>,
396) -> (ArrayView<'_, F, D>, CowArray<'_, F, D::Smaller>) {
397    let nrows = w.shape()[0];
398    if n_features == nrows {
399        (
400            w.view(),
401            Array::zeros(w.raw_dim().remove_axis(Axis(0))).into(),
402        )
403    } else if n_features + 1 == nrows {
404        (
405            w.slice_axis(Axis(0), Slice::from(..n_features)),
406            w.index_axis(Axis(0), n_features).into(),
407        )
408    } else {
409        panic!(
410            "Unexpected length of parameter vector `w`, exected {} or {}, found {}",
411            n_features,
412            n_features + 1,
413            nrows
414        );
415    }
416}
417
418/// The logistic function
419fn logistic<F: linfa::Float>(x: F) -> F {
420    F::one() / (F::one() + (-x).exp())
421}
422
423/// A numerically stable version of the log of the logistic function.
424///
425/// Taken from scikit-learn
426/// https://github.com/scikit-learn/scikit-learn/blob/0.23.1/sklearn/utils/_logistic_sigmoid.pyx
427///
428/// See the blog post describing this implementation:
429/// http://fa.bianp.net/blog/2013/numerical-optimizers-for-logistic-regression/
430fn log_logistic<F: linfa::Float>(x: F) -> F {
431    if x > F::zero() {
432        -(F::one() + (-x).exp()).ln()
433    } else {
434        x - (F::one() + x.exp()).ln()
435    }
436}
437
438/// Finds the log of the sum of exponents across a specific axis in a numerically stable way. More
439/// specifically, computes `ln(exp(x1) + exp(x2) + exp(e3) + ...)` across an axis.
440///
441/// Based off this implementation: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.logsumexp.html
442fn log_sum_exp<F: linfa::Float, A: Data<Elem = F>>(
443    m: &ArrayBase<A, Ix2>,
444    axis: Axis,
445) -> Array<F, Ix1> {
446    // Find max value of the array
447    let max = m.iter().copied().reduce(F::max).unwrap();
448    // Computes `max + ln(exp(x1-max) + exp(x2-max) + exp(x3-max) + ...)`, which is equal to the
449    // log_sum_exp formula
450    let reduced = m.fold_axis(axis, F::zero(), |acc, elem| *acc + (*elem - max).exp());
451    reduced.mapv_into(|e| e.max(F::cast(1e-15)).ln() + max)
452}
453
454/// Computes `exp(n - max) / sum(exp(n- max))`, which is a numerically stable version of softmax
455fn softmax_inplace<F: linfa::Float, A: DataMut<Elem = F>>(v: &mut ArrayBase<A, Ix1>) {
456    let max = v.iter().copied().reduce(F::max).unwrap();
457    v.mapv_inplace(|n| (n - max).exp());
458    let sum = v.sum();
459    v.mapv_inplace(|n| n / sum);
460}
461
462/// Computes the logistic loss assuming the training labels $y \in {-1, 1}$
463///
464/// Because the logistic function fullfills $\sigma(-z) = 1 - \sigma(z)$
465/// we can write $P(y=1|z) = \sigma(z) = \sigma(yz)$ and
466/// $P(y=-1|z) = 1 - P(y=1|z) = 1 - \sigma(z) = \sigma(-z) = \sigma(yz)$, so
467/// $P(y|z) = \sigma(yz)$ for both $y=1$ and $y=-1$.
468///
469/// Thus, the log loss can be written as
470/// $$-\sum_{i=1}^{N} \log(\sigma(y_i z_i)) + \frac{\alpha}{2}\text{params}^T\text{params}$$
471fn logistic_loss<F: Float, A: Data<Elem = F>>(
472    x: &ArrayBase<A, Ix2>,
473    y: &Array1<F>,
474    alpha: F,
475    w: &Array1<F>,
476) -> F {
477    let n_features = x.shape()[1];
478    let (params, intercept) = convert_params(n_features, w);
479    let yz = x.dot(&params.into_shape_with_order((params.len(), 1)).unwrap()) + intercept;
480    let len = yz.len();
481    let mut yz = yz.into_shape_with_order(len).unwrap() * y;
482    yz.mapv_inplace(log_logistic);
483    -yz.sum() + F::cast(0.5) * alpha * params.dot(&params)
484}
485
486/// Computes the gradient of the logistic loss function
487fn logistic_grad<F: Float, A: Data<Elem = F>>(
488    x: &ArrayBase<A, Ix2>,
489    y: &Array1<F>,
490    alpha: F,
491    w: &Array1<F>,
492) -> Array1<F> {
493    let n_features = x.shape()[1];
494    let (params, intercept) = convert_params(n_features, w);
495    let yz = x.dot(&params.into_shape_with_order((params.len(), 1)).unwrap()) + intercept;
496    let len = yz.len();
497    let mut yz = yz.into_shape_with_order(len).unwrap() * y;
498    yz.mapv_inplace(logistic);
499    yz -= F::one();
500    yz *= y;
501    if w.len() == n_features + 1 {
502        let mut grad = Array::zeros(w.len());
503        grad.slice_mut(s![..n_features])
504            .assign(&(x.t().dot(&yz) + (&params * alpha)));
505        grad[n_features] = yz.sum();
506        grad
507    } else {
508        x.t().dot(&yz) + (&params * alpha)
509    }
510}
511
512/// Compute the log of probabilities, which is `log(softmax(H))`, where H is `X . W + b`. Also
513/// returns `W` without the intercept.
514/// `Y` is the output (n_samples * n_classes), `X` is the input (n_samples * n_features), `W` is the
515/// params (n_features * n_classes), `b` is the intercept vector (n_classes).
516fn multi_logistic_prob_params<'a, F: Float, A: Data<Elem = F>>(
517    x: &ArrayBase<A, Ix2>,
518    w: &'a Array2<F>, // This parameter includes `W` and `b`
519) -> (Array2<F>, ArrayView2<'a, F>) {
520    let n_features = x.shape()[1];
521    let (params, intercept) = convert_params(n_features, w);
522    // Compute H
523    let h = x.dot(&params) + intercept;
524    // This computes `H - log(sum(exp(H)))`, which is equal to
525    // `log(softmax(H)) = log(exp(H) / sum(exp(H)))`
526    let log_prob = &h
527        - log_sum_exp(&h, Axis(1))
528            .into_shape_with_order((h.nrows(), 1))
529            .unwrap();
530    (log_prob, params)
531}
532
533/// Computes loss function of `-sum(Y * log(softmax(H))) + alpha/2 * norm(W)`
534fn multi_logistic_loss<F: Float, A: Data<Elem = F>>(
535    x: &ArrayBase<A, Ix2>,
536    y: &Array2<F>,
537    alpha: F,
538    w: &Array2<F>,
539) -> F {
540    let (log_prob, params) = multi_logistic_prob_params(x, w);
541    // Calculate loss
542    -elem_dot(&log_prob, y) + F::cast(0.5) * alpha * elem_dot(&params, &params)
543}
544
545/// Computes multinomial gradients for `W` and `b` and combine them.
546/// Gradient for `W` is `Xt . (softmax(H) - Y) + alpha * W`.
547/// Gradient for `b` is `sum(softmax(H) - Y)`.
548fn multi_logistic_grad<F: Float, A: Data<Elem = F>>(
549    x: &ArrayBase<A, Ix2>,
550    y: &Array2<F>,
551    alpha: F,
552    w: &Array2<F>,
553) -> Array2<F> {
554    let (log_prob, params) = multi_logistic_prob_params(x, w);
555    let (n_features, n_classes) = params.dim();
556    let intercept = w.nrows() > n_features;
557    let mut grad = Array::zeros((n_features + intercept as usize, n_classes));
558
559    // This value is `softmax(H)`
560    let prob = log_prob.mapv_into(num_traits::Float::exp);
561    let diff = prob - y;
562    // Compute gradient for `W` and place it at start of the grad matrix
563    let dw = x.t().dot(&diff) + (&params * alpha);
564    grad.slice_mut(s![..n_features, ..]).assign(&dw);
565    // Compute gradient for `b` and place it at end of grad matrix
566    if intercept {
567        grad.row_mut(n_features).assign(&diff.sum_axis(Axis(0)));
568    }
569    grad
570}
571
572/// A fitted logistic regression which can make predictions
573#[derive(Debug, Clone, PartialEq)]
574#[cfg_attr(
575    feature = "serde",
576    derive(Serialize, Deserialize),
577    serde(crate = "serde_crate")
578)]
579pub struct FittedLogisticRegression<F: Float, C: PartialOrd + Clone> {
580    threshold: F,
581    intercept: F,
582    params: Array1<F>,
583    labels: BinaryClassLabels<F, C>,
584}
585
586impl<F: Float, C: PartialOrd + Clone> FittedLogisticRegression<F, C> {
587    fn new(
588        intercept: F,
589        params: Array1<F>,
590        labels: BinaryClassLabels<F, C>,
591    ) -> FittedLogisticRegression<F, C> {
592        FittedLogisticRegression {
593            threshold: F::cast(0.5),
594            intercept,
595            params,
596            labels,
597        }
598    }
599
600    /// Set the probability threshold for which the 'positive' class will be
601    /// predicted. Defaults to 0.5.
602    pub fn set_threshold(mut self, threshold: F) -> FittedLogisticRegression<F, C> {
603        if threshold < F::zero() || threshold > F::one() {
604            panic!("FittedLogisticRegression::set_threshold: threshold needs to be between 0.0 and 1.0");
605        }
606        self.threshold = threshold;
607        self
608    }
609
610    pub fn intercept(&self) -> F {
611        self.intercept
612    }
613
614    pub fn params(&self) -> &Array1<F> {
615        &self.params
616    }
617
618    /// Get the model positive and negative classes mapped to their
619    /// corresponding problem input labels.
620    pub fn labels(&self) -> &BinaryClassLabels<F, C> {
621        &self.labels
622    }
623
624    /// Given a feature matrix, predict the probabilities that a sample
625    /// should be classified as the larger of the two classes learned when the
626    /// model was fitted.
627    pub fn predict_probabilities<A: Data<Elem = F>>(&self, x: &ArrayBase<A, Ix2>) -> Array1<F> {
628        let mut probs = x.dot(&self.params) + self.intercept;
629        probs.mapv_inplace(logistic);
630        probs
631    }
632}
633
634impl<C: PartialOrd + Clone + Default, F: Float, D: Data<Elem = F>>
635    PredictInplace<ArrayBase<D, Ix2>, Array1<C>> for FittedLogisticRegression<F, C>
636{
637    /// Given a feature matrix, predict the classes learned when the model was
638    /// fitted.
639    fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<C>) {
640        assert_eq!(
641            x.nrows(),
642            y.len(),
643            "The number of data points must match the number of output targets."
644        );
645        assert_eq!(
646            x.ncols(),
647            self.params.len(),
648            "Number of data features must match the number of features the model was trained with."
649        );
650
651        let pos_class = &self.labels.pos.class;
652        let neg_class = &self.labels.neg.class;
653        Zip::from(&self.predict_probabilities(x))
654            .and(y)
655            .for_each(|prob, out| {
656                *out = if *prob >= self.threshold {
657                    pos_class.clone()
658                } else {
659                    neg_class.clone()
660                }
661            });
662    }
663
664    fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<C> {
665        Array1::default(x.nrows())
666    }
667}
668
669/// A fitted multinomial logistic regression which can make predictions
670#[derive(Debug, Clone, PartialEq, Eq)]
671#[cfg_attr(
672    feature = "serde",
673    derive(Serialize, Deserialize),
674    serde(crate = "serde_crate")
675)]
676pub struct MultiFittedLogisticRegression<F, C: PartialOrd + Clone> {
677    intercept: Array1<F>,
678    params: Array2<F>,
679    classes: Vec<C>,
680}
681
682impl<F: Float, C: PartialOrd + Clone> MultiFittedLogisticRegression<F, C> {
683    fn new(intercept: Array1<F>, params: Array2<F>, classes: Vec<C>) -> Self {
684        Self {
685            intercept,
686            params,
687            classes,
688        }
689    }
690
691    pub fn intercept(&self) -> &Array1<F> {
692        &self.intercept
693    }
694
695    pub fn params(&self) -> &Array2<F> {
696        &self.params
697    }
698
699    /// Return non-normalized probabilities (n_samples * n_classes)
700    fn predict_nonorm_probabilities<A: Data<Elem = F>>(&self, x: &ArrayBase<A, Ix2>) -> Array2<F> {
701        x.dot(&self.params) + &self.intercept
702    }
703
704    /// Return normalized probabilities for each output class. The output dimensions are (n_samples
705    /// * n_classes).
706    pub fn predict_probabilities<A: Data<Elem = F>>(&self, x: &ArrayBase<A, Ix2>) -> Array2<F> {
707        let mut probs = self.predict_nonorm_probabilities(x);
708        probs
709            .rows_mut()
710            .into_iter()
711            .for_each(|mut row| softmax_inplace(&mut row));
712        probs
713    }
714
715    /// Get the list of class labels, which maps the numerical class indices to the labels
716    pub fn classes(&self) -> &[C] {
717        &self.classes
718    }
719}
720
721impl<C: PartialOrd + Clone + Default, F: Float, D: Data<Elem = F>>
722    PredictInplace<ArrayBase<D, Ix2>, Array1<C>> for MultiFittedLogisticRegression<F, C>
723{
724    /// Given a feature matrix, predict the classes learned when the model was
725    /// fitted.
726    fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<C>) {
727        assert_eq!(
728            x.nrows(),
729            y.len(),
730            "The number of data points must match the number of output targets."
731        );
732        assert_eq!(
733            x.ncols(),
734            self.params.nrows(),
735            "Number of data features must match the number of features the model was trained with."
736        );
737
738        let probs = self.predict_nonorm_probabilities(x);
739        Zip::from(probs.rows()).and(y).for_each(|prob_row, out| {
740            let idx = prob_row.argmax().unwrap();
741            *out = self.classes[idx].clone();
742        });
743    }
744
745    fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<C> {
746        Array1::default(x.nrows())
747    }
748}
749
750#[derive(Debug, Clone, PartialEq)]
751#[cfg_attr(
752    feature = "serde",
753    derive(Serialize, Deserialize),
754    serde(crate = "serde_crate")
755)]
756pub struct ClassLabel<F, C: PartialOrd> {
757    pub class: C,
758    pub label: F,
759}
760
761#[derive(Debug, Clone, PartialEq)]
762#[cfg_attr(
763    feature = "serde",
764    derive(Serialize, Deserialize),
765    serde(crate = "serde_crate")
766)]
767pub struct BinaryClassLabels<F, C: PartialOrd> {
768    pub pos: ClassLabel<F, C>,
769    pub neg: ClassLabel<F, C>,
770}
771
772/// Internal representation of a logistic regression problem.
773/// This data structure exists to be handed to Argmin.
774struct LogisticRegressionProblem<'a, F: Float, A: Data<Elem = F>, D: Dimension> {
775    x: &'a ArrayBase<A, Ix2>,
776    target: Array<F, D>,
777    alpha: F,
778}
779
780type LogisticRegressionProblem1<'a, F, A> = LogisticRegressionProblem<'a, F, A, Ix1>;
781type LogisticRegressionProblem2<'a, F, A> = LogisticRegressionProblem<'a, F, A, Ix2>;
782
783impl<F: Float, A: Data<Elem = F>> CostFunction for LogisticRegressionProblem1<'_, F, A> {
784    type Param = ArgminParam<F, Ix1>;
785    type Output = F;
786
787    /// Apply the cost function to a parameter `p`
788    fn cost(&self, p: &Self::Param) -> std::result::Result<Self::Output, argmin::core::Error> {
789        let w = p.as_array();
790        let cost = logistic_loss(self.x, &self.target, self.alpha, w);
791        Ok(cost)
792    }
793}
794
795impl<F: Float, A: Data<Elem = F>> Gradient for LogisticRegressionProblem1<'_, F, A> {
796    type Param = ArgminParam<F, Ix1>;
797    type Gradient = ArgminParam<F, Ix1>;
798
799    /// Compute the gradient at parameter `p`.
800    fn gradient(&self, p: &Self::Param) -> std::result::Result<Self::Param, argmin::core::Error> {
801        let w = p.as_array();
802        let grad = ArgminParam(logistic_grad(self.x, &self.target, self.alpha, w));
803        Ok(grad)
804    }
805}
806
807impl<F: Float, A: Data<Elem = F>> CostFunction for LogisticRegressionProblem2<'_, F, A> {
808    type Param = ArgminParam<F, Ix2>;
809    type Output = F;
810
811    /// Apply the cost function to a parameter `p`
812    fn cost(&self, p: &Self::Param) -> std::result::Result<Self::Output, argmin::core::Error> {
813        let w = p.as_array();
814        let cost = multi_logistic_loss(self.x, &self.target, self.alpha, w);
815        Ok(cost)
816    }
817}
818
819impl<F: Float, A: Data<Elem = F>> Gradient for LogisticRegressionProblem2<'_, F, A> {
820    type Param = ArgminParam<F, Ix2>;
821    type Gradient = ArgminParam<F, Ix2>;
822
823    /// Compute the gradient at parameter `p`.
824    fn gradient(&self, p: &Self::Param) -> std::result::Result<Self::Param, argmin::core::Error> {
825        let w = p.as_array();
826        let grad = ArgminParam(multi_logistic_grad(self.x, &self.target, self.alpha, w));
827        Ok(grad)
828    }
829}
830
831trait SolvableProblem<F: Float, D: Dimension>: Gradient + Sized {
832    type Solver: Solver<Self, IterStateType<F, D>>;
833}
834
835impl<F: Float, A: Data<Elem = F>> SolvableProblem<F, Ix1> for LogisticRegressionProblem1<'_, F, A> {
836    type Solver = LBFGSType1<F>;
837}
838
839impl<F: Float, A: Data<Elem = F>> SolvableProblem<F, Ix2> for LogisticRegressionProblem2<'_, F, A> {
840    type Solver = LBFGSType2<F>;
841}
842
843#[cfg(test)]
844mod test {
845    extern crate linfa;
846
847    use super::Error;
848    use super::*;
849    use approx::{assert_abs_diff_eq, assert_relative_eq, AbsDiffEq};
850    use linfa::prelude::*;
851    use ndarray::{array, Array2, Dim, Ix};
852
853    #[test]
854    fn autotraits() {
855        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
856        has_autotraits::<LogisticRegressionParams<f64, Dim<[Ix; 0]>>>();
857        has_autotraits::<LogisticRegressionValidParams<f64, Dim<[Ix; 0]>>>();
858        has_autotraits::<ArgminParam<f64, Dim<[Ix; 0]>>>();
859    }
860
861    /// Test that the logistic loss function works as expected.
862    /// The expected values were obtained from running sklearn's
863    /// _logistic_loss_and_grad function.
864    #[test]
865    fn test_logistic_loss() {
866        let x = array![
867            [0.0],
868            [1.0],
869            [2.0],
870            [3.0],
871            [4.0],
872            [5.0],
873            [6.0],
874            [7.0],
875            [8.0],
876            [9.0]
877        ];
878        let y = array![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
879        let ws = vec![
880            array![0.0, 0.0],
881            array![0.0, 1.0],
882            array![1.0, 0.0],
883            array![1.0, 1.0],
884            array![0.0, -1.0],
885            array![-1.0, 0.0],
886            array![-1.0, -1.0],
887        ];
888        let alphas = &[0.0, 1.0, 10.0];
889        let expecteds = vec![
890            6.931471805599453,
891            6.931471805599453,
892            6.931471805599453,
893            4.652158847349118,
894            4.652158847349118,
895            4.652158847349118,
896            2.8012999588008323,
897            3.3012999588008323,
898            7.801299958800833,
899            2.783195429782239,
900            3.283195429782239,
901            7.783195429782239,
902            10.652158847349117,
903            10.652158847349117,
904            10.652158847349117,
905            41.80129995880083,
906            42.30129995880083,
907            46.80129995880083,
908            47.78319542978224,
909            48.28319542978224,
910            52.78319542978224,
911        ];
912
913        for ((w, alpha), exp) in ws
914            .iter()
915            .flat_map(|w| alphas.iter().map(move |&alpha| (w, alpha)))
916            .zip(&expecteds)
917        {
918            assert_abs_diff_eq!(logistic_loss(&x, &y, alpha, w), *exp);
919        }
920    }
921
922    /// Test that the logistic grad function works as expected.
923    /// The expected values were obtained from running sklearn's
924    /// _logistic_loss_and_grad function.
925    #[test]
926    fn test_logistic_grad() {
927        let x = array![
928            [0.0],
929            [1.0],
930            [2.0],
931            [3.0],
932            [4.0],
933            [5.0],
934            [6.0],
935            [7.0],
936            [8.0],
937            [9.0]
938        ];
939        let y = array![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
940        let ws = vec![
941            array![0.0, 0.0],
942            array![0.0, 1.0],
943            array![1.0, 0.0],
944            array![1.0, 1.0],
945            array![0.0, -1.0],
946            array![-1.0, 0.0],
947            array![-1.0, -1.0],
948        ];
949        let alphas = &[0.0, 1.0, 10.0];
950        let expecteds = vec![
951            array![-19.5, -3.],
952            array![-19.5, -3.],
953            array![-19.5, -3.],
954            array![-10.48871543, -1.61364853],
955            array![-10.48871543, -1.61364853],
956            array![-10.48871543, -1.61364853],
957            array![-0.13041554, -0.02852148],
958            array![0.86958446, -0.02852148],
959            array![9.86958446, -0.02852148],
960            array![-0.04834401, -0.01058067],
961            array![0.95165599, -0.01058067],
962            array![9.95165599, -0.01058067],
963            array![-28.51128457, -4.38635147],
964            array![-28.51128457, -4.38635147],
965            array![-28.51128457, -4.38635147],
966            array![-38.86958446, -5.97147852],
967            array![-39.86958446, -5.97147852],
968            array![-48.86958446, -5.97147852],
969            array![-38.95165599, -5.98941933],
970            array![-39.95165599, -5.98941933],
971            array![-48.95165599, -5.98941933],
972        ];
973
974        for ((w, alpha), exp) in ws
975            .iter()
976            .flat_map(|w| alphas.iter().map(move |&alpha| (w, alpha)))
977            .zip(&expecteds)
978        {
979            let actual = logistic_grad(&x, &y, alpha, w);
980            assert!(actual.abs_diff_eq(exp, 1e-8));
981        }
982    }
983
984    #[test]
985    fn simple_example_1() {
986        let log_reg = LogisticRegression::default();
987        let x = array![[-1.0], [-0.01], [0.01], [1.0]];
988        let y = array![0, 0, 1, 1];
989        let dataset = Dataset::new(x, y);
990        let res = log_reg.fit(&dataset).unwrap();
991        assert_abs_diff_eq!(res.intercept(), 0.0);
992        assert!(res.params().abs_diff_eq(&array![-0.681], 1e-3));
993        assert_eq!(
994            &res.predict(dataset.records()),
995            dataset.targets().as_single_targets()
996        );
997    }
998
999    #[test]
1000    fn simple_example_1_cats_dogs() {
1001        let log_reg = LogisticRegression::default();
1002        let x = array![[0.01], [1.0], [-1.0], [-0.01]];
1003        let y = array!["dog", "dog", "cat", "cat"];
1004        let dataset = Dataset::new(x, y);
1005        let res = log_reg.fit(&dataset).unwrap();
1006        assert_abs_diff_eq!(res.intercept(), 0.0);
1007        assert!(res.params().abs_diff_eq(&array![0.681], 1e-3));
1008        assert!(res
1009            .predict_probabilities(dataset.records())
1010            .abs_diff_eq(&array![0.501, 0.664, 0.335, 0.498], 1e-3));
1011        assert_eq!(
1012            &res.predict(dataset.records()),
1013            dataset.targets().as_single_targets()
1014        );
1015        assert_eq!(res.labels().pos.class, "dog");
1016        assert_eq!(res.labels().neg.class, "cat");
1017    }
1018
1019    #[test]
1020    fn simple_example_2() {
1021        let log_reg = LogisticRegression::default().alpha(1.0);
1022        let x = array![
1023            [0.0],
1024            [1.0],
1025            [2.0],
1026            [3.0],
1027            [4.0],
1028            [5.0],
1029            [6.0],
1030            [7.0],
1031            [8.0],
1032            [9.0]
1033        ];
1034        let y = array![0, 0, 0, 0, 1, 1, 1, 1, 1, 1];
1035        let dataset = Dataset::new(x, y);
1036        let res = log_reg.fit(&dataset).unwrap();
1037        assert_eq!(
1038            &res.predict(dataset.records()),
1039            dataset.targets().as_single_targets()
1040        );
1041    }
1042
1043    #[test]
1044    fn simple_example_3() {
1045        let x = array![[1.0], [0.0], [1.0], [0.0]];
1046        let y = array![1, 0, 1, 0];
1047        let dataset = DatasetBase::new(x, y);
1048        let model = LogisticRegression::default().fit(&dataset).unwrap();
1049
1050        let pred = model.predict(&dataset.records);
1051        assert_eq!(dataset.targets(), pred);
1052    }
1053
1054    #[test]
1055    fn rejects_mismatching_x_y() {
1056        let log_reg = LogisticRegression::default();
1057        let x = array![[-1.0], [-0.01], [0.01]];
1058        let y = array![0, 0, 1, 1];
1059        let res = log_reg.fit(&Dataset::new(x, y));
1060        assert!(matches!(res.unwrap_err(), Error::MismatchedShapes(3, 4)));
1061    }
1062
1063    #[test]
1064    fn rejects_inf_values() {
1065        let infs = &[f64::INFINITY, f64::NEG_INFINITY, f64::NAN];
1066        let inf_xs: Vec<_> = infs.iter().map(|&inf| array![[1.0], [inf]]).collect();
1067        let log_reg = LogisticRegression::default();
1068        let normal_x = array![[-1.0], [1.0]];
1069        let y = array![0, 1];
1070        for inf_x in &inf_xs {
1071            let res = log_reg.fit(&DatasetBase::new(inf_x.view(), &y));
1072            assert!(matches!(res.unwrap_err(), Error::InvalidValues));
1073        }
1074        for inf in infs {
1075            let log_reg = LogisticRegression::default().alpha(*inf);
1076            let res = log_reg.fit(&DatasetBase::new(normal_x.view(), &y));
1077            assert!(matches!(res.unwrap_err(), Error::InvalidAlpha));
1078        }
1079        let mut non_positives = infs.to_vec();
1080        non_positives.push(-1.0);
1081        non_positives.push(0.0);
1082        for inf in &non_positives {
1083            let log_reg = LogisticRegression::default().gradient_tolerance(*inf);
1084            let res = log_reg.fit(&Dataset::new(normal_x.to_owned(), y.to_owned()));
1085            assert!(matches!(res.unwrap_err(), Error::InvalidGradientTolerance));
1086        }
1087    }
1088
1089    #[test]
1090    fn validates_initial_params() {
1091        let infs = &[f64::INFINITY, f64::NEG_INFINITY, f64::NAN];
1092        let normal_x = array![[-1.0], [1.0]];
1093        let normal_y = array![0, 1];
1094        let dataset = Dataset::new(normal_x, normal_y);
1095        for inf in infs {
1096            let log_reg = LogisticRegression::default().initial_params(array![*inf, 0.0]);
1097            let res = log_reg.fit(&dataset);
1098            assert!(matches!(res.unwrap_err(), Error::InvalidInitialParameters));
1099        }
1100        {
1101            let log_reg = LogisticRegression::default().initial_params(array![0.0, 0.0, 0.0]);
1102            let res = log_reg.fit(&dataset);
1103            assert!(matches!(
1104                res.unwrap_err(),
1105                Error::InitialParameterFeaturesMismatch {
1106                    rows: 3,
1107                    n_features: 2
1108                }
1109            ));
1110        }
1111        {
1112            let log_reg = LogisticRegression::default()
1113                .with_intercept(false)
1114                .initial_params(array![0.0, 0.0]);
1115            let res = log_reg.fit(&dataset);
1116            assert!(matches!(
1117                res.unwrap_err(),
1118                Error::InitialParameterFeaturesMismatch {
1119                    rows: 2,
1120                    n_features: 1
1121                }
1122            ));
1123        }
1124    }
1125
1126    #[test]
1127    fn uses_initial_params() {
1128        let params = array![1.2, -4.12];
1129        let log_reg = LogisticRegression::default()
1130            .initial_params(params)
1131            .max_iterations(5);
1132        let x = array![
1133            [0.0],
1134            [1.0],
1135            [2.0],
1136            [3.0],
1137            [4.0],
1138            [5.0],
1139            [6.0],
1140            [7.0],
1141            [8.0],
1142            [9.0]
1143        ];
1144        let y = array![0, 0, 0, 0, 1, 1, 1, 1, 1, 1];
1145        let dataset = Dataset::new(x, y);
1146        let res = log_reg.fit(&dataset).unwrap();
1147        assert!(res.intercept().abs_diff_eq(&-4.124, 1e-3));
1148        assert!(res.params().abs_diff_eq(&array![1.181], 1e-3));
1149        assert_eq!(
1150            &res.predict(dataset.records()),
1151            dataset.targets().as_single_targets()
1152        );
1153
1154        // Test serialization
1155        #[cfg(feature = "serde")]
1156        {
1157            let ser = rmp_serde::to_vec(&res).unwrap();
1158            let unser: FittedLogisticRegression<f32, f32> = rmp_serde::from_slice(&ser).unwrap();
1159
1160            let x = array![[1.0]];
1161            let y_hat = unser.predict(&x);
1162
1163            assert!(y_hat[0] == 0.0);
1164        }
1165    }
1166
1167    #[test]
1168    fn works_with_f32() {
1169        let log_reg = LogisticRegression::default();
1170        let x: Array2<f32> = array![[-1.0], [-0.01], [0.01], [1.0]];
1171        let y = array![0, 0, 1, 1];
1172        let dataset = Dataset::new(x, y);
1173        let res = log_reg.fit(&dataset).unwrap();
1174        assert_abs_diff_eq!(res.intercept(), 0.0_f32);
1175        assert!(res.params().abs_diff_eq(&array![-0.682_f32], 1e-3));
1176        assert_eq!(
1177            &res.predict(dataset.records()),
1178            dataset.targets().as_single_targets()
1179        );
1180    }
1181
1182    #[test]
1183    fn test_log_sum_exp() {
1184        let data = array![[3.3, 0.4, -2.1], [0.4, 2.2, -0.1], [1., 0., -1.]];
1185        let out = log_sum_exp(&data, Axis(1));
1186        assert_abs_diff_eq!(out, array![3.35783, 2.43551, 1.40761], epsilon = 1e-5);
1187    }
1188
1189    #[test]
1190    fn test_softmax() {
1191        let mut data = array![3.3, 5.5, 0.1, -4.4, 8.0];
1192        softmax_inplace(&mut data);
1193        assert_relative_eq!(
1194            data,
1195            array![0.0083324, 0.075200047, 0.000339647, 0.000003773, 0.91612413],
1196            epsilon = 1e-8
1197        );
1198        assert_abs_diff_eq!(data.sum(), 1.0);
1199    }
1200
1201    #[test]
1202    fn test_multi_logistic_loss_grad() {
1203        let x = array![
1204            [0.0, 0.5],
1205            [1.0, -1.0],
1206            [2.0, -2.0],
1207            [3.0, -3.0],
1208            [4.0, -4.0],
1209            [5.0, -5.0],
1210            [6.0, -6.0],
1211            [7.0, -7.0],
1212        ];
1213        let y = array![
1214            [1.0, 0.0, 0.0],
1215            [1.0, 0.0, 0.0],
1216            [0.0, 1.0, 0.0],
1217            [0.0, 1.0, 0.0],
1218            [0.0, 1.0, 0.0],
1219            [0.0, 0.0, 1.0],
1220            [0.0, 0.0, 1.0],
1221            [0.0, 0.0, 1.0],
1222        ];
1223        let params1 = array![[4.4, -1.2, 3.3], [3.4, 0.1, 0.0]];
1224        let params2 = array![[0.001, -3.2, 2.9], [0.1, 4.5, 5.7], [4.5, 2.2, 1.7]];
1225        let alpha = 0.6;
1226
1227        {
1228            let (log_prob, w) = multi_logistic_prob_params(&x, &params1);
1229            assert_abs_diff_eq!(
1230                log_prob,
1231                array![
1232                    [-3.18259845e-01, -1.96825985e+00, -2.01825985e+00],
1233                    [-2.40463987e+00, -4.70463987e+00, -1.04639868e-01],
1234                    [-4.61010168e+00, -9.21010168e+00, -1.01016809e-02],
1235                    [-6.90100829e+00, -1.38010083e+01, -1.00829256e-03],
1236                    [-9.20010104e+00, -1.84001010e+01, -1.01044506e-04],
1237                    [-1.15000101e+01, -2.30000101e+01, -1.01301449e-05],
1238                    [-1.38000010e+01, -2.76000010e+01, -1.01563199e-06],
1239                    [-1.61000001e+01, -3.22000001e+01, -1.01826043e-07],
1240                ],
1241                epsilon = 1e-6
1242            );
1243            assert_abs_diff_eq!(w, params1);
1244            let loss = multi_logistic_loss(&x, &y, alpha, &params1);
1245            assert_abs_diff_eq!(loss, 57.11212197835295, epsilon = 1e-6);
1246            let grad = multi_logistic_grad(&x, &y, alpha, &params1);
1247            assert_abs_diff_eq!(
1248                grad,
1249                array![
1250                    [1.7536815, -9.71074369, 11.85706219],
1251                    [2.79002537, 9.12059357, -9.81061893]
1252                ],
1253                epsilon = 1e-6
1254            );
1255        }
1256
1257        {
1258            let (log_prob, w) = multi_logistic_prob_params(&x, &params2);
1259            assert_abs_diff_eq!(
1260                log_prob,
1261                array![
1262                    [-1.06637742e+00, -1.16637742e+00, -1.06637742e+00],
1263                    [-4.12429463e-03, -9.90512429e+00, -5.50512429e+00],
1264                    [-2.74092305e-04, -1.75022741e+01, -8.20227409e+00],
1265                    [-1.84027855e-05, -2.51030184e+01, -1.09030184e+01],
1266                    [-1.23554225e-06, -3.27040012e+01, -1.36040012e+01],
1267                    [-8.29523046e-08, -4.03050001e+01, -1.63050001e+01],
1268                    [-5.56928016e-09, -4.79060000e+01, -1.90060000e+01],
1269                    [-3.73912013e-10, -5.55070000e+01, -2.17070000e+01]
1270                ],
1271                epsilon = 1e-6
1272            );
1273            assert_abs_diff_eq!(w, params2.slice(s![..params2.nrows() - 1, ..]));
1274            let loss = multi_logistic_loss(&x, &y, alpha, &params2);
1275            assert_abs_diff_eq!(loss, 154.8177958366479, epsilon = 1e-6);
1276            let grad = multi_logistic_grad(&x, &y, alpha, &params2);
1277            assert_abs_diff_eq!(
1278                grad,
1279                array![
1280                    [26.99587549, -10.91995003, -16.25532546],
1281                    [-27.26314882, 11.85569669, 21.58745213],
1282                    [5.33984376, -2.68845675, -2.65138701]
1283                ],
1284                epsilon = 1e-6
1285            );
1286        }
1287    }
1288
1289    #[test]
1290    fn simple_multi_example() {
1291        let x = array![[-1., 0.], [0., 1.], [1., 1.]];
1292        let y = array![2, 1, 0];
1293        let log_reg = MultiLogisticRegression::default()
1294            .alpha(0.1)
1295            .initial_params(Array::zeros((3, 3)));
1296        let dataset = Dataset::new(x, y);
1297        let res = log_reg.fit(&dataset).unwrap();
1298        assert_eq!(res.params().dim(), (2, 3));
1299        assert_eq!(res.intercept().dim(), 3);
1300        assert_eq!(
1301            &res.predict(dataset.records()),
1302            dataset.targets().as_single_targets()
1303        );
1304    }
1305
1306    #[test]
1307    fn simple_multi_example_2() {
1308        let x = array![[1.0], [0.0], [1.0], [0.0]];
1309        let y = array![1, 0, 1, 0];
1310        let dataset = DatasetBase::new(x, y);
1311        let model = MultiLogisticRegression::default().fit(&dataset).unwrap();
1312
1313        let pred = model.predict(&dataset.records);
1314        assert_eq!(dataset.targets(), pred);
1315    }
1316
1317    #[test]
1318    fn simple_multi_example_text() {
1319        let log_reg = MultiLogisticRegression::default().alpha(0.1);
1320        let x = array![[0.1], [1.0], [-1.0], [-0.1]];
1321        let y = array!["dog", "ape", "rocket", "cat"];
1322        let dataset = Dataset::new(x, y);
1323        let res = log_reg.fit(&dataset).unwrap();
1324        assert_eq!(res.params().dim(), (1, 4));
1325        assert_eq!(res.intercept().dim(), 4);
1326        assert_eq!(
1327            &res.predict(dataset.records()),
1328            dataset.targets().as_single_targets()
1329        );
1330    }
1331
1332    #[test]
1333    fn multi_on_binary_problem() {
1334        let log_reg = MultiLogisticRegression::default().alpha(1.0);
1335        let x = array![
1336            [0.0],
1337            [1.0],
1338            [2.0],
1339            [3.0],
1340            [4.0],
1341            [5.0],
1342            [6.0],
1343            [7.0],
1344            [8.0],
1345            [9.0]
1346        ];
1347        let y = array![0, 0, 0, 0, 1, 1, 1, 1, 1, 1];
1348        let dataset = Dataset::new(x, y);
1349        let res = log_reg.fit(&dataset).unwrap();
1350        assert_eq!(res.params().dim(), (1, 2));
1351        assert_eq!(res.intercept().dim(), 2);
1352        assert_eq!(
1353            &res.predict(dataset.records()),
1354            dataset.targets().as_single_targets()
1355        );
1356    }
1357
1358    #[test]
1359    fn reject_num_class_mismatch() {
1360        let n_samples = 4;
1361        let n_classes = 3;
1362        let n_features = 1;
1363        let x = Array2::<f64>::zeros((n_samples, n_features));
1364        let y = array![0, 1, 2, 0];
1365        let dataset = Dataset::new(x, y);
1366
1367        let log_reg = MultiLogisticRegression::default()
1368            .with_intercept(false)
1369            .initial_params(Array::zeros((n_features, n_classes - 1)));
1370        assert!(matches!(
1371            log_reg.fit(&dataset).unwrap_err(),
1372            Error::InitialParameterClassesMismatch {
1373                cols: 2,
1374                n_classes: 3,
1375            }
1376        ));
1377    }
1378}