1use crate::float::Float;
10use argmin_math::{
11 ArgminAdd, ArgminDot, ArgminL1Norm, ArgminL2Norm, ArgminMinMax, ArgminMul, ArgminSignum,
12 ArgminSub, ArgminZeroLike,
13};
14use ndarray::{Array, ArrayBase, Data, Dimension, Ix1, Ix2, Zip};
15#[cfg(feature = "serde")]
16use serde_crate::{Deserialize, Serialize};
17
18pub fn elem_dot<F: linfa::Float, A1: Data<Elem = F>, A2: Data<Elem = F>, D: Dimension>(
19 a: &ArrayBase<A1, D>,
20 b: &ArrayBase<A2, D>,
21) -> F {
22 Zip::from(a)
23 .and(b)
24 .fold(F::zero(), |acc, &a, &b| acc + a * b)
25}
26
27#[derive(Debug, Clone, PartialEq, Default)]
28#[cfg_attr(
29 feature = "serde",
30 derive(Serialize, Deserialize),
31 serde(crate = "serde_crate")
32)]
33pub struct ArgminParam<F, D: Dimension>(pub Array<F, D>);
34
35impl<F, D: Dimension> ArgminParam<F, D> {
36 #[inline]
37 pub fn as_array(&self) -> &Array<F, D> {
38 &self.0
39 }
40}
41
42impl<F: Float, D: Dimension> ArgminSub<F, ArgminParam<F, D>> for ArgminParam<F, D> {
43 fn sub(&self, other: &F) -> ArgminParam<F, D> {
44 ArgminParam(&self.0 - *other)
45 }
46}
47
48impl<F: Float, D: Dimension> ArgminSub<ArgminParam<F, D>, ArgminParam<F, D>> for ArgminParam<F, D> {
49 fn sub(&self, other: &ArgminParam<F, D>) -> ArgminParam<F, D> {
50 ArgminParam(&self.0 - &other.0)
51 }
52}
53
54impl<F: Float, D: Dimension> ArgminAdd<F, ArgminParam<F, D>> for ArgminParam<F, D> {
55 fn add(&self, other: &F) -> ArgminParam<F, D> {
56 ArgminParam(&self.0 + *other)
57 }
58}
59
60impl<F: Float, D: Dimension> ArgminAdd<ArgminParam<F, D>, ArgminParam<F, D>> for ArgminParam<F, D> {
61 fn add(&self, other: &ArgminParam<F, D>) -> ArgminParam<F, D> {
62 ArgminParam(&self.0 + &other.0)
63 }
64}
65
66impl<F: Float, D: Dimension> ArgminDot<ArgminParam<F, D>, F> for ArgminParam<F, D> {
67 fn dot(&self, other: &ArgminParam<F, D>) -> F {
68 elem_dot(&self.0, &other.0)
69 }
70}
71
72impl<F: Float, D: Dimension> ArgminL1Norm<F> for ArgminParam<F, D> {
73 fn l1_norm(&self) -> F {
74 num_traits::Float::sqrt(elem_dot(&self.0, &self.0))
75 }
76}
77
78impl<F: Float, D: Dimension> ArgminL2Norm<F> for ArgminParam<F, D> {
79 fn l2_norm(&self) -> F {
80 num_traits::Float::sqrt(elem_dot(&self.0, &self.0))
81 }
82}
83
84impl<F: Float, D: Dimension> ArgminMul<F, ArgminParam<F, D>> for ArgminParam<F, D> {
85 fn mul(&self, other: &F) -> ArgminParam<F, D> {
86 ArgminParam(&self.0 * *other)
87 }
88}
89
90impl<F: Float, D: Dimension> ArgminMul<ArgminParam<F, D>, ArgminParam<F, D>> for ArgminParam<F, D> {
91 fn mul(&self, other: &ArgminParam<F, D>) -> ArgminParam<F, D> {
92 ArgminParam(&self.0 * &other.0)
93 }
94}
95
96impl<F: Float, D: Dimension> ArgminSignum for ArgminParam<F, D> {
97 fn signum(self) -> ArgminParam<F, D> {
98 self
99 }
100}
101
102impl<F: Float, D: Dimension> ArgminZeroLike for ArgminParam<F, D> {
103 fn zero_like(&self) -> ArgminParam<F, D> {
104 let dims = self.as_array().raw_dim();
105 ArgminParam(Array::zeros(dims))
106 }
107}
108
109impl<F: Float> ArgminMinMax for ArgminParam<F, Ix1> {
110 fn min(x: &Self, y: &Self) -> ArgminParam<F, Ix1> {
111 let x_array = x.as_array();
112 let y_array = y.as_array();
113
114 assert_eq!(x_array.shape(), y_array.shape());
115 ArgminParam(
116 x_array
117 .iter()
118 .zip(y_array)
119 .map(|(&a, &b)| if a < b { a } else { b })
120 .collect(),
121 )
122 }
123
124 fn max(x: &Self, y: &Self) -> ArgminParam<F, Ix1> {
125 let x_array = x.as_array();
126 let y_array = y.as_array();
127
128 assert_eq!(x_array.shape(), y_array.shape());
129 ArgminParam(
130 x_array
131 .iter()
132 .zip(y_array)
133 .map(|(&a, &b)| if a > b { a } else { b })
134 .collect(),
135 )
136 }
137}
138
139impl<F: Float> ArgminMinMax for ArgminParam<F, Ix2> {
140 fn min(x: &Self, y: &Self) -> ArgminParam<F, Ix2> {
141 let x_array = x.as_array();
142 let y_array = y.as_array();
143
144 assert_eq!(x_array.shape(), y_array.shape());
145 let m = x_array.shape()[0];
146 let n = x_array.shape()[1];
147 let mut out = x_array.clone();
148 for i in 0..m {
149 for j in 0..n {
150 let a = x_array[(i, j)];
151 let b = y_array[(i, j)];
152 out[(i, j)] = if a < b { a } else { b };
153 }
154 }
155 ArgminParam(out)
156 }
157
158 fn max(x: &Self, y: &Self) -> ArgminParam<F, Ix2> {
159 let x_array = x.as_array();
160 let y_array = y.as_array();
161
162 assert_eq!(x_array.shape(), y_array.shape());
163 let m = x_array.shape()[0];
164 let n = x_array.shape()[1];
165 let mut out = x_array.clone();
166 for i in 0..m {
167 for j in 0..n {
168 let a = x_array[(i, j)];
169 let b = y_array[(i, j)];
170 out[(i, j)] = if a > b { a } else { b };
171 }
172 }
173 ArgminParam(out)
174 }
175}