1use linfa::Float;
2use ndarray::Zip;
3use ndarray::{Array1, ArrayView1};
4
5use crate::error::{LinearError, Result};
6
7#[derive(Debug, Clone, PartialEq)]
8pub struct TweedieDistribution<F: Float> {
9 power: F,
10 lower_bound: F,
11 inclusive: bool,
12}
13
14impl<F: Float> TweedieDistribution<F> {
15 pub fn new(power: F) -> Result<Self, F> {
16 let dist = match power {
18 power if power <= F::zero() => Self {
19 power,
20 lower_bound: F::neg_infinity(),
21 inclusive: false,
22 },
23 power if (power > F::zero() && power < F::one()) => {
24 return Err(LinearError::InvalidTweediePower(power));
25 }
26 power if (F::one()..F::cast(2.0)).contains(&power) => Self {
27 power,
28 lower_bound: F::zero(),
29 inclusive: true,
30 },
31 power if power >= F::cast(2.0) => Self {
32 power,
33 lower_bound: F::zero(),
34 inclusive: false,
35 },
36 _ => unreachable!(),
37 };
38
39 Ok(dist)
40 }
41
42 pub fn in_range(&self, y: &ArrayView1<F>) -> bool {
44 if self.inclusive {
45 return y.iter().all(|&x| x >= self.lower_bound);
46 }
47 y.iter().all(|&x| x > self.lower_bound)
48 }
49
50 fn unit_variance(&self, ypred: ArrayView1<F>) -> Array1<F> {
51 ypred.mapv(|x| x.powf(self.power))
53 }
54
55 fn unit_deviance(&self, y: ArrayView1<F>, ypred: ArrayView1<F>) -> Result<Array1<F>, F> {
56 match self.power {
57 power if power.is_negative() => {
58 let mut left = y.mapv(|x| x.max(F::zero()));
59
60 left.mapv_inplace(|x| {
61 x.powf(F::cast(2.) - self.power)
62 / ((F::one() - self.power) * (F::cast(2.) - self.power))
63 });
64
65 let middle =
66 &y * &ypred.mapv(|x| x.powf(F::cast(1.) - self.power) / (F::cast(1.) - power));
67
68 let right =
69 ypred.mapv(|x| x.powf(F::cast(2.) - self.power) / (F::cast(2.) - self.power));
70
71 Ok((left - middle + right).mapv(|x| F::cast(2.) * x))
72 }
73 power if power == F::zero() => Ok((&y - &ypred).mapv(|x| x * x)),
76 power if power < F::one() => Err(LinearError::InvalidTweediePower(power)),
77 power if (power - F::one()).abs() < F::cast(1e-6) => {
80 let mut div = &y / &ypred;
81 Zip::from(&mut div).and(y).for_each(|y, &x| {
82 if x == F::zero() {
83 *y = F::zero();
84 } else {
85 *y = F::cast(2.) * (x * y.ln());
86 }
87 });
88 Ok(div - y + ypred)
89 }
90 power if (power - F::cast(2.)).abs() < F::cast(1e-6) => {
93 let mut temp = (&ypred / &y).mapv(|x| x.ln()) + (&y / &ypred);
94 temp.mapv_inplace(|x| x - F::one());
95 Ok(temp.mapv(|x| F::cast(2.) * x))
96 }
97 power => {
98 let left = y.mapv(|x| {
99 x.powf(F::cast(2.) - power) / ((F::one() - power) * (F::cast(2.) - power))
100 });
101
102 let middle = &y * &ypred.mapv(|x| x.powf(F::one() - power) / (F::one() - power));
103
104 let right = ypred.mapv(|x| x.powf(F::cast(2.) - power) / (F::cast(2.) - power));
105
106 Ok((left - middle + right).mapv(|x| F::cast(2.) * x))
107 }
108 }
109 }
110
111 fn unit_deviance_derivative(&self, y: ArrayView1<F>, ypred: ArrayView1<F>) -> Array1<F> {
112 ((&y - &ypred) / &self.unit_variance(ypred)).mapv(|x| F::cast(-2.) * x)
113 }
114
115 pub fn deviance(&self, y: ArrayView1<F>, ypred: ArrayView1<F>) -> Result<F, F> {
116 Ok(self.unit_deviance(y, ypred)?.sum())
117 }
118
119 pub fn deviance_derivative(&self, y: ArrayView1<F>, ypred: ArrayView1<F>) -> Array1<F> {
120 self.unit_deviance_derivative(y, ypred)
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127 use approx::assert_abs_diff_eq;
128 use ndarray::array;
129
130 #[test]
131 fn autotraits() {
132 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
133 has_autotraits::<TweedieDistribution<f64>>();
134 }
135
136 #[test]
137 fn test_distribution_error() {
138 let tweedie = TweedieDistribution::new(0.2);
139 assert!(tweedie.is_err());
140 }
141
142 macro_rules! test_bounds {
143 ($($name:ident: ($dist:expr, $input:expr, $expected:expr),)*) => {
144 $(
145 #[test]
146 #[allow(clippy::bool_assert_comparison)]
147 fn $name() {
148 let output = $dist.in_range(&$input.view());
149 assert_eq!(output, $expected);
150 }
151 )*
152 };
153 }
154
155 test_bounds! {
156 test_bounds_normal: (TweedieDistribution::new(0.).unwrap(), array![-1., 0., 1.], true),
157 test_bounds_poisson1: (TweedieDistribution::new(1.).unwrap(), array![-1., 0., 1.], false),
158 test_bounds_poisson2: (TweedieDistribution::new(1.).unwrap(), array![0., 1., 2.], true),
159 test_bounds_tweedie1: (TweedieDistribution::new(1.5).unwrap(), array![-1., 0., 1.], false),
160 test_bounds_tweedie2: (TweedieDistribution::new(1.5).unwrap(), array![0., 1., 4.], true),
161 test_bounds_gamma1: (TweedieDistribution::new(2.).unwrap(), array![-1., 0., 1.], false),
162 test_bounds_gamma2: (TweedieDistribution::new(2.).unwrap(), array![0., 1., 2.], false),
163 test_bounds_gamma3: (TweedieDistribution::new(2.).unwrap(), array![1., 2., 3.], true),
164 test_bounds_inverse_gaussian: (TweedieDistribution::new(3.).unwrap(), array![-1., 0., 1.], false),
165 test_bounds_tweedie3: (TweedieDistribution::new(3.5).unwrap(), array![-1., 0., 1.], false),
166 }
167
168 macro_rules! test_deviance {
169 ($($name:ident: ($dist:expr, $input:expr),)*) => {
170 $(
171 #[test]
172 fn $name() {
173 let output = $dist.deviance($input.view(), $input.view()).unwrap();
174 assert_abs_diff_eq!(output, 0.0, epsilon=1e-9);
175 }
176 )*
177 }
178 }
179
180 test_deviance! {
181 test_deviance_normal: (TweedieDistribution::new(0.).unwrap(), array![-1.5, -0.1, 0.1, 2.5]),
182 test_deviance_poisson: (TweedieDistribution::new(1.).unwrap(), array![0.1, 1.5]),
183 test_deviance_gamma: (TweedieDistribution::new(2.).unwrap(), array![0.1, 1.5]),
184 test_deviance_inverse_gaussian: (TweedieDistribution::new(3.).unwrap(), array![0.1, 1.5]),
185 test_deviance_tweedie1: (TweedieDistribution::new(-2.5).unwrap(), array![0.1, 1.5]),
186 test_deviance_tweedie2: (TweedieDistribution::new(-1.).unwrap(), array![0.1, 1.5]),
187 test_deviance_tweedie3: (TweedieDistribution::new(1.5).unwrap(), array![0.1, 1.5]),
188 test_deviance_tweedie4: (TweedieDistribution::new(2.5).unwrap(), array![0.1, 1.5]),
189 test_deviance_tweedie5: (TweedieDistribution::new(-4.).unwrap(), array![0.1, 1.5]),
190 }
191
192 macro_rules! test_deviance_derivative {
193 ($($name:ident: {dist: $dist:expr, y: $y:expr, ypred: $ypred:expr, expected: $expected:expr,},)*) => {
194 $(
195 #[test]
196 fn $name() {
197 let output = $dist.deviance_derivative($y.view(), $ypred.view());
198 println!("{:?}", $expected);
199 println!("{:?}", output);
200 assert_abs_diff_eq!(output, $expected, epsilon=1e-6);
201 }
202 )*
203 };
204 }
205
206 test_deviance_derivative! {
207 test_derivative_normal: {
208 dist: TweedieDistribution::new(0.).unwrap(),
209 y: array![
210 0.94225502, 1.10863089, 0.99620489, 0.9383247, 0.81709632, 1.03933563, 0.83102873,
211 1.28521452, 1.35710428, 0.77688304
212 ],
213 ypred: array![
214 1.73398006, 1.6375258, 1.56424946, 1.86392134, 0.88813238, 1.12646493, 0.85124713,
215 2.11783437, 2.13526103, 1.64689519
216 ],
217 expected: array![
218 1.58345008, 1.05778984, 1.13608912, 1.85119328, 0.14207212, 0.1742586, 0.04043679,
219 1.66523969, 1.5563135, 1.7400243
220 ],
221 },
222 test_derivative_poisson: {
223 dist: TweedieDistribution::new(1.).unwrap(),
224 y: array![
225 0.94225502, 1.10863089, 0.99620489, 0.9383247, 0.81709632, 1.03933563, 0.83102873,
226 1.28521452, 1.35710428, 0.77688304
227 ],
228 ypred: array![
229 1.73398006, 1.6375258, 1.56424946, 1.86392134, 0.88813238, 1.12646493, 0.85124713,
230 2.11783437, 2.13526103, 1.64689519
231 ],
232 expected: array![
233 0.91318817, 0.64596835, 0.72628385, 0.99317135, 0.15996728, 0.15469509, 0.047503,
234 0.78629364, 0.72886335, 1.05654829
235 ],
236 },
237 test_derivative_gamma: {
238 dist: TweedieDistribution::new(2.).unwrap(),
239 y: array![
240 0.94225502, 1.10863089, 0.99620489, 0.9383247, 0.81709632, 1.03933563, 0.83102873,
241 1.28521452, 1.35710428, 0.77688304
242 ],
243 ypred: array![
244 1.73398006, 1.6375258, 1.56424946, 1.86392134, 0.88813238, 1.12646493, 0.85124713,
245 2.11783437, 2.13526103, 1.64689519
246 ],
247 expected: array![
248 0.52664283, 0.39447827, 0.46430181, 0.53283973, 0.18011648, 0.13732793, 0.05580401,
249 0.37127249, 0.34134625, 0.64153949
250 ],
251 },
252 test_derivative_inverse_gaussian: {
253 dist: TweedieDistribution::new(3.).unwrap(),
254 y: array![
255 0.94225502, 1.10863089, 0.99620489, 0.9383247, 0.81709632, 1.03933563, 0.83102873,
256 1.28521452, 1.35710428, 0.77688304
257 ],
258 ypred: array![
259 1.73398006, 1.6375258, 1.56424946, 1.86392134, 0.88813238, 1.12646493, 0.85124713,
260 2.11783437, 2.13526103, 1.64689519
261 ],
262 expected: array![
263 0.30371908, 0.24089896, 0.29682082, 0.28587029, 0.20280364, 0.12191052, 0.06555559,
264 0.17530761, 0.1598616, 0.38954482
265 ],
266 },
267 test_derivative_tweedie1: {
268 dist: TweedieDistribution::new(-2.5).unwrap(),
269 y: array![
270 0.94225502, 1.10863089, 0.99620489, 0.9383247, 0.81709632, 1.03933563, 0.83102873,
271 1.28521452, 1.35710428, 0.77688304
272 ],
273 ypred: array![
274 1.73398006, 1.6375258, 1.56424946, 1.86392134, 0.88813238, 1.12646493, 0.85124713,
275 2.11783437, 2.13526103, 1.64689519
276 ],
277 expected: array![
278 6.26923606,
279 3.62969199,
280 3.47678178,
281 8.78052969,
282 0.10560953,
283 0.23468666,
284 0.02703435,
285 10.86942904,
286 10.36870504,
287 6.05647896
288 ],
289 },
290 test_derivative_tweedie2: {
291 dist: TweedieDistribution::new(-1.).unwrap(),
292 y: array![
293 0.94225502, 1.10863089, 0.99620489, 0.9383247, 0.81709632, 1.03933563, 0.83102873,
294 1.28521452, 1.35710428, 0.77688304
295 ],
296 ypred: array![
297 1.73398006, 1.6375258, 1.56424946, 1.86392134, 0.88813238, 1.12646493, 0.85124713,
298 2.11783437, 2.13526103, 1.64689519
299 ],
300 expected: array![
301 2.74567086, 1.73215816, 1.77712679, 3.45047865, 0.12617885, 0.1962962, 0.03442171,
302 3.52670184, 3.32313557, 2.86563764
303 ],
304 },
305 test_derivative_tweedie3: {
306 dist: TweedieDistribution::new(1.5).unwrap(),
307 y: array![
308 0.94225502, 1.10863089, 0.99620489, 0.9383247, 0.81709632, 1.03933563, 0.83102873,
309 1.28521452, 1.35710428, 0.77688304
310 ],
311 ypred: array![
312 1.73398006, 1.6375258, 1.56424946, 1.86392134, 0.88813238, 1.12646493, 0.85124713,
313 2.11783437, 2.13526103, 1.64689519
314 ],
315 expected: array![
316 0.69348684, 0.50479746, 0.58070208, 0.72746214, 0.16974317, 0.14575307, 0.05148648,
317 0.54030473, 0.49879331, 0.8232967
318 ],
319 },
320 test_derivative_tweedie4: {
321 dist: TweedieDistribution::new(2.5).unwrap(),
322 y: array![
323 0.94225502, 1.10863089, 0.99620489, 0.9383247, 0.81709632, 1.03933563, 0.83102873,
324 1.28521452, 1.35710428, 0.77688304
325 ],
326 ypred: array![
327 1.73398006, 1.6375258, 1.56424946, 1.86392134, 0.88813238, 1.12646493, 0.85124713,
328 2.11783437, 2.13526103, 1.64689519
329 ],
330 expected: array![
331 0.39993934, 0.3082684, 0.37123368, 0.39028586, 0.19112372, 0.1293898, 0.06048359,
332 0.25512133, 0.23359829, 0.49990837
333 ],
334 },
335 test_derivative_tweedie5: {
336 dist: TweedieDistribution::new(-4.).unwrap(),
337 y: array![
338 0.94225502, 1.10863089, 0.99620489, 0.9383247, 0.81709632, 1.03933563, 0.83102873,
339 1.28521452, 1.35710428, 0.77688304
340 ],
341 ypred: array![
342 1.73398006, 1.6375258, 1.56424946, 1.86392134, 0.88813238, 1.12646493, 0.85124713,
343 2.11783437, 2.13526103, 1.64689519
344 ],
345 expected: array![
346 1.43146513e+01,
347 7.60592435e+00,
348 6.80199725e+00,
349 2.23440599e+01,
350 8.83933634e-02,
351 2.80585306e-01,
352 2.12324135e-02,
353 3.34999932e+01,
354 3.23519886e+01,
355 1.28002707e+01
356 ],
357 },
358 }
359}