linfa_linear/
ols.rs

1//! Ordinary Least Squares
2#![allow(non_snake_case)]
3use crate::error::{LinearError, Result};
4#[cfg(feature = "blas")]
5use linfa::dataset::{WithLapack, WithoutLapack};
6use linfa::Float;
7#[cfg(not(feature = "blas"))]
8use linfa_linalg::qr::LeastSquaresQrInto;
9use ndarray::{concatenate, s, Array, Array1, Array2, ArrayBase, Axis, Data, Ix1, Ix2};
10#[cfg(feature = "blas")]
11use ndarray_linalg::LeastSquaresSvdInto;
12#[cfg(feature = "serde")]
13use serde_crate::{Deserialize, Serialize};
14
15use linfa::dataset::{AsSingleTargets, DatasetBase};
16use linfa::traits::{Fit, PredictInplace};
17
18#[derive(Debug, Clone, PartialEq, Eq)]
19#[cfg_attr(
20    feature = "serde",
21    derive(Serialize, Deserialize),
22    serde(crate = "serde_crate")
23)]
24/// An ordinary least squares univariate linear regression model.
25///
26/// Given predictors `x` and responses `y` ordinary least squares linear
27/// regression estimates a model of the form
28///
29/// `y = xW + b`
30///
31/// by finding a matrix `W` and a vector `b` which minimize the sum of the
32/// squared L_2 norms `||y_j - x_jW - b||_2^2` for a dataset
33/// `{(x_j, y_j) for j in 1..=n_samples}`.
34///
35/// The algorithm is only implemented for _univariate_ regression. This means
36/// that `b` and `y` are scalars and `W` is just one column.
37///
38/// ## Examples
39///
40/// Here's an example on how to train a linear regression model on the `diabetes` dataset
41/// ```rust
42/// use linfa::traits::{Fit, Predict};
43/// use linfa_linear::LinearRegression;
44/// use linfa::prelude::SingleTargetRegression;
45///
46/// let dataset = linfa_datasets::diabetes();
47/// let model = LinearRegression::default().fit(&dataset).unwrap();
48/// let pred = model.predict(&dataset);
49/// let r2 = pred.r2(&dataset).unwrap();
50/// println!("r2 from prediction: {}", r2);
51/// ```
52pub struct LinearRegression {
53    fit_intercept: bool,
54}
55
56#[derive(Debug, Clone, PartialEq)]
57#[cfg_attr(
58    feature = "serde",
59    derive(Serialize, Deserialize),
60    serde(crate = "serde_crate")
61)]
62/// A fitted linear regression model which can be used for making predictions.
63pub struct FittedLinearRegression<F> {
64    intercept: F,
65    params: Array1<F>,
66}
67
68impl Default for LinearRegression {
69    fn default() -> Self {
70        LinearRegression::new()
71    }
72}
73
74/// Configure and fit a linear regression model
75impl LinearRegression {
76    /// Create a default linear regression model.
77    /// By default, an intercept will be fitted.
78    pub fn new() -> LinearRegression {
79        LinearRegression {
80            fit_intercept: true,
81        }
82    }
83
84    /// Configure the linear regression model to fit an intercept.
85    pub fn with_intercept(mut self, intercept: bool) -> Self {
86        self.fit_intercept = intercept;
87        self
88    }
89}
90
91impl<F: Float, D: Data<Elem = F>, T: AsSingleTargets<Elem = F>>
92    Fit<ArrayBase<D, Ix2>, T, LinearError<F>> for LinearRegression
93{
94    type Object = FittedLinearRegression<F>;
95
96    /// Fit a linear regression model given a feature matrix `X` and a target
97    /// variable `y`.
98    ///
99    /// The feature matrix `X` must have shape `(n_samples, n_features)`
100    ///
101    /// The target variable `y` must have shape `(n_samples)`
102    ///
103    /// Returns a `FittedLinearRegression` object which contains the fitted
104    /// parameters and can be used to `predict` values of the target variable
105    /// for new feature values.
106    fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object, F> {
107        let X = dataset.records();
108        let y = dataset.as_single_targets();
109
110        let (n_samples, _) = X.dim();
111
112        // Check that our inputs have compatible shapes
113        assert_eq!(y.dim(), n_samples);
114
115        if self.fit_intercept {
116            let X = concatenate(Axis(1), &[X.view(), Array2::ones((X.nrows(), 1)).view()]).unwrap();
117            let params: Array1<F> = solve_least_squares(X, y.to_owned())?;
118            let intercept = *params.last().unwrap();
119            let params = params.slice(s![..params.len() - 1]).to_owned();
120            Ok(FittedLinearRegression { intercept, params })
121        } else {
122            // `LeastSquaresSvdInto` needs a mutable reference to the data and `dataset` is taken
123            // by reference. Therefore copy the problem matrix and target vector.
124            let (X, y) = (X.to_owned(), y.to_owned());
125
126            Ok(FittedLinearRegression {
127                intercept: F::cast(0),
128                params: solve_least_squares(X, y)?,
129            })
130        }
131    }
132}
133
134/// Find the b that minimizes the 2-norm of X b - y
135/// by using the least_squares solver from ndarray-linalg
136fn solve_least_squares<F>(mut X: Array<F, Ix2>, mut y: Array<F, Ix1>) -> Result<Array1<F>, F>
137where
138    F: Float,
139{
140    // ensure that B = C
141    let (X, y) = (X.view_mut(), y.view_mut());
142
143    #[cfg(not(feature = "blas"))]
144    let out = X
145        .least_squares_into(y.insert_axis(Axis(1)))?
146        .remove_axis(Axis(1));
147    #[cfg(feature = "blas")]
148    let out = X
149        .with_lapack()
150        .least_squares_into(y.with_lapack())
151        .map(|x| x.solution)?
152        .without_lapack();
153    Ok(out)
154}
155
156/// View the fitted parameters and make predictions with a fitted
157/// linear regresssion model.
158impl<F: Float> FittedLinearRegression<F> {
159    /// Get the fitted parameters
160    pub fn params(&self) -> &Array1<F> {
161        &self.params
162    }
163
164    /// Get the fitted intercept, 0. if no intercept was fitted
165    pub fn intercept(&self) -> F {
166        self.intercept
167    }
168}
169
170impl<F: Float, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Ix2>, Array1<F>>
171    for FittedLinearRegression<F>
172{
173    /// Given an input matrix `X`, with shape `(n_samples, n_features)`,
174    /// `predict` returns the target variable according to linear model
175    /// learned from the training data distribution.
176    fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<F>) {
177        assert_eq!(
178            x.nrows(),
179            y.len(),
180            "The number of data points must match the number of output targets."
181        );
182
183        *y = x.dot(&self.params) + self.intercept;
184    }
185
186    fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<F> {
187        Array1::zeros(x.nrows())
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194    use approx::assert_abs_diff_eq;
195    use linfa::{traits::Predict, Dataset};
196    use ndarray::array;
197
198    #[test]
199    fn autotraits() {
200        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
201        has_autotraits::<FittedLinearRegression<f64>>();
202        has_autotraits::<LinearRegression>();
203        has_autotraits::<LinearError<f64>>();
204    }
205
206    #[test]
207    fn fits_a_line_through_two_dots() {
208        let lin_reg = LinearRegression::new();
209        let dataset = Dataset::new(array![[0f64], [1.]], array![1., 2.]);
210        let model = lin_reg.fit(&dataset).unwrap();
211        let result = model.predict(dataset.records());
212
213        assert_abs_diff_eq!(result, &array![1., 2.], epsilon = 1e-12);
214    }
215
216    /// When `with_intercept` is set to false, the
217    /// fitted line runs through the origin. For a perfect
218    /// fit we only need to provide one point.
219    #[test]
220    fn without_intercept_fits_line_through_origin() {
221        let lin_reg = LinearRegression::new().with_intercept(false);
222        let dataset = Dataset::new(array![[1.]], array![1.]);
223        let model = lin_reg.fit(&dataset).unwrap();
224        let result = model.predict(&array![[0.], [1.]]);
225
226        assert_abs_diff_eq!(result, &array![0., 1.], epsilon = 1e-12);
227    }
228
229    /// We can't fit a line through two points without fitting the
230    /// intercept in general. In this case we should find the solution
231    /// that minimizes the squares. Fitting a line with intercept through
232    /// the points (-1, 1), (1, 1) has the least-squares solution
233    /// f(x) = 0
234    #[test]
235    fn fits_least_squares_line_through_two_dots() {
236        let lin_reg = LinearRegression::new().with_intercept(false);
237        let dataset = Dataset::new(array![[-1.], [1.]], array![1., 1.]);
238        let model = lin_reg.fit(&dataset).unwrap();
239        let result = model.predict(dataset.records());
240
241        assert_abs_diff_eq!(result, &array![0., 0.], epsilon = 1e-12);
242    }
243
244    /// We can't fit a line through three points in general
245    /// - in this case we should find the solution that minimizes
246    ///   the squares. Fitting a line with intercept through the
247    ///   points (0, 0), (1, 0), (2, 2) has the least-squares solution
248    ///   f(x) = -1./3. + x
249    #[test]
250    fn fits_least_squares_line_through_three_dots() {
251        let lin_reg = LinearRegression::new();
252        let dataset = Dataset::new(array![[0.], [1.], [2.]], array![0., 0., 2.]);
253        let model = lin_reg.fit(&dataset).unwrap();
254        let actual = model.predict(dataset.records());
255
256        assert_abs_diff_eq!(actual, array![-1. / 3., 2. / 3., 5. / 3.], epsilon = 1e-12);
257    }
258
259    /// Check that the linear regression prefectly fits three datapoints for
260    /// the model
261    /// f(x) = (x + 1)^2 = x^2 + 2x + 1
262    #[test]
263    fn fits_three_parameters_through_three_dots() {
264        let lin_reg = LinearRegression::new();
265        let dataset = Dataset::new(array![[0f64, 0.], [1., 1.], [2., 4.]], array![1., 4., 9.]);
266        let model = lin_reg.fit(&dataset).unwrap();
267
268        assert_abs_diff_eq!(model.params(), &array![2., 1.], epsilon = 1e-12);
269        assert_abs_diff_eq!(model.intercept(), &1., epsilon = 1e-12);
270    }
271
272    /// Check that the linear regression prefectly fits four datapoints for
273    /// the model
274    /// f(x) = (x + 1)^3 = x^3 + 3x^2 + 3x + 1
275    #[test]
276    fn fits_four_parameters_through_four_dots() {
277        let lin_reg = LinearRegression::new();
278        let dataset = Dataset::new(
279            array![[0f64, 0., 0.], [1., 1., 1.], [2., 4., 8.], [3., 9., 27.]],
280            array![1., 8., 27., 64.],
281        );
282        let model = lin_reg.fit(&dataset).unwrap();
283
284        assert_abs_diff_eq!(model.params(), &array![3., 3., 1.], epsilon = 1e-12);
285        assert_abs_diff_eq!(model.intercept(), &1., epsilon = 1e-12);
286    }
287
288    /// Check that the linear regression prefectly fits three datapoints for
289    /// the model
290    /// f(x) = (x + 1)^2 = x^2 + 2x + 1
291    #[test]
292    fn fits_three_parameters_through_three_dots_f32() {
293        let lin_reg = LinearRegression::new();
294        let dataset = Dataset::new(array![[0f64, 0.], [1., 1.], [2., 4.]], array![1., 4., 9.]);
295        let model = lin_reg.fit(&dataset).unwrap();
296
297        assert_abs_diff_eq!(model.params(), &array![2., 1.], epsilon = 1e-4);
298        assert_abs_diff_eq!(model.intercept(), &1., epsilon = 1e-6);
299    }
300
301    ///// Check that the linear regression prefectly fits four datapoints for
302    ///// the model
303    ///// f(x) = (x + 1)^3 = x^3 + 3x^2 + 3x + 1
304    ///// when normalization is enabled
305    //#[test]
306    //fn fits_four_parameters_through_four_dots_with_normalization() {
307    //let lin_reg = LinearRegression::new().with_intercept_and_normalize();
308    //let dataset = Dataset::new(
309    //array![[0f64, 0., 0.], [1., 1., 1.], [2., 4., 8.], [3., 9., 27.]],
310    //array![1., 8., 27., 64.],
311    //);
312    //let model = lin_reg.fit(&dataset).unwrap();
313
314    //assert_abs_diff_eq!(model.params(), &array![3., 3., 1.], epsilon = 1e-12);
315    //assert_abs_diff_eq!(model.intercept(), 1., epsilon = 1e-12);
316    //}
317
318    ///// Check that the linear regression model works with both owned and view
319    ///// representations of arrays
320    //#[test]
321    //fn works_with_viewed_and_owned_representations() {
322    //let lin_reg = LinearRegression::new().with_intercept_and_normalize();
323    //let dataset = Dataset::new(
324    //array![[0., 0., 0.], [1., 1., 1.], [2., 4., 8.], [3., 9., 27.]],
325    //array![1., 8., 27., 64.],
326    //);
327    //let dataset_view = dataset.view();
328
329    //let model1 = lin_reg.fit(&dataset).expect("can't fit owned arrays");
330    //let model2 = lin_reg
331    //.fit(&dataset_view)
332    //.expect("can't fit feature view with owned target");
333
334    //assert_eq!(model1.params(), model2.params());
335    //assert_abs_diff_eq!(model1.intercept(), model2.intercept());
336    //}
337}