linfa_linear/glm/
link.rs

1//! Link functions used by GLM
2
3use ndarray::Array1;
4#[cfg(feature = "serde")]
5use serde_crate::{Deserialize, Serialize};
6
7use crate::float::Float;
8
9#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
10#[cfg_attr(
11    feature = "serde",
12    derive(Serialize, Deserialize),
13    serde(crate = "serde_crate")
14)]
15/// Link functions used by GLM
16pub enum Link {
17    /// The identity link function `g(x)=x`
18    Identity,
19    /// The log link function `g(x)=log(x)`
20    Log,
21    /// The logit link function `g(x)=logit(x)`
22    Logit,
23}
24
25impl Link {
26    /// Compute the link function `g(ypred)`
27    ///
28    /// The link function links the mean `ypred=E[y]` to the so called
29    /// linear predictor, `g(ypred)=linear predictor`
30    pub fn link<A: Float>(&self, ypred: &Array1<A>) -> Array1<A> {
31        match self {
32            Self::Identity => IdentityLink::link(ypred),
33            Self::Log => LogLink::link(ypred),
34            Self::Logit => LogitLink::link(ypred),
35        }
36    }
37
38    /// Computes the derivative of the link `g'(ypred)`
39    pub fn link_derivative<A: Float>(&self, ypred: &Array1<A>) -> Array1<A> {
40        match self {
41            Self::Identity => IdentityLink::link_derivative(ypred),
42            Self::Log => LogLink::link_derivative(ypred),
43            Self::Logit => LogitLink::link_derivative(ypred),
44        }
45    }
46
47    /// Computes the inverse link function `h(linear predictor)`
48    ///
49    /// Gives the inverse relationship between the linear predictor and the mean
50    /// `ypred=E[y]`, i.e. `h(linear predictor)=ypred`
51    pub fn inverse<A: Float>(&self, lin_pred: &Array1<A>) -> Array1<A> {
52        match self {
53            Self::Identity => IdentityLink::inverse(lin_pred),
54            Self::Log => LogLink::inverse(lin_pred),
55            Self::Logit => LogitLink::inverse(lin_pred),
56        }
57    }
58
59    /// Computes the derivative of the inverse link function `h'(linear predictor)`
60    pub fn inverse_derviative<A: Float>(&self, lin_pred: &Array1<A>) -> Array1<A> {
61        match self {
62            Self::Identity => IdentityLink::inverse_derivative(lin_pred),
63            Self::Log => LogLink::inverse_derivative(lin_pred),
64            Self::Logit => LogitLink::inverse_derivative(lin_pred),
65        }
66    }
67}
68
69trait LinkFn<A> {
70    fn link(ypred: &Array1<A>) -> Array1<A>;
71    fn link_derivative(ypred: &Array1<A>) -> Array1<A>;
72    fn inverse(lin_pred: &Array1<A>) -> Array1<A>;
73    fn inverse_derivative(lin_pred: &Array1<A>) -> Array1<A>;
74}
75
76struct IdentityLink;
77
78impl<A: Float> LinkFn<A> for IdentityLink {
79    fn link(ypred: &Array1<A>) -> Array1<A> {
80        ypred.clone()
81    }
82
83    fn link_derivative(ypred: &Array1<A>) -> Array1<A> {
84        Array1::ones(ypred.shape()[0])
85    }
86
87    fn inverse(lin_pred: &Array1<A>) -> Array1<A> {
88        lin_pred.clone()
89    }
90
91    fn inverse_derivative(lin_pred: &Array1<A>) -> Array1<A> {
92        Array1::ones(lin_pred.shape()[0])
93    }
94}
95
96struct LogLink;
97
98impl<A: linfa::Float> LinkFn<A> for LogLink {
99    fn link(ypred: &Array1<A>) -> Array1<A> {
100        ypred.mapv(|x| x.ln())
101    }
102
103    fn link_derivative(ypred: &Array1<A>) -> Array1<A> {
104        // 1 / ypred
105        ypred.mapv(|x| {
106            let lower_bound = A::from(1e-7).unwrap();
107            if x < lower_bound {
108                return lower_bound.recip();
109            }
110            x.recip()
111        })
112    }
113
114    fn inverse(lin_pred: &Array1<A>) -> Array1<A> {
115        lin_pred.mapv(|x| x.exp())
116    }
117
118    fn inverse_derivative(lin_pred: &Array1<A>) -> Array1<A> {
119        lin_pred.mapv(|x| x.exp())
120    }
121}
122
123struct LogitLink;
124
125impl<A: linfa::Float> LinkFn<A> for LogitLink {
126    fn link(ypred: &Array1<A>) -> Array1<A> {
127        // logit(ypred)
128        ypred.mapv(|x| (x / (A::one() - x)).ln())
129    }
130
131    fn link_derivative(ypred: &Array1<A>) -> Array1<A> {
132        // 1 / (ypred * (1-ypred)
133        ypred.mapv(|x| A::one() / (x * (A::one() - x)))
134    }
135
136    fn inverse(lin_pred: &Array1<A>) -> Array1<A> {
137        // expit(lin_pred)
138        lin_pred.mapv(|x| A::one() / (A::one() + x.neg().exp()))
139    }
140
141    fn inverse_derivative(lin_pred: &Array1<A>) -> Array1<A> {
142        // expit(lin_pred) * (1 - expit(lin_pred))
143        let expit = lin_pred.mapv(|x| A::one() / (A::one() + x.neg().exp()));
144        let one_minus_expit = expit.mapv(|x| A::one() - x);
145        expit * one_minus_expit
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152    use approx::assert_abs_diff_eq;
153    use ndarray::array;
154
155    macro_rules! test_links {
156        ($($func:ident: {input: $input:expr, expected: $expected:expr, link: $link:expr}),*) => {
157            $(
158                #[test]
159                fn $func() {
160                    for (expected, input) in $expected.iter().zip($input.iter()) {
161                        let output = $link(input);
162                        assert_abs_diff_eq!(output, expected, epsilon = 1e-6);
163                    }
164                }
165            )*
166        };
167    }
168
169    #[test]
170    fn autotraits() {
171        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
172        has_autotraits::<Link>();
173        has_autotraits::<IdentityLink>();
174        has_autotraits::<LogLink>();
175        has_autotraits::<LogitLink>();
176    }
177
178    test_links! [
179        test_identity_link: {
180            input: &[array![1., 1., 1., 1.], array![1.348, 2.879, 4.545, 3.232]],
181            expected: &[array![1., 1., 1., 1.], array![1.348, 2.879, 4.545, 3.232]],
182            link: IdentityLink::link
183        },
184        test_identity_link_derivative: {
185            input: &[array![1., 1., 1., 1.], array![1.348, 2.879, 4.545, 3.232]],
186            expected: &[array![1., 1., 1., 1.], array![1., 1., 1., 1.]],
187            link: IdentityLink::link_derivative
188        },
189        test_identity_inverse: {
190            input: &[array![1., 1., 1., 1.], array![1.348, 2.879, 4.545, 3.232]],
191            expected: &[array![1., 1., 1., 1.], array![1.348, 2.879, 4.545, 3.232]],
192            link: IdentityLink::inverse
193        },
194        test_identity_inverse_derivative: {
195            input: &[array![1., 1., 1., 1.], array![1.348, 2.879, 4.545, 3.232]],
196            expected: &[array![1., 1., 1., 1.], array![1., 1., 1., 1.]],
197            link: IdentityLink::inverse_derivative
198        }
199    ];
200
201    test_links! [
202        test_log_link: {
203            input: &[
204                array![1.382, 1.329, 1.32, 1.322],
205                array![4.56432e+01, 4.30000e+01, 2.00000e-07, 3.42000e-01],
206            ],
207            expected: &[
208                array![0.32353173, 0.28442678, 0.27763174, 0.27914574],
209                array![3.82085464, 3.76120012, -15.42494847, -1.07294454],
210            ],
211            link: LogLink::link
212        },
213        test_log_link_derivative: {
214            input: &[
215                array![1.382, 1.329, 1.32, 1.322],
216                array![4.56432e+01, 4.30000e+01, 2.00000e-07, 3.42000e-01],
217            ],
218            expected: &[
219                array![0.723589, 0.75244545, 0.75757576, 0.75642965],
220                array![
221                    2.19090686e-02,
222                    2.32558140e-02,
223                    5.00000000e+06,
224                    2.92397661e+00
225                ],
226            ],
227            link: LogLink::link_derivative
228        },
229        test_log_inverse: {
230            input: &[
231                array![1.382f32, 1.329f32, 1.32f32, 1.322f32],
232                array![4.56432e+01, 4.30000e+01, 2.00000e-07, 3.42000e-01],
233            ],
234            expected: &[
235                array![3.982_859_4, 3.777_264, 3.743_421_3, 3.750_915_8],
236                array![6.646_452e19, 4.727_839_5e18, 1.000_000_2e0, 1.407_760_3e0],
237            ],
238            link: LogLink::inverse
239        },
240        test_log_inverse_derivative: {
241            input: &[
242                array![1.382f32, 1.329f32, 1.32f32, 1.322f32],
243                array![4.56432e+01, 4.30000e+01, 2.00000e-07, 3.42000e-01],
244            ],
245            expected: &[
246                array![3.982_859_4, 3.777_264, 3.743_421_3, 3.750_915_8],
247                array![6.646_452e19, 4.727_839_5e18, 1.000_000_2e0, 1.407_760_3e0],
248            ],
249            link: LogLink::inverse_derivative
250        }
251    ];
252
253    test_links! [
254        test_logit_link: {
255            input: &[
256                array![0.934, 0.323, 0.989, 0.412], array![0.044, 0.023, 0.999, 0.124]
257            ],
258            expected: &[
259                array![2.6498217, -0.74001895, 4.49879906, -0.3557036 ],
260                array![-3.07856828, -3.74899244,  6.90675478, -1.95508453],
261            ],
262            link: LogitLink::link
263        },
264        test_logit_link_derivative: {
265            input: &[array![0.934, 0.323, 0.989, 0.412], array![0.044, 0.023, 0.999, 0.124]],
266            expected: &[
267                array![16.22217896, 4.57308011, 91.92021325, 4.12786474],
268                array![23.77329783, 44.50180232, 1001.001001, 9.20606864],
269            ],
270            link: LogitLink::link_derivative
271        },
272        test_logit_inverse: {
273            input: &[array![0.934, 0.323, 0.989, 0.412], array![0.044, 0.023, 0.999, 0.124]],
274            expected: &[
275                array![0.71788609, 0.5800552, 0.72889036, 0.60156734],
276                array![0.51099823, 0.50574975, 0.73086192, 0.53096034],
277            ],
278            link: LogitLink::inverse
279        },
280        test_logit_inverse_derivative: {
281            input: &[array![0.934, 0.323, 0.989, 0.412], array![0.044, 0.023, 0.999, 0.124]],
282            expected: &[
283                array![0.20252565, 0.24359116, 0.1976092, 0.23968407],
284                array![0.24987904, 0.24996694, 0.19670277, 0.24904146],
285            ],
286            link: LogitLink::inverse_derivative
287        }
288    ];
289}