1#![allow(non_snake_case)]
3use crate::error::{LinearError, Result};
4#[cfg(feature = "blas")]
5use linfa::dataset::{WithLapack, WithoutLapack};
6use linfa::Float;
7#[cfg(not(feature = "blas"))]
8use linfa_linalg::qr::LeastSquaresQrInto;
9use ndarray::{concatenate, s, Array, Array1, Array2, ArrayBase, Axis, Data, Ix1, Ix2};
10#[cfg(feature = "blas")]
11use ndarray_linalg::LeastSquaresSvdInto;
12#[cfg(feature = "serde")]
13use serde_crate::{Deserialize, Serialize};
14
15use linfa::dataset::{AsSingleTargets, DatasetBase};
16use linfa::traits::{Fit, PredictInplace};
17
18#[derive(Debug, Clone, PartialEq, Eq)]
19#[cfg_attr(
20 feature = "serde",
21 derive(Serialize, Deserialize),
22 serde(crate = "serde_crate")
23)]
24pub struct LinearRegression {
53 fit_intercept: bool,
54}
55
56#[derive(Debug, Clone, PartialEq)]
57#[cfg_attr(
58 feature = "serde",
59 derive(Serialize, Deserialize),
60 serde(crate = "serde_crate")
61)]
62pub struct FittedLinearRegression<F> {
64 intercept: F,
65 params: Array1<F>,
66}
67
68impl Default for LinearRegression {
69 fn default() -> Self {
70 LinearRegression::new()
71 }
72}
73
74impl LinearRegression {
76 pub fn new() -> LinearRegression {
79 LinearRegression {
80 fit_intercept: true,
81 }
82 }
83
84 pub fn with_intercept(mut self, intercept: bool) -> Self {
86 self.fit_intercept = intercept;
87 self
88 }
89}
90
91impl<F: Float, D: Data<Elem = F>, T: AsSingleTargets<Elem = F>>
92 Fit<ArrayBase<D, Ix2>, T, LinearError<F>> for LinearRegression
93{
94 type Object = FittedLinearRegression<F>;
95
96 fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object, F> {
107 let X = dataset.records();
108 let y = dataset.as_single_targets();
109
110 let (n_samples, _) = X.dim();
111
112 assert_eq!(y.dim(), n_samples);
114
115 if self.fit_intercept {
116 let X = concatenate(Axis(1), &[X.view(), Array2::ones((X.nrows(), 1)).view()]).unwrap();
117 let params: Array1<F> = solve_least_squares(X, y.to_owned())?;
118 let intercept = *params.last().unwrap();
119 let params = params.slice(s![..params.len() - 1]).to_owned();
120 Ok(FittedLinearRegression { intercept, params })
121 } else {
122 let (X, y) = (X.to_owned(), y.to_owned());
125
126 Ok(FittedLinearRegression {
127 intercept: F::cast(0),
128 params: solve_least_squares(X, y)?,
129 })
130 }
131 }
132}
133
134fn solve_least_squares<F>(mut X: Array<F, Ix2>, mut y: Array<F, Ix1>) -> Result<Array1<F>, F>
137where
138 F: Float,
139{
140 let (X, y) = (X.view_mut(), y.view_mut());
142
143 #[cfg(not(feature = "blas"))]
144 let out = X
145 .least_squares_into(y.insert_axis(Axis(1)))?
146 .remove_axis(Axis(1));
147 #[cfg(feature = "blas")]
148 let out = X
149 .with_lapack()
150 .least_squares_into(y.with_lapack())
151 .map(|x| x.solution)?
152 .without_lapack();
153 Ok(out)
154}
155
156impl<F: Float> FittedLinearRegression<F> {
159 pub fn params(&self) -> &Array1<F> {
161 &self.params
162 }
163
164 pub fn intercept(&self) -> F {
166 self.intercept
167 }
168}
169
170impl<F: Float, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Ix2>, Array1<F>>
171 for FittedLinearRegression<F>
172{
173 fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<F>) {
177 assert_eq!(
178 x.nrows(),
179 y.len(),
180 "The number of data points must match the number of output targets."
181 );
182
183 *y = x.dot(&self.params) + self.intercept;
184 }
185
186 fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<F> {
187 Array1::zeros(x.nrows())
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194 use approx::assert_abs_diff_eq;
195 use linfa::{traits::Predict, Dataset};
196 use ndarray::array;
197
198 #[test]
199 fn autotraits() {
200 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
201 has_autotraits::<FittedLinearRegression<f64>>();
202 has_autotraits::<LinearRegression>();
203 has_autotraits::<LinearError<f64>>();
204 }
205
206 #[test]
207 fn fits_a_line_through_two_dots() {
208 let lin_reg = LinearRegression::new();
209 let dataset = Dataset::new(array![[0f64], [1.]], array![1., 2.]);
210 let model = lin_reg.fit(&dataset).unwrap();
211 let result = model.predict(dataset.records());
212
213 assert_abs_diff_eq!(result, &array![1., 2.], epsilon = 1e-12);
214 }
215
216 #[test]
220 fn without_intercept_fits_line_through_origin() {
221 let lin_reg = LinearRegression::new().with_intercept(false);
222 let dataset = Dataset::new(array![[1.]], array![1.]);
223 let model = lin_reg.fit(&dataset).unwrap();
224 let result = model.predict(&array![[0.], [1.]]);
225
226 assert_abs_diff_eq!(result, &array![0., 1.], epsilon = 1e-12);
227 }
228
229 #[test]
235 fn fits_least_squares_line_through_two_dots() {
236 let lin_reg = LinearRegression::new().with_intercept(false);
237 let dataset = Dataset::new(array![[-1.], [1.]], array![1., 1.]);
238 let model = lin_reg.fit(&dataset).unwrap();
239 let result = model.predict(dataset.records());
240
241 assert_abs_diff_eq!(result, &array![0., 0.], epsilon = 1e-12);
242 }
243
244 #[test]
250 fn fits_least_squares_line_through_three_dots() {
251 let lin_reg = LinearRegression::new();
252 let dataset = Dataset::new(array![[0.], [1.], [2.]], array![0., 0., 2.]);
253 let model = lin_reg.fit(&dataset).unwrap();
254 let actual = model.predict(dataset.records());
255
256 assert_abs_diff_eq!(actual, array![-1. / 3., 2. / 3., 5. / 3.], epsilon = 1e-12);
257 }
258
259 #[test]
263 fn fits_three_parameters_through_three_dots() {
264 let lin_reg = LinearRegression::new();
265 let dataset = Dataset::new(array![[0f64, 0.], [1., 1.], [2., 4.]], array![1., 4., 9.]);
266 let model = lin_reg.fit(&dataset).unwrap();
267
268 assert_abs_diff_eq!(model.params(), &array![2., 1.], epsilon = 1e-12);
269 assert_abs_diff_eq!(model.intercept(), &1., epsilon = 1e-12);
270 }
271
272 #[test]
276 fn fits_four_parameters_through_four_dots() {
277 let lin_reg = LinearRegression::new();
278 let dataset = Dataset::new(
279 array![[0f64, 0., 0.], [1., 1., 1.], [2., 4., 8.], [3., 9., 27.]],
280 array![1., 8., 27., 64.],
281 );
282 let model = lin_reg.fit(&dataset).unwrap();
283
284 assert_abs_diff_eq!(model.params(), &array![3., 3., 1.], epsilon = 1e-12);
285 assert_abs_diff_eq!(model.intercept(), &1., epsilon = 1e-12);
286 }
287
288 #[test]
292 fn fits_three_parameters_through_three_dots_f32() {
293 let lin_reg = LinearRegression::new();
294 let dataset = Dataset::new(array![[0f64, 0.], [1., 1.], [2., 4.]], array![1., 4., 9.]);
295 let model = lin_reg.fit(&dataset).unwrap();
296
297 assert_abs_diff_eq!(model.params(), &array![2., 1.], epsilon = 1e-4);
298 assert_abs_diff_eq!(model.intercept(), &1., epsilon = 1e-6);
299 }
300
301 }