linfa_logistic/
argmin_param.rs

1//! This module defines newtypes for ndarray's Array.
2//!
3//! This is necessary to be able to abstract over floats (f32 and f64) so that
4//! the logistic regression code can be abstract in the float type it works
5//! with.
6//!
7//! Unfortunately, this requires that we re-implement some traits from Argmin.
8
9use crate::float::Float;
10use argmin_math::{
11    ArgminAdd, ArgminDot, ArgminL1Norm, ArgminL2Norm, ArgminMinMax, ArgminMul, ArgminSignum,
12    ArgminSub, ArgminZeroLike,
13};
14use ndarray::{Array, ArrayBase, Data, Dimension, Ix1, Ix2, Zip};
15#[cfg(feature = "serde")]
16use serde_crate::{Deserialize, Serialize};
17
18pub fn elem_dot<F: linfa::Float, A1: Data<Elem = F>, A2: Data<Elem = F>, D: Dimension>(
19    a: &ArrayBase<A1, D>,
20    b: &ArrayBase<A2, D>,
21) -> F {
22    Zip::from(a)
23        .and(b)
24        .fold(F::zero(), |acc, &a, &b| acc + a * b)
25}
26
27#[derive(Debug, Clone, PartialEq, Default)]
28#[cfg_attr(
29    feature = "serde",
30    derive(Serialize, Deserialize),
31    serde(crate = "serde_crate")
32)]
33pub struct ArgminParam<F, D: Dimension>(pub Array<F, D>);
34
35impl<F, D: Dimension> ArgminParam<F, D> {
36    #[inline]
37    pub fn as_array(&self) -> &Array<F, D> {
38        &self.0
39    }
40}
41
42impl<F: Float, D: Dimension> ArgminSub<F, ArgminParam<F, D>> for ArgminParam<F, D> {
43    fn sub(&self, other: &F) -> ArgminParam<F, D> {
44        ArgminParam(&self.0 - *other)
45    }
46}
47
48impl<F: Float, D: Dimension> ArgminSub<ArgminParam<F, D>, ArgminParam<F, D>> for ArgminParam<F, D> {
49    fn sub(&self, other: &ArgminParam<F, D>) -> ArgminParam<F, D> {
50        ArgminParam(&self.0 - &other.0)
51    }
52}
53
54impl<F: Float, D: Dimension> ArgminAdd<F, ArgminParam<F, D>> for ArgminParam<F, D> {
55    fn add(&self, other: &F) -> ArgminParam<F, D> {
56        ArgminParam(&self.0 + *other)
57    }
58}
59
60impl<F: Float, D: Dimension> ArgminAdd<ArgminParam<F, D>, ArgminParam<F, D>> for ArgminParam<F, D> {
61    fn add(&self, other: &ArgminParam<F, D>) -> ArgminParam<F, D> {
62        ArgminParam(&self.0 + &other.0)
63    }
64}
65
66impl<F: Float, D: Dimension> ArgminDot<ArgminParam<F, D>, F> for ArgminParam<F, D> {
67    fn dot(&self, other: &ArgminParam<F, D>) -> F {
68        elem_dot(&self.0, &other.0)
69    }
70}
71
72impl<F: Float, D: Dimension> ArgminL1Norm<F> for ArgminParam<F, D> {
73    fn l1_norm(&self) -> F {
74        num_traits::Float::sqrt(elem_dot(&self.0, &self.0))
75    }
76}
77
78impl<F: Float, D: Dimension> ArgminL2Norm<F> for ArgminParam<F, D> {
79    fn l2_norm(&self) -> F {
80        num_traits::Float::sqrt(elem_dot(&self.0, &self.0))
81    }
82}
83
84impl<F: Float, D: Dimension> ArgminMul<F, ArgminParam<F, D>> for ArgminParam<F, D> {
85    fn mul(&self, other: &F) -> ArgminParam<F, D> {
86        ArgminParam(&self.0 * *other)
87    }
88}
89
90impl<F: Float, D: Dimension> ArgminMul<ArgminParam<F, D>, ArgminParam<F, D>> for ArgminParam<F, D> {
91    fn mul(&self, other: &ArgminParam<F, D>) -> ArgminParam<F, D> {
92        ArgminParam(&self.0 * &other.0)
93    }
94}
95
96impl<F: Float, D: Dimension> ArgminSignum for ArgminParam<F, D> {
97    fn signum(self) -> ArgminParam<F, D> {
98        self
99    }
100}
101
102impl<F: Float, D: Dimension> ArgminZeroLike for ArgminParam<F, D> {
103    fn zero_like(&self) -> ArgminParam<F, D> {
104        let dims = self.as_array().raw_dim();
105        ArgminParam(Array::zeros(dims))
106    }
107}
108
109impl<F: Float> ArgminMinMax for ArgminParam<F, Ix1> {
110    fn min(x: &Self, y: &Self) -> ArgminParam<F, Ix1> {
111        let x_array = x.as_array();
112        let y_array = y.as_array();
113
114        assert_eq!(x_array.shape(), y_array.shape());
115        ArgminParam(
116            x_array
117                .iter()
118                .zip(y_array)
119                .map(|(&a, &b)| if a < b { a } else { b })
120                .collect(),
121        )
122    }
123
124    fn max(x: &Self, y: &Self) -> ArgminParam<F, Ix1> {
125        let x_array = x.as_array();
126        let y_array = y.as_array();
127
128        assert_eq!(x_array.shape(), y_array.shape());
129        ArgminParam(
130            x_array
131                .iter()
132                .zip(y_array)
133                .map(|(&a, &b)| if a > b { a } else { b })
134                .collect(),
135        )
136    }
137}
138
139impl<F: Float> ArgminMinMax for ArgminParam<F, Ix2> {
140    fn min(x: &Self, y: &Self) -> ArgminParam<F, Ix2> {
141        let x_array = x.as_array();
142        let y_array = y.as_array();
143
144        assert_eq!(x_array.shape(), y_array.shape());
145        let m = x_array.shape()[0];
146        let n = x_array.shape()[1];
147        let mut out = x_array.clone();
148        for i in 0..m {
149            for j in 0..n {
150                let a = x_array[(i, j)];
151                let b = y_array[(i, j)];
152                out[(i, j)] = if a < b { a } else { b };
153            }
154        }
155        ArgminParam(out)
156    }
157
158    fn max(x: &Self, y: &Self) -> ArgminParam<F, Ix2> {
159        let x_array = x.as_array();
160        let y_array = y.as_array();
161
162        assert_eq!(x_array.shape(), y_array.shape());
163        let m = x_array.shape()[0];
164        let n = x_array.shape()[1];
165        let mut out = x_array.clone();
166        for i in 0..m {
167            for j in 0..n {
168                let a = x_array[(i, j)];
169                let b = y_array[(i, j)];
170                out[(i, j)] = if a > b { a } else { b };
171            }
172        }
173        ArgminParam(out)
174    }
175}