1mod distribution;
4mod hyperparams;
5mod link;
6
7use crate::error::{LinearError, Result};
8use crate::float::Float;
9use argmin_math::{
10 ArgminAdd, ArgminDot, ArgminL1Norm, ArgminL2Norm, ArgminMinMax, ArgminMul, ArgminSignum,
11 ArgminSub, ArgminZero,
12};
13use distribution::TweedieDistribution;
14pub use hyperparams::TweedieRegressorParams;
15pub use hyperparams::TweedieRegressorValidParams;
16use linfa::dataset::AsSingleTargets;
17pub use link::Link;
18
19use argmin::core::{CostFunction, Executor, Gradient};
20use argmin::solver::linesearch::MoreThuenteLineSearch;
21use argmin::solver::quasinewton::LBFGS;
22use ndarray::{array, concatenate, s};
23use ndarray::{Array, Array1, ArrayBase, ArrayView1, ArrayView2, Axis, Data, Ix2};
24#[cfg(feature = "serde")]
25use serde_crate::{Deserialize, Serialize};
26
27use linfa::traits::*;
28use linfa::DatasetBase;
29
30impl<F: Float, D: Data<Elem = F>, T: AsSingleTargets<Elem = F>>
31 Fit<ArrayBase<D, Ix2>, T, LinearError<F>> for TweedieRegressorValidParams<F>
32where
33 Array1<F>: ArgminAdd<Array1<F>, Array1<F>>
34 + ArgminSub<Array1<F>, Array1<F>>
35 + ArgminSub<F, Array1<F>>
36 + ArgminAdd<F, Array1<F>>
37 + ArgminMul<F, Array1<F>>
38 + ArgminMul<Array1<F>, Array1<F>>
39 + ArgminDot<Array1<F>, F>
40 + ArgminL2Norm<F>
41 + ArgminL1Norm<F>
42 + ArgminSignum
43 + ArgminMinMax,
44 F: ArgminMul<Array1<F>, Array1<F>> + ArgminZero,
45{
46 type Object = TweedieRegressor<F>;
47
48 fn fit(&self, ds: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object, F> {
49 let (x, y) = (ds.records(), ds.as_single_targets());
50
51 let dist = TweedieDistribution::new(self.power())?;
52 let link = self.link();
53
54 if !dist.in_range(&y) {
58 return Err(LinearError::InvalidTargetRange(self.power()));
61 }
62 let mut coef = Array::zeros(x.ncols());
64 if self.fit_intercept() {
65 let temp = link.link(&array![y.mean().unwrap()]);
66 coef = concatenate!(Axis(0), temp, coef);
67 }
68
69 let problem = TweedieProblem {
73 x: x.view(),
74 y,
75 fit_intercept: self.fit_intercept(),
76 link: &link,
77 dist,
78 alpha: self.alpha(),
79 };
80 let linesearch = MoreThuenteLineSearch::new();
81
82 let solver = LBFGS::new(linesearch, 7).with_tolerance_grad(F::cast(self.tol()))?;
87
88 let mut result = Executor::new(problem, solver)
89 .configure(|state| state.param(coef).max_iters(self.max_iter() as u64))
90 .run()?;
91 coef = result.state.take_best_param().unwrap();
92
93 if self.fit_intercept() {
94 Ok(TweedieRegressor {
95 coef: coef.slice(s![1..]).to_owned(),
96 intercept: *coef.get(0).unwrap(),
97 link,
98 })
99 } else {
100 Ok(TweedieRegressor {
101 coef: coef.to_owned(),
102 intercept: F::cast(0.),
103 link,
104 })
105 }
106 }
107}
108
109struct TweedieProblem<'a, F: Float> {
110 x: ArrayView2<'a, F>,
111 y: ArrayView1<'a, F>,
112 fit_intercept: bool,
113 link: &'a Link,
114 dist: TweedieDistribution<F>,
115 alpha: F,
116}
117
118impl<A: Float> TweedieProblem<'_, A> {
119 fn ypred(&self, p: &Array1<A>) -> (Array1<A>, Array1<A>, usize) {
120 let mut offset = 0;
121 let mut intercept = A::from(0.).unwrap();
122 if self.fit_intercept {
123 offset = 1;
124 intercept = *p.get(0).unwrap();
125 }
126
127 let lin_pred = self
128 .x
129 .view()
130 .dot(&p.slice(s![offset..]))
131 .mapv(|x| x + intercept);
132
133 (self.link.inverse(&lin_pred), lin_pred, offset)
134 }
135}
136
137impl<A: Float> CostFunction for TweedieProblem<'_, A> {
138 type Param = Array1<A>;
139 type Output = A;
140
141 fn cost(&self, p: &Self::Param) -> std::result::Result<Self::Output, argmin::core::Error> {
149 let (ypred, _, offset) = self.ypred(p);
150
151 let dev = self.dist.deviance(self.y, ypred.view())?;
152
153 let pscaled = p
154 .slice(s![offset..])
155 .mapv(|x| x * A::from(self.alpha).unwrap());
156
157 let obj = A::from(0.5).unwrap() * (dev + p.slice(s![offset..]).dot(&pscaled));
158
159 Ok(obj)
160 }
161}
162
163impl<A: Float> Gradient for TweedieProblem<'_, A> {
164 type Param = Array1<A>;
165 type Gradient = Array1<A>;
166
167 fn gradient(&self, p: &Self::Param) -> std::result::Result<Self::Param, argmin::core::Error> {
168 let (ypred, lin_pred, offset) = self.ypred(p);
169
170 let devp;
171 let der = self.link.inverse_derviative(&lin_pred);
172 let temp = der * self.dist.deviance_derivative(self.y, ypred.view());
173 if self.fit_intercept {
174 devp = concatenate![Axis(0), array![temp.sum()], temp.dot(&self.x)];
175 } else {
176 devp = temp.dot(&self.x);
177 }
178
179 let pscaled = p
180 .slice(s![offset..])
181 .mapv(|x| x * A::from(self.alpha).unwrap());
182
183 let mut objp = devp.mapv(|x| x * A::from(0.5).unwrap());
184 objp.slice_mut(s![offset..])
185 .zip_mut_with(&pscaled, |x, y| *x += *y);
186
187 Ok(objp)
188 }
189}
190
191#[derive(Debug, Clone, PartialEq)]
224#[cfg_attr(
225 feature = "serde",
226 derive(Serialize, Deserialize),
227 serde(crate = "serde_crate")
228)]
229pub struct TweedieRegressor<A> {
230 pub coef: Array1<A>,
232 pub intercept: A,
234 link: Link,
235}
236
237impl<A: Float, D: Data<Elem = A>> PredictInplace<ArrayBase<D, Ix2>, Array1<A>>
238 for TweedieRegressor<A>
239{
240 fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<A>) {
242 assert_eq!(
243 x.nrows(),
244 y.len(),
245 "The number of data points must match the number of output targets."
246 );
247
248 let ypred = x.dot(&self.coef) + self.intercept;
249 *y = self.link.inverse(&ypred);
250 }
251
252 fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<A> {
253 Array1::zeros(x.nrows())
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260 use crate::glm::hyperparams::TweedieRegressorParams;
261 use approx::assert_abs_diff_eq;
262 use linfa::Dataset;
263 use ndarray::{array, Array2};
264
265 #[test]
266 fn autotraits() {
267 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
268 has_autotraits::<TweedieRegressor<f64>>();
269 has_autotraits::<TweedieRegressorValidParams<f64>>();
270 has_autotraits::<TweedieRegressorParams<f64>>();
271 }
272
273 macro_rules! test_tweedie {
274 ($($name:ident: {power: $power:expr, intercept: $intercept:expr,},)*) => {
275 $(
276 #[test]
277 fn $name() {
278 let coef = array![0.2, -0.1];
279 let mut x: Array2<f64> = array![[1., 1., 1., 1., 1.], [0., 1., 2., 3., 4.]].reversed_axes();
280 let y = x.dot(&coef).mapv(|x| x.exp());
281
282 let glm = TweedieRegressor::params()
283 .alpha(0.)
284 .power($power)
285 .link(Link::Log)
286 .tol(1e-7)
287 .fit_intercept($intercept);
288
289 if $intercept {
290 x = x.slice(s![.., 1..]).to_owned();
291 let dataset = Dataset::new(x, y);
292 let glm = glm.fit(&dataset).unwrap();
293
294 assert_abs_diff_eq!(glm.intercept, coef.get(0).unwrap(), epsilon = 1e-3);
295 assert_abs_diff_eq!(glm.coef, coef.slice(s![1..]), epsilon = 1e-3);
296 } else {
297 let dataset = Dataset::new(x, y);
298 let glm = glm.fit(&dataset).unwrap();
299
300 assert_abs_diff_eq!(glm.coef, coef, epsilon = 1e-3);
301 }
302 }
303 )*
304 }
305 }
306
307 test_tweedie! {
308 test_glm_normal1: {
309 power: 0.,
310 intercept: true,
311 },
312 test_glm_normal2: {
313 power: 0.,
314 intercept: false,
315 },
316 test_glm_poisson1: {
317 power: 1.,
318 intercept: true,
319 },
320 test_glm_poisson2: {
321 power: 1.,
322 intercept: false,
323 },
324 test_glm_gamma1: {
325 power: 2.,
326 intercept: true,
327 },
328 test_glm_gamma2: {
329 power: 2.,
330 intercept: false,
331 },
332 test_glm_inverse_gaussian1: {
333 power: 3.,
334 intercept: true,
335 },
336 test_glm_inverse_gaussian2: {
337 power: 3.,
338 intercept: false,
339 },
340 test_glm_tweedie1: {
341 power: 1.5,
342 intercept: true,
343 },
344 test_glm_tweedie2: {
345 power: 1.5,
346 intercept: false,
347 },
348 }
349}