linfa_linear/glm/
distribution.rs

1use linfa::Float;
2use ndarray::Zip;
3use ndarray::{Array1, ArrayView1};
4
5use crate::error::{LinearError, Result};
6
7#[derive(Debug, Clone, PartialEq)]
8pub struct TweedieDistribution<F: Float> {
9    power: F,
10    lower_bound: F,
11    inclusive: bool,
12}
13
14impl<F: Float> TweedieDistribution<F> {
15    pub fn new(power: F) -> Result<Self, F> {
16        // Based on the `power` value, the lower bound of `y` is selected
17        let dist = match power {
18            power if power <= F::zero() => Self {
19                power,
20                lower_bound: F::neg_infinity(),
21                inclusive: false,
22            },
23            power if (power > F::zero() && power < F::one()) => {
24                return Err(LinearError::InvalidTweediePower(power));
25            }
26            power if (F::one()..F::cast(2.0)).contains(&power) => Self {
27                power,
28                lower_bound: F::zero(),
29                inclusive: true,
30            },
31            power if power >= F::cast(2.0) => Self {
32                power,
33                lower_bound: F::zero(),
34                inclusive: false,
35            },
36            _ => unreachable!(),
37        };
38
39        Ok(dist)
40    }
41
42    // Returns `true` if y is in the valid range
43    pub fn in_range(&self, y: &ArrayView1<F>) -> bool {
44        if self.inclusive {
45            return y.iter().all(|&x| x >= self.lower_bound);
46        }
47        y.iter().all(|&x| x > self.lower_bound)
48    }
49
50    fn unit_variance(&self, ypred: ArrayView1<F>) -> Array1<F> {
51        // ypred ^ power
52        ypred.mapv(|x| x.powf(self.power))
53    }
54
55    fn unit_deviance(&self, y: ArrayView1<F>, ypred: ArrayView1<F>) -> Result<Array1<F>, F> {
56        match self.power {
57            power if power.is_negative() => {
58                let mut left = y.mapv(|x| x.max(F::zero()));
59
60                left.mapv_inplace(|x| {
61                    x.powf(F::cast(2.) - self.power)
62                        / ((F::one() - self.power) * (F::cast(2.) - self.power))
63                });
64
65                let middle =
66                    &y * &ypred.mapv(|x| x.powf(F::cast(1.) - self.power) / (F::cast(1.) - power));
67
68                let right =
69                    ypred.mapv(|x| x.powf(F::cast(2.) - self.power) / (F::cast(2.) - self.power));
70
71                Ok((left - middle + right).mapv(|x| F::cast(2.) * x))
72            }
73            // Normal distribution
74            // (y - ypred)^2
75            power if power == F::zero() => Ok((&y - &ypred).mapv(|x| x * x)),
76            power if power < F::one() => Err(LinearError::InvalidTweediePower(power)),
77            // Poisson distribution
78            // 2 * (y * log(y / ypred) - y + ypred)
79            power if (power - F::one()).abs() < F::cast(1e-6) => {
80                let mut div = &y / &ypred;
81                Zip::from(&mut div).and(y).for_each(|y, &x| {
82                    if x == F::zero() {
83                        *y = F::zero();
84                    } else {
85                        *y = F::cast(2.) * (x * y.ln());
86                    }
87                });
88                Ok(div - y + ypred)
89            }
90            // Gamma distribution
91            // 2 * (log(ypred / y) + (y / ypred) - 1)
92            power if (power - F::cast(2.)).abs() < F::cast(1e-6) => {
93                let mut temp = (&ypred / &y).mapv(|x| x.ln()) + (&y / &ypred);
94                temp.mapv_inplace(|x| x - F::one());
95                Ok(temp.mapv(|x| F::cast(2.) * x))
96            }
97            power => {
98                let left = y.mapv(|x| {
99                    x.powf(F::cast(2.) - power) / ((F::one() - power) * (F::cast(2.) - power))
100                });
101
102                let middle = &y * &ypred.mapv(|x| x.powf(F::one() - power) / (F::one() - power));
103
104                let right = ypred.mapv(|x| x.powf(F::cast(2.) - power) / (F::cast(2.) - power));
105
106                Ok((left - middle + right).mapv(|x| F::cast(2.) * x))
107            }
108        }
109    }
110
111    fn unit_deviance_derivative(&self, y: ArrayView1<F>, ypred: ArrayView1<F>) -> Array1<F> {
112        ((&y - &ypred) / &self.unit_variance(ypred)).mapv(|x| F::cast(-2.) * x)
113    }
114
115    pub fn deviance(&self, y: ArrayView1<F>, ypred: ArrayView1<F>) -> Result<F, F> {
116        Ok(self.unit_deviance(y, ypred)?.sum())
117    }
118
119    pub fn deviance_derivative(&self, y: ArrayView1<F>, ypred: ArrayView1<F>) -> Array1<F> {
120        self.unit_deviance_derivative(y, ypred)
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    use approx::assert_abs_diff_eq;
128    use ndarray::array;
129
130    #[test]
131    fn autotraits() {
132        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
133        has_autotraits::<TweedieDistribution<f64>>();
134    }
135
136    #[test]
137    fn test_distribution_error() {
138        let tweedie = TweedieDistribution::new(0.2);
139        assert!(tweedie.is_err());
140    }
141
142    macro_rules! test_bounds {
143        ($($name:ident: ($dist:expr, $input:expr, $expected:expr),)*) => {
144            $(
145                #[test]
146                #[allow(clippy::bool_assert_comparison)]
147                fn $name() {
148                    let output = $dist.in_range(&$input.view());
149                    assert_eq!(output, $expected);
150                }
151            )*
152        };
153    }
154
155    test_bounds! {
156        test_bounds_normal: (TweedieDistribution::new(0.).unwrap(), array![-1., 0., 1.], true),
157        test_bounds_poisson1: (TweedieDistribution::new(1.).unwrap(), array![-1., 0., 1.], false),
158        test_bounds_poisson2: (TweedieDistribution::new(1.).unwrap(), array![0., 1., 2.], true),
159        test_bounds_tweedie1: (TweedieDistribution::new(1.5).unwrap(), array![-1., 0., 1.], false),
160        test_bounds_tweedie2: (TweedieDistribution::new(1.5).unwrap(), array![0., 1., 4.], true),
161        test_bounds_gamma1: (TweedieDistribution::new(2.).unwrap(), array![-1., 0., 1.], false),
162        test_bounds_gamma2: (TweedieDistribution::new(2.).unwrap(), array![0., 1., 2.], false),
163        test_bounds_gamma3: (TweedieDistribution::new(2.).unwrap(), array![1., 2., 3.], true),
164        test_bounds_inverse_gaussian: (TweedieDistribution::new(3.).unwrap(), array![-1., 0., 1.], false),
165        test_bounds_tweedie3: (TweedieDistribution::new(3.5).unwrap(), array![-1., 0., 1.], false),
166    }
167
168    macro_rules! test_deviance {
169        ($($name:ident: ($dist:expr, $input:expr),)*) => {
170            $(
171                #[test]
172                fn $name() {
173                    let output = $dist.deviance($input.view(), $input.view()).unwrap();
174                    assert_abs_diff_eq!(output, 0.0, epsilon=1e-9);
175                }
176            )*
177        }
178    }
179
180    test_deviance! {
181        test_deviance_normal: (TweedieDistribution::new(0.).unwrap(), array![-1.5, -0.1, 0.1, 2.5]),
182        test_deviance_poisson: (TweedieDistribution::new(1.).unwrap(), array![0.1, 1.5]),
183        test_deviance_gamma: (TweedieDistribution::new(2.).unwrap(), array![0.1, 1.5]),
184        test_deviance_inverse_gaussian: (TweedieDistribution::new(3.).unwrap(), array![0.1, 1.5]),
185        test_deviance_tweedie1: (TweedieDistribution::new(-2.5).unwrap(), array![0.1, 1.5]),
186        test_deviance_tweedie2: (TweedieDistribution::new(-1.).unwrap(), array![0.1, 1.5]),
187        test_deviance_tweedie3: (TweedieDistribution::new(1.5).unwrap(), array![0.1, 1.5]),
188        test_deviance_tweedie4: (TweedieDistribution::new(2.5).unwrap(), array![0.1, 1.5]),
189        test_deviance_tweedie5: (TweedieDistribution::new(-4.).unwrap(), array![0.1, 1.5]),
190    }
191
192    macro_rules! test_deviance_derivative {
193        ($($name:ident: {dist: $dist:expr, y: $y:expr, ypred: $ypred:expr, expected: $expected:expr,},)*) => {
194            $(
195                #[test]
196                fn $name() {
197                    let output = $dist.deviance_derivative($y.view(), $ypred.view());
198                    println!("{:?}", $expected);
199                    println!("{:?}", output);
200                    assert_abs_diff_eq!(output, $expected, epsilon=1e-6);
201                }
202            )*
203        };
204    }
205
206    test_deviance_derivative! {
207        test_derivative_normal: {
208            dist: TweedieDistribution::new(0.).unwrap(),
209            y: array![
210                0.94225502, 1.10863089, 0.99620489, 0.9383247, 0.81709632, 1.03933563, 0.83102873,
211                1.28521452, 1.35710428, 0.77688304
212            ],
213            ypred: array![
214                1.73398006, 1.6375258, 1.56424946, 1.86392134, 0.88813238, 1.12646493, 0.85124713,
215                2.11783437, 2.13526103, 1.64689519
216            ],
217            expected: array![
218                1.58345008, 1.05778984, 1.13608912, 1.85119328, 0.14207212, 0.1742586, 0.04043679,
219                1.66523969, 1.5563135, 1.7400243
220            ],
221        },
222        test_derivative_poisson: {
223            dist: TweedieDistribution::new(1.).unwrap(),
224            y: array![
225                0.94225502, 1.10863089, 0.99620489, 0.9383247, 0.81709632, 1.03933563, 0.83102873,
226                1.28521452, 1.35710428, 0.77688304
227            ],
228            ypred: array![
229                1.73398006, 1.6375258, 1.56424946, 1.86392134, 0.88813238, 1.12646493, 0.85124713,
230                2.11783437, 2.13526103, 1.64689519
231            ],
232            expected: array![
233                0.91318817, 0.64596835, 0.72628385, 0.99317135, 0.15996728, 0.15469509, 0.047503,
234                0.78629364, 0.72886335, 1.05654829
235            ],
236        },
237        test_derivative_gamma: {
238            dist: TweedieDistribution::new(2.).unwrap(),
239            y: array![
240                0.94225502, 1.10863089, 0.99620489, 0.9383247, 0.81709632, 1.03933563, 0.83102873,
241                1.28521452, 1.35710428, 0.77688304
242            ],
243            ypred: array![
244                1.73398006, 1.6375258, 1.56424946, 1.86392134, 0.88813238, 1.12646493, 0.85124713,
245                2.11783437, 2.13526103, 1.64689519
246            ],
247            expected: array![
248                0.52664283, 0.39447827, 0.46430181, 0.53283973, 0.18011648, 0.13732793, 0.05580401,
249                0.37127249, 0.34134625, 0.64153949
250            ],
251        },
252        test_derivative_inverse_gaussian: {
253            dist: TweedieDistribution::new(3.).unwrap(),
254            y: array![
255                0.94225502, 1.10863089, 0.99620489, 0.9383247, 0.81709632, 1.03933563, 0.83102873,
256                1.28521452, 1.35710428, 0.77688304
257            ],
258            ypred: array![
259                1.73398006, 1.6375258, 1.56424946, 1.86392134, 0.88813238, 1.12646493, 0.85124713,
260                2.11783437, 2.13526103, 1.64689519
261            ],
262            expected: array![
263                0.30371908, 0.24089896, 0.29682082, 0.28587029, 0.20280364, 0.12191052, 0.06555559,
264                0.17530761, 0.1598616, 0.38954482
265            ],
266        },
267        test_derivative_tweedie1: {
268            dist: TweedieDistribution::new(-2.5).unwrap(),
269            y: array![
270                0.94225502, 1.10863089, 0.99620489, 0.9383247, 0.81709632, 1.03933563, 0.83102873,
271                1.28521452, 1.35710428, 0.77688304
272            ],
273            ypred: array![
274                1.73398006, 1.6375258, 1.56424946, 1.86392134, 0.88813238, 1.12646493, 0.85124713,
275                2.11783437, 2.13526103, 1.64689519
276            ],
277            expected: array![
278                6.26923606,
279                3.62969199,
280                3.47678178,
281                8.78052969,
282                0.10560953,
283                0.23468666,
284                0.02703435,
285                10.86942904,
286                10.36870504,
287                6.05647896
288            ],
289        },
290        test_derivative_tweedie2: {
291            dist: TweedieDistribution::new(-1.).unwrap(),
292            y: array![
293                0.94225502, 1.10863089, 0.99620489, 0.9383247, 0.81709632, 1.03933563, 0.83102873,
294                1.28521452, 1.35710428, 0.77688304
295            ],
296            ypred: array![
297                1.73398006, 1.6375258, 1.56424946, 1.86392134, 0.88813238, 1.12646493, 0.85124713,
298                2.11783437, 2.13526103, 1.64689519
299            ],
300            expected: array![
301                2.74567086, 1.73215816, 1.77712679, 3.45047865, 0.12617885, 0.1962962, 0.03442171,
302                3.52670184, 3.32313557, 2.86563764
303            ],
304        },
305        test_derivative_tweedie3: {
306            dist: TweedieDistribution::new(1.5).unwrap(),
307            y: array![
308                0.94225502, 1.10863089, 0.99620489, 0.9383247, 0.81709632, 1.03933563, 0.83102873,
309                1.28521452, 1.35710428, 0.77688304
310            ],
311            ypred: array![
312                1.73398006, 1.6375258, 1.56424946, 1.86392134, 0.88813238, 1.12646493, 0.85124713,
313                2.11783437, 2.13526103, 1.64689519
314            ],
315            expected: array![
316                0.69348684, 0.50479746, 0.58070208, 0.72746214, 0.16974317, 0.14575307, 0.05148648,
317                0.54030473, 0.49879331, 0.8232967
318            ],
319        },
320        test_derivative_tweedie4: {
321            dist: TweedieDistribution::new(2.5).unwrap(),
322            y: array![
323                0.94225502, 1.10863089, 0.99620489, 0.9383247, 0.81709632, 1.03933563, 0.83102873,
324                1.28521452, 1.35710428, 0.77688304
325            ],
326            ypred: array![
327                1.73398006, 1.6375258, 1.56424946, 1.86392134, 0.88813238, 1.12646493, 0.85124713,
328                2.11783437, 2.13526103, 1.64689519
329            ],
330            expected: array![
331                0.39993934, 0.3082684, 0.37123368, 0.39028586, 0.19112372, 0.1293898, 0.06048359,
332                0.25512133, 0.23359829, 0.49990837
333            ],
334        },
335        test_derivative_tweedie5: {
336            dist: TweedieDistribution::new(-4.).unwrap(),
337            y: array![
338                0.94225502, 1.10863089, 0.99620489, 0.9383247, 0.81709632, 1.03933563, 0.83102873,
339                1.28521452, 1.35710428, 0.77688304
340            ],
341            ypred: array![
342                1.73398006, 1.6375258, 1.56424946, 1.86392134, 0.88813238, 1.12646493, 0.85124713,
343                2.11783437, 2.13526103, 1.64689519
344            ],
345            expected: array![
346                1.43146513e+01,
347                7.60592435e+00,
348                6.80199725e+00,
349                2.23440599e+01,
350                8.83933634e-02,
351                2.80585306e-01,
352                2.12324135e-02,
353                3.34999932e+01,
354                3.23519886e+01,
355                1.28002707e+01
356            ],
357        },
358    }
359}