1use ndarray::Array1;
4#[cfg(feature = "serde")]
5use serde_crate::{Deserialize, Serialize};
6
7use crate::float::Float;
8
9#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
10#[cfg_attr(
11 feature = "serde",
12 derive(Serialize, Deserialize),
13 serde(crate = "serde_crate")
14)]
15pub enum Link {
17 Identity,
19 Log,
21 Logit,
23}
24
25impl Link {
26 pub fn link<A: Float>(&self, ypred: &Array1<A>) -> Array1<A> {
31 match self {
32 Self::Identity => IdentityLink::link(ypred),
33 Self::Log => LogLink::link(ypred),
34 Self::Logit => LogitLink::link(ypred),
35 }
36 }
37
38 pub fn link_derivative<A: Float>(&self, ypred: &Array1<A>) -> Array1<A> {
40 match self {
41 Self::Identity => IdentityLink::link_derivative(ypred),
42 Self::Log => LogLink::link_derivative(ypred),
43 Self::Logit => LogitLink::link_derivative(ypred),
44 }
45 }
46
47 pub fn inverse<A: Float>(&self, lin_pred: &Array1<A>) -> Array1<A> {
52 match self {
53 Self::Identity => IdentityLink::inverse(lin_pred),
54 Self::Log => LogLink::inverse(lin_pred),
55 Self::Logit => LogitLink::inverse(lin_pred),
56 }
57 }
58
59 pub fn inverse_derviative<A: Float>(&self, lin_pred: &Array1<A>) -> Array1<A> {
61 match self {
62 Self::Identity => IdentityLink::inverse_derivative(lin_pred),
63 Self::Log => LogLink::inverse_derivative(lin_pred),
64 Self::Logit => LogitLink::inverse_derivative(lin_pred),
65 }
66 }
67}
68
69trait LinkFn<A> {
70 fn link(ypred: &Array1<A>) -> Array1<A>;
71 fn link_derivative(ypred: &Array1<A>) -> Array1<A>;
72 fn inverse(lin_pred: &Array1<A>) -> Array1<A>;
73 fn inverse_derivative(lin_pred: &Array1<A>) -> Array1<A>;
74}
75
76struct IdentityLink;
77
78impl<A: Float> LinkFn<A> for IdentityLink {
79 fn link(ypred: &Array1<A>) -> Array1<A> {
80 ypred.clone()
81 }
82
83 fn link_derivative(ypred: &Array1<A>) -> Array1<A> {
84 Array1::ones(ypred.shape()[0])
85 }
86
87 fn inverse(lin_pred: &Array1<A>) -> Array1<A> {
88 lin_pred.clone()
89 }
90
91 fn inverse_derivative(lin_pred: &Array1<A>) -> Array1<A> {
92 Array1::ones(lin_pred.shape()[0])
93 }
94}
95
96struct LogLink;
97
98impl<A: linfa::Float> LinkFn<A> for LogLink {
99 fn link(ypred: &Array1<A>) -> Array1<A> {
100 ypred.mapv(|x| x.ln())
101 }
102
103 fn link_derivative(ypred: &Array1<A>) -> Array1<A> {
104 ypred.mapv(|x| {
106 let lower_bound = A::from(1e-7).unwrap();
107 if x < lower_bound {
108 return lower_bound.recip();
109 }
110 x.recip()
111 })
112 }
113
114 fn inverse(lin_pred: &Array1<A>) -> Array1<A> {
115 lin_pred.mapv(|x| x.exp())
116 }
117
118 fn inverse_derivative(lin_pred: &Array1<A>) -> Array1<A> {
119 lin_pred.mapv(|x| x.exp())
120 }
121}
122
123struct LogitLink;
124
125impl<A: linfa::Float> LinkFn<A> for LogitLink {
126 fn link(ypred: &Array1<A>) -> Array1<A> {
127 ypred.mapv(|x| (x / (A::one() - x)).ln())
129 }
130
131 fn link_derivative(ypred: &Array1<A>) -> Array1<A> {
132 ypred.mapv(|x| A::one() / (x * (A::one() - x)))
134 }
135
136 fn inverse(lin_pred: &Array1<A>) -> Array1<A> {
137 lin_pred.mapv(|x| A::one() / (A::one() + x.neg().exp()))
139 }
140
141 fn inverse_derivative(lin_pred: &Array1<A>) -> Array1<A> {
142 let expit = lin_pred.mapv(|x| A::one() / (A::one() + x.neg().exp()));
144 let one_minus_expit = expit.mapv(|x| A::one() - x);
145 expit * one_minus_expit
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152 use approx::assert_abs_diff_eq;
153 use ndarray::array;
154
155 macro_rules! test_links {
156 ($($func:ident: {input: $input:expr, expected: $expected:expr, link: $link:expr}),*) => {
157 $(
158 #[test]
159 fn $func() {
160 for (expected, input) in $expected.iter().zip($input.iter()) {
161 let output = $link(input);
162 assert_abs_diff_eq!(output, expected, epsilon = 1e-6);
163 }
164 }
165 )*
166 };
167 }
168
169 #[test]
170 fn autotraits() {
171 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
172 has_autotraits::<Link>();
173 has_autotraits::<IdentityLink>();
174 has_autotraits::<LogLink>();
175 has_autotraits::<LogitLink>();
176 }
177
178 test_links! [
179 test_identity_link: {
180 input: &[array![1., 1., 1., 1.], array![1.348, 2.879, 4.545, 3.232]],
181 expected: &[array![1., 1., 1., 1.], array![1.348, 2.879, 4.545, 3.232]],
182 link: IdentityLink::link
183 },
184 test_identity_link_derivative: {
185 input: &[array![1., 1., 1., 1.], array![1.348, 2.879, 4.545, 3.232]],
186 expected: &[array![1., 1., 1., 1.], array![1., 1., 1., 1.]],
187 link: IdentityLink::link_derivative
188 },
189 test_identity_inverse: {
190 input: &[array![1., 1., 1., 1.], array![1.348, 2.879, 4.545, 3.232]],
191 expected: &[array![1., 1., 1., 1.], array![1.348, 2.879, 4.545, 3.232]],
192 link: IdentityLink::inverse
193 },
194 test_identity_inverse_derivative: {
195 input: &[array![1., 1., 1., 1.], array![1.348, 2.879, 4.545, 3.232]],
196 expected: &[array![1., 1., 1., 1.], array![1., 1., 1., 1.]],
197 link: IdentityLink::inverse_derivative
198 }
199 ];
200
201 test_links! [
202 test_log_link: {
203 input: &[
204 array![1.382, 1.329, 1.32, 1.322],
205 array![4.56432e+01, 4.30000e+01, 2.00000e-07, 3.42000e-01],
206 ],
207 expected: &[
208 array![0.32353173, 0.28442678, 0.27763174, 0.27914574],
209 array![3.82085464, 3.76120012, -15.42494847, -1.07294454],
210 ],
211 link: LogLink::link
212 },
213 test_log_link_derivative: {
214 input: &[
215 array![1.382, 1.329, 1.32, 1.322],
216 array![4.56432e+01, 4.30000e+01, 2.00000e-07, 3.42000e-01],
217 ],
218 expected: &[
219 array![0.723589, 0.75244545, 0.75757576, 0.75642965],
220 array![
221 2.19090686e-02,
222 2.32558140e-02,
223 5.00000000e+06,
224 2.92397661e+00
225 ],
226 ],
227 link: LogLink::link_derivative
228 },
229 test_log_inverse: {
230 input: &[
231 array![1.382f32, 1.329f32, 1.32f32, 1.322f32],
232 array![4.56432e+01, 4.30000e+01, 2.00000e-07, 3.42000e-01],
233 ],
234 expected: &[
235 array![3.982_859_4, 3.777_264, 3.743_421_3, 3.750_915_8],
236 array![6.646_452e19, 4.727_839_5e18, 1.000_000_2e0, 1.407_760_3e0],
237 ],
238 link: LogLink::inverse
239 },
240 test_log_inverse_derivative: {
241 input: &[
242 array![1.382f32, 1.329f32, 1.32f32, 1.322f32],
243 array![4.56432e+01, 4.30000e+01, 2.00000e-07, 3.42000e-01],
244 ],
245 expected: &[
246 array![3.982_859_4, 3.777_264, 3.743_421_3, 3.750_915_8],
247 array![6.646_452e19, 4.727_839_5e18, 1.000_000_2e0, 1.407_760_3e0],
248 ],
249 link: LogLink::inverse_derivative
250 }
251 ];
252
253 test_links! [
254 test_logit_link: {
255 input: &[
256 array![0.934, 0.323, 0.989, 0.412], array![0.044, 0.023, 0.999, 0.124]
257 ],
258 expected: &[
259 array![2.6498217, -0.74001895, 4.49879906, -0.3557036 ],
260 array![-3.07856828, -3.74899244, 6.90675478, -1.95508453],
261 ],
262 link: LogitLink::link
263 },
264 test_logit_link_derivative: {
265 input: &[array![0.934, 0.323, 0.989, 0.412], array![0.044, 0.023, 0.999, 0.124]],
266 expected: &[
267 array![16.22217896, 4.57308011, 91.92021325, 4.12786474],
268 array![23.77329783, 44.50180232, 1001.001001, 9.20606864],
269 ],
270 link: LogitLink::link_derivative
271 },
272 test_logit_inverse: {
273 input: &[array![0.934, 0.323, 0.989, 0.412], array![0.044, 0.023, 0.999, 0.124]],
274 expected: &[
275 array![0.71788609, 0.5800552, 0.72889036, 0.60156734],
276 array![0.51099823, 0.50574975, 0.73086192, 0.53096034],
277 ],
278 link: LogitLink::inverse
279 },
280 test_logit_inverse_derivative: {
281 input: &[array![0.934, 0.323, 0.989, 0.412], array![0.044, 0.023, 0.999, 0.124]],
282 expected: &[
283 array![0.20252565, 0.24359116, 0.1976092, 0.23968407],
284 array![0.24987904, 0.24996694, 0.19670277, 0.24904146],
285 ],
286 link: LogitLink::inverse_derivative
287 }
288 ];
289}