1use linfa::{
3 dataset::{AsSingleTargets, DatasetBase},
4 traits::Fit,
5 traits::Transformer,
6 traits::{Predict, PredictInplace},
7};
8use linfa_kernel::Kernel;
9use ndarray::{Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Data, Ix2};
10
11use super::error::{Result, SvmError};
12use super::permutable_kernel::PermutableKernelRegression;
13use super::solver_smo::SolverState;
14use super::SolverParams;
15use super::{Float, Svm, SvmValidParams};
16
17pub fn fit_epsilon<F: Float>(
29 params: SolverParams<F>,
30 dataset: ArrayView2<F>,
31 kernel: Kernel<F>,
32 target: &[F],
33 c: F,
34 p: F,
35) -> Svm<F, F> {
36 let mut linear_term = vec![F::zero(); 2 * target.len()];
37 let mut targets = vec![true; 2 * target.len()];
38
39 for i in 0..target.len() {
40 linear_term[i] = p - target[i];
41 targets[i] = true;
42
43 linear_term[i + target.len()] = p + target[i];
44 targets[i + target.len()] = false;
45 }
46
47 let kernel = PermutableKernelRegression::new(kernel);
48 let solver = SolverState::new(
49 vec![F::zero(); 2 * target.len()],
50 linear_term,
51 targets.to_vec(),
52 dataset,
53 kernel,
54 vec![c; 2 * target.len()],
55 params,
56 false,
57 );
58
59 let res = solver.solve();
60
61 res.with_phantom()
62}
63
64pub fn fit_nu<F: Float>(
76 params: SolverParams<F>,
77 dataset: ArrayView2<F>,
78 kernel: Kernel<F>,
79 target: &[F],
80 nu: F,
81 c: F,
82) -> Svm<F, F> {
83 let mut alpha = vec![F::zero(); 2 * target.len()];
84 let mut linear_term = vec![F::zero(); 2 * target.len()];
85 let mut targets = vec![true; 2 * target.len()];
86
87 let mut sum = c * nu * F::cast(target.len()) / F::cast(2.0);
88 for i in 0..target.len() {
89 alpha[i] = F::min(sum, c);
90 alpha[i + target.len()] = F::min(sum, c);
91 sum -= alpha[i];
92
93 linear_term[i] = -target[i];
94 targets[i] = true;
95
96 linear_term[i + target.len()] = target[i];
97 targets[i + target.len()] = false;
98 }
99
100 let kernel = PermutableKernelRegression::new(kernel);
101 let solver = SolverState::new(
102 alpha,
103 linear_term,
104 targets.to_vec(),
105 dataset,
106 kernel,
107 vec![c; 2 * target.len()],
108 params,
109 false,
110 );
111
112 let res = solver.solve();
113
114 res.with_phantom()
115}
116
117macro_rules! impl_regression {
121 ($records:ty, $targets:ty, $f:ty) => {
122 impl Fit<$records, $targets, SvmError> for SvmValidParams<$f, $f> {
123 type Object = Svm<$f, $f>;
124
125 fn fit(&self, dataset: &DatasetBase<$records, $targets>) -> Result<Self::Object> {
126 let kernel = self.kernel_params().transform(dataset.records());
127 let target = dataset.as_single_targets();
128 let target = target.as_slice().unwrap();
129
130 let ret = match (self.c(), self.nu()) {
131 (Some((c, p)), _) => fit_epsilon(
132 self.solver_params().clone(),
133 dataset.records().view(),
134 kernel,
135 target,
136 c,
137 p,
138 ),
139 (None, Some((nu, c))) => fit_nu(
140 self.solver_params().clone(),
141 dataset.records().view(),
142 kernel,
143 target,
144 nu,
145 c,
146 ),
147 _ => panic!("Set either C value or Nu value"),
148 };
149
150 Ok(ret)
151 }
152 }
153 };
154}
155
156impl_regression!(Array2<f32>, Array1<f32>, f32);
157impl_regression!(Array2<f64>, Array1<f64>, f64);
158impl_regression!(ArrayView2<'_, f32>, ArrayView1<'_, f32>, f32);
159impl_regression!(ArrayView2<'_, f64>, ArrayView1<'_, f64>, f64);
160
161macro_rules! impl_predict {
162 ( $($t:ty),* ) => {
163 $(
164 impl Predict<Array1<$t>, $t> for Svm<$t, $t> {
166 fn predict(&self, data: Array1<$t>) -> $t {
167 self.weighted_sum(&data) - self.rho
168 }
169 }
170 impl<'a> Predict<ArrayView1<'a, $t>, $t> for Svm<$t, $t> {
172 fn predict(&self, data: ArrayView1<'a, $t>) -> $t {
173 self.weighted_sum(&data) - self.rho
174 }
175 }
176
177 impl<D: Data<Elem = $t>> PredictInplace<ArrayBase<D, Ix2>, Array1<$t>> for Svm<$t, $t> {
182 fn predict_inplace(&'_ self, data: &ArrayBase<D, Ix2>, targets: &mut Array1<$t>) {
183 assert_eq!(data.nrows(), targets.len(), "The number of data points must match the number of output targets.");
184
185 for (data, target) in data.outer_iter().zip(targets.iter_mut()) {
186 *target = self.weighted_sum(&data) - self.rho;
187 }
188 }
189
190 fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<$t> {
191 Array1::zeros(x.nrows())
192 }
193 }
194
195 ) *
196 }
197}
198
199impl_predict!(f32, f64);
200
201#[cfg(test)]
202pub mod tests {
203 use super::Svm;
204 use crate::error::Result;
205
206 use linfa::dataset::Dataset;
207 use linfa::metrics::SingleTargetRegression;
208 use linfa::traits::{Fit, Predict};
209 use linfa::DatasetBase;
210 use ndarray::{Array, Array1, Array2};
211
212 fn _check_model(model: Svm<f64, f64>, dataset: &DatasetBase<Array2<f64>, Array1<f64>>) {
213 println!("{model}");
214 let predicted = model.predict(dataset.records());
215 let err = predicted.mean_squared_error(&dataset).unwrap();
216 println!("err={err}");
217 assert!(predicted.mean_squared_error(&dataset).unwrap() < 1e-2);
218 }
219
220 #[test]
221 fn test_epsilon_regression_linear() -> Result<()> {
222 let targets = Array::linspace(0f64, 10., 100);
224 let records = targets.clone().into_shape((100, 1)).unwrap();
225 let dataset = Dataset::new(records, targets);
226
227 let model = Svm::params()
228 .c_svr(5., None)
229 .linear_kernel()
230 .fit(&dataset)?;
231 _check_model(model, &dataset);
232
233 #[allow(deprecated)]
235 let model2 = Svm::params()
236 .c_eps(5., 1e-3)
237 .linear_kernel()
238 .fit(&dataset)?;
239 _check_model(model2, &dataset);
240
241 Ok(())
242 }
243
244 #[test]
245 fn test_nu_regression_linear() -> Result<()> {
246 let targets = Array::linspace(0f64, 10., 100);
248 let records = targets.clone().into_shape((100, 1)).unwrap();
249 let dataset = Dataset::new(records, targets);
250
251 let model = Svm::params()
253 .nu_svr(0.5, Some(1.))
254 .linear_kernel()
255 .fit(&dataset)?;
256 _check_model(model, &dataset);
257
258 #[allow(deprecated)]
260 let model2 = Svm::params()
261 .nu_eps(0.5, 1e-3)
262 .linear_kernel()
263 .fit(&dataset)?;
264 _check_model(model2, &dataset);
265 Ok(())
266 }
267
268 #[test]
269 fn test_epsilon_regression_gaussian() -> Result<()> {
270 let records = Array::linspace(0f64, 10., 100)
271 .into_shape((100, 1))
272 .unwrap();
273 let sin_curve = records.mapv(|v| v.sin()).into_shape((100,)).unwrap();
274 let dataset = Dataset::new(records, sin_curve);
275
276 let model = Svm::params()
277 .c_svr(100., Some(0.1))
278 .gaussian_kernel(10.)
279 .eps(1e-3)
280 .fit(&dataset)?;
281 _check_model(model, &dataset);
282 Ok(())
283 }
284
285 #[test]
286 fn test_nu_regression_polynomial() -> Result<()> {
287 let n = 100;
288 let records = Array::linspace(0f64, 5., n).into_shape((n, 1)).unwrap();
289 let sin_curve = records.mapv(|v| v.sin()).into_shape((n,)).unwrap();
290 let dataset = Dataset::new(records, sin_curve);
291
292 let model = Svm::params()
293 .nu_svr(0.01, None)
294 .polynomial_kernel(1., 3.)
295 .eps(1e-3)
296 .fit(&dataset)?;
297 _check_model(model, &dataset);
298 Ok(())
299 }
300}