linfa_linear/glm/
mod.rs

1//! Generalized Linear Models (GLM)
2
3mod distribution;
4mod hyperparams;
5mod link;
6
7use crate::error::{LinearError, Result};
8use crate::float::Float;
9use argmin_math::{
10    ArgminAdd, ArgminDot, ArgminL1Norm, ArgminL2Norm, ArgminMinMax, ArgminMul, ArgminSignum,
11    ArgminSub, ArgminZero,
12};
13use distribution::TweedieDistribution;
14pub use hyperparams::TweedieRegressorParams;
15pub use hyperparams::TweedieRegressorValidParams;
16use linfa::dataset::AsSingleTargets;
17pub use link::Link;
18
19use argmin::core::{CostFunction, Executor, Gradient};
20use argmin::solver::linesearch::MoreThuenteLineSearch;
21use argmin::solver::quasinewton::LBFGS;
22use ndarray::{array, concatenate, s};
23use ndarray::{Array, Array1, ArrayBase, ArrayView1, ArrayView2, Axis, Data, Ix2};
24#[cfg(feature = "serde")]
25use serde_crate::{Deserialize, Serialize};
26
27use linfa::traits::*;
28use linfa::DatasetBase;
29
30impl<F: Float, D: Data<Elem = F>, T: AsSingleTargets<Elem = F>>
31    Fit<ArrayBase<D, Ix2>, T, LinearError<F>> for TweedieRegressorValidParams<F>
32where
33    Array1<F>: ArgminAdd<Array1<F>, Array1<F>>
34        + ArgminSub<Array1<F>, Array1<F>>
35        + ArgminSub<F, Array1<F>>
36        + ArgminAdd<F, Array1<F>>
37        + ArgminMul<F, Array1<F>>
38        + ArgminMul<Array1<F>, Array1<F>>
39        + ArgminDot<Array1<F>, F>
40        + ArgminL2Norm<F>
41        + ArgminL1Norm<F>
42        + ArgminSignum
43        + ArgminMinMax,
44    F: ArgminMul<Array1<F>, Array1<F>> + ArgminZero,
45{
46    type Object = TweedieRegressor<F>;
47
48    fn fit(&self, ds: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object, F> {
49        let (x, y) = (ds.records(), ds.as_single_targets());
50
51        let dist = TweedieDistribution::new(self.power())?;
52        let link = self.link();
53
54        // If link is not set we automatically select an appropriate
55        // link function
56
57        if !dist.in_range(&y) {
58            // An error is sent when y has values in the range not applicable
59            // for the distribution
60            return Err(LinearError::InvalidTargetRange(self.power()));
61        }
62        // We initialize the coefficients and intercept
63        let mut coef = Array::zeros(x.ncols());
64        if self.fit_intercept() {
65            let temp = link.link(&array![y.mean().unwrap()]);
66            coef = concatenate!(Axis(0), temp, coef);
67        }
68
69        // Constructing a struct that satisfies the requirements of the L-BFGS solver
70        // with functions implemented for the objective function and the parameter
71        // gradient
72        let problem = TweedieProblem {
73            x: x.view(),
74            y,
75            fit_intercept: self.fit_intercept(),
76            link: &link,
77            dist,
78            alpha: self.alpha(),
79        };
80        let linesearch = MoreThuenteLineSearch::new();
81
82        // L-BFGS maintains a history of the past m updates of the
83        // position x and gradient ∇f(x), where generally the history
84        // size m can be small (often m < 10)
85        // For our problem we set m as 7
86        let solver = LBFGS::new(linesearch, 7).with_tolerance_grad(F::cast(self.tol()))?;
87
88        let mut result = Executor::new(problem, solver)
89            .configure(|state| state.param(coef).max_iters(self.max_iter() as u64))
90            .run()?;
91        coef = result.state.take_best_param().unwrap();
92
93        if self.fit_intercept() {
94            Ok(TweedieRegressor {
95                coef: coef.slice(s![1..]).to_owned(),
96                intercept: *coef.get(0).unwrap(),
97                link,
98            })
99        } else {
100            Ok(TweedieRegressor {
101                coef: coef.to_owned(),
102                intercept: F::cast(0.),
103                link,
104            })
105        }
106    }
107}
108
109struct TweedieProblem<'a, F: Float> {
110    x: ArrayView2<'a, F>,
111    y: ArrayView1<'a, F>,
112    fit_intercept: bool,
113    link: &'a Link,
114    dist: TweedieDistribution<F>,
115    alpha: F,
116}
117
118impl<A: Float> TweedieProblem<'_, A> {
119    fn ypred(&self, p: &Array1<A>) -> (Array1<A>, Array1<A>, usize) {
120        let mut offset = 0;
121        let mut intercept = A::from(0.).unwrap();
122        if self.fit_intercept {
123            offset = 1;
124            intercept = *p.get(0).unwrap();
125        }
126
127        let lin_pred = self
128            .x
129            .view()
130            .dot(&p.slice(s![offset..]))
131            .mapv(|x| x + intercept);
132
133        (self.link.inverse(&lin_pred), lin_pred, offset)
134    }
135}
136
137impl<A: Float> CostFunction for TweedieProblem<'_, A> {
138    type Param = Array1<A>;
139    type Output = A;
140
141    // This function calculates the value of the objective function we are trying
142    // to minimize,
143    //
144    // 0.5 * (deviance(y, ypred) + alpha * |p|_2)
145    //
146    // - `p` is the parameter we are optimizing (coefficients and intercept)
147    // - `alpha` is the regularization hyperparameter
148    fn cost(&self, p: &Self::Param) -> std::result::Result<Self::Output, argmin::core::Error> {
149        let (ypred, _, offset) = self.ypred(p);
150
151        let dev = self.dist.deviance(self.y, ypred.view())?;
152
153        let pscaled = p
154            .slice(s![offset..])
155            .mapv(|x| x * A::from(self.alpha).unwrap());
156
157        let obj = A::from(0.5).unwrap() * (dev + p.slice(s![offset..]).dot(&pscaled));
158
159        Ok(obj)
160    }
161}
162
163impl<A: Float> Gradient for TweedieProblem<'_, A> {
164    type Param = Array1<A>;
165    type Gradient = Array1<A>;
166
167    fn gradient(&self, p: &Self::Param) -> std::result::Result<Self::Param, argmin::core::Error> {
168        let (ypred, lin_pred, offset) = self.ypred(p);
169
170        let devp;
171        let der = self.link.inverse_derviative(&lin_pred);
172        let temp = der * self.dist.deviance_derivative(self.y, ypred.view());
173        if self.fit_intercept {
174            devp = concatenate![Axis(0), array![temp.sum()], temp.dot(&self.x)];
175        } else {
176            devp = temp.dot(&self.x);
177        }
178
179        let pscaled = p
180            .slice(s![offset..])
181            .mapv(|x| x * A::from(self.alpha).unwrap());
182
183        let mut objp = devp.mapv(|x| x * A::from(0.5).unwrap());
184        objp.slice_mut(s![offset..])
185            .zip_mut_with(&pscaled, |x, y| *x += *y);
186
187        Ok(objp)
188    }
189}
190
191/// Generalized Linear Model (GLM) with a Tweedie distribution
192///
193/// The Regressor can be used to model different GLMs depending on
194/// [`power`](TweedieRegressorParams),
195/// which determines the underlying distribution.
196///
197/// | Power  | Distribution           |
198/// | ------ | ---------------------- |
199/// | 0      | Normal                 |
200/// | 1      | Poisson                |
201/// | (1, 2) | Compound Poisson Gamma |
202/// | 2      | Gamma                  |
203/// | 3      | Inverse Gaussian       |
204///
205/// NOTE: No distribution exists between 0 and 1
206///
207/// Learn more from sklearn's excellent [User Guide](https://scikit-learn.org/stable/modules/linear_model.html#generalized-linear-regression)
208///
209/// ## Examples
210///
211/// Here's an example on how to train a GLM on the `diabetes` dataset
212/// ```rust
213/// use linfa::traits::{Fit, Predict};
214/// use linfa_linear::TweedieRegressor;
215/// use linfa::prelude::SingleTargetRegression;
216///
217/// let dataset = linfa_datasets::diabetes();
218/// let model = TweedieRegressor::params().fit(&dataset).unwrap();
219/// let pred = model.predict(&dataset);
220/// let r2 = pred.r2(&dataset).unwrap();
221/// println!("r2 from prediction: {}", r2);
222/// ```
223#[derive(Debug, Clone, PartialEq)]
224#[cfg_attr(
225    feature = "serde",
226    derive(Serialize, Deserialize),
227    serde(crate = "serde_crate")
228)]
229pub struct TweedieRegressor<A> {
230    /// Estimated coefficients for the linear predictor
231    pub coef: Array1<A>,
232    /// Intercept or bias added to the linear model
233    pub intercept: A,
234    link: Link,
235}
236
237impl<A: Float, D: Data<Elem = A>> PredictInplace<ArrayBase<D, Ix2>, Array1<A>>
238    for TweedieRegressor<A>
239{
240    /// Predict the target
241    fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<A>) {
242        assert_eq!(
243            x.nrows(),
244            y.len(),
245            "The number of data points must match the number of output targets."
246        );
247
248        let ypred = x.dot(&self.coef) + self.intercept;
249        *y = self.link.inverse(&ypred);
250    }
251
252    fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<A> {
253        Array1::zeros(x.nrows())
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260    use crate::glm::hyperparams::TweedieRegressorParams;
261    use approx::assert_abs_diff_eq;
262    use linfa::Dataset;
263    use ndarray::{array, Array2};
264
265    #[test]
266    fn autotraits() {
267        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
268        has_autotraits::<TweedieRegressor<f64>>();
269        has_autotraits::<TweedieRegressorValidParams<f64>>();
270        has_autotraits::<TweedieRegressorParams<f64>>();
271    }
272
273    macro_rules! test_tweedie {
274        ($($name:ident: {power: $power:expr, intercept: $intercept:expr,},)*) => {
275            $(
276                #[test]
277                fn $name() {
278                    let coef = array![0.2, -0.1];
279                    let mut x: Array2<f64> = array![[1., 1., 1., 1., 1.], [0., 1., 2., 3., 4.]].reversed_axes();
280                    let y = x.dot(&coef).mapv(|x| x.exp());
281
282                    let glm = TweedieRegressor::params()
283                        .alpha(0.)
284                        .power($power)
285                        .link(Link::Log)
286                        .tol(1e-7)
287                        .fit_intercept($intercept);
288
289                    if $intercept {
290                        x = x.slice(s![.., 1..]).to_owned();
291                        let dataset = Dataset::new(x, y);
292                        let glm = glm.fit(&dataset).unwrap();
293
294                        assert_abs_diff_eq!(glm.intercept, coef.get(0).unwrap(), epsilon = 1e-3);
295                        assert_abs_diff_eq!(glm.coef, coef.slice(s![1..]), epsilon = 1e-3);
296                    } else {
297                        let dataset = Dataset::new(x, y);
298                        let glm = glm.fit(&dataset).unwrap();
299
300                        assert_abs_diff_eq!(glm.coef, coef, epsilon = 1e-3);
301                    }
302                }
303            )*
304        }
305    }
306
307    test_tweedie! {
308        test_glm_normal1: {
309            power: 0.,
310            intercept: true,
311        },
312        test_glm_normal2: {
313            power: 0.,
314            intercept: false,
315        },
316        test_glm_poisson1: {
317            power: 1.,
318            intercept: true,
319        },
320        test_glm_poisson2: {
321            power: 1.,
322            intercept: false,
323        },
324        test_glm_gamma1: {
325            power: 2.,
326            intercept: true,
327        },
328        test_glm_gamma2: {
329            power: 2.,
330            intercept: false,
331        },
332        test_glm_inverse_gaussian1: {
333            power: 3.,
334            intercept: true,
335        },
336        test_glm_inverse_gaussian2: {
337            power: 3.,
338            intercept: false,
339        },
340        test_glm_tweedie1: {
341            power: 1.5,
342            intercept: true,
343        },
344        test_glm_tweedie2: {
345            power: 1.5,
346            intercept: false,
347        },
348    }
349}