1use crate::{KernelMethod, SolverParams, Svm, SvmError};
2use linfa::{platt_scaling::PlattParams, Float, ParamGuard, Platt};
3use linfa_kernel::{Kernel, KernelParams};
4use std::marker::PhantomData;
5
6#[derive(Debug, Clone, PartialEq)]
37pub struct SvmValidParams<F: Float, T> {
38 c: Option<(F, F)>,
39 nu: Option<(F, F)>,
40 solver_params: SolverParams<F>,
41 phantom: PhantomData<T>,
42 kernel: KernelParams<F>,
43 platt: PlattParams<F, ()>,
44}
45
46impl<F: Float, T> SvmValidParams<F, T> {
47 pub fn c(&self) -> Option<(F, F)> {
48 self.c
49 }
50
51 pub fn nu(&self) -> Option<(F, F)> {
52 self.nu
53 }
54
55 pub fn solver_params(&self) -> &SolverParams<F> {
56 &self.solver_params
57 }
58
59 pub fn kernel_params(&self) -> &KernelParams<F> {
60 &self.kernel
61 }
62
63 pub fn platt_params(&self) -> &PlattParams<F, ()> {
64 &self.platt
65 }
66}
67
68#[derive(Debug, Clone, PartialEq)]
69pub struct SvmParams<F: Float, T>(SvmValidParams<F, T>);
70
71impl<F: Float, T> SvmParams<F, T> {
72 pub fn new() -> Self {
80 Self(SvmValidParams {
81 c: Some((F::one(), F::one())),
82 nu: None,
83 solver_params: SolverParams {
84 eps: F::cast(1e-7),
85 shrinking: false,
86 },
87 phantom: PhantomData,
88 kernel: Kernel::params().method(KernelMethod::Linear),
89 platt: Platt::params(),
90 })
91 }
92
93 pub fn eps(mut self, new_eps: F) -> Self {
98 self.0.solver_params.eps = new_eps;
99 self
100 }
101
102 pub fn shrinking(mut self, shrinking: bool) -> Self {
107 self.0.solver_params.shrinking = shrinking;
108 self
109 }
110
111 pub fn with_kernel_params(mut self, kernel: KernelParams<F>) -> Self {
119 self.0.kernel = kernel;
120 self
121 }
122
123 pub fn with_platt_params(mut self, platt: PlattParams<F, ()>) -> Self {
125 self.0.platt = platt;
126 self
127 }
128
129 pub fn gaussian_kernel(mut self, eps: F) -> Self {
132 self.0.kernel = Kernel::params().method(KernelMethod::Gaussian(eps));
133 self
134 }
135
136 pub fn polynomial_kernel(mut self, constant: F, degree: F) -> Self {
139 self.0.kernel = Kernel::params().method(KernelMethod::Polynomial(constant, degree));
140 self
141 }
142
143 pub fn linear_kernel(mut self) -> Self {
146 self.0.kernel = Kernel::params().method(KernelMethod::Linear);
147 self
148 }
149}
150
151impl<F: Float, T> SvmParams<F, T> {
152 pub fn pos_neg_weights(mut self, c_pos: F, c_neg: F) -> Self {
154 self.0.c = Some((c_pos, c_neg));
155 self.0.nu = None;
156 self
157 }
158
159 pub fn nu_weight(mut self, nu: F) -> Self {
164 self.0.nu = Some((nu, nu));
165 self.0.c = None;
166 self
167 }
168}
169
170impl<F: Float> SvmParams<F, F> {
171 #[deprecated(since = "0.7.2", note = "Use .c_svr() and .eps()")]
174 pub fn c_eps(mut self, c: F, eps: F) -> Self {
175 self.0.c = Some((c, F::cast(0.1)));
176 self.0.nu = None;
177 self.0.solver_params.eps = eps;
178 self
179 }
180
181 #[deprecated(since = "0.7.2", note = "Use .nu_svr() and .eps()")]
184 pub fn nu_eps(mut self, nu: F, eps: F) -> Self {
185 self.0.nu = Some((nu, F::one()));
186 self.0.c = None;
187 self.0.solver_params.eps = eps;
188 self
189 }
190
191 pub fn c_svr(mut self, c: F, loss_eps: Option<F>) -> Self {
193 self.0.c = Some((c, loss_eps.unwrap_or(F::cast(0.1))));
194 self.0.nu = None;
195 self
196 }
197
198 pub fn nu_svr(mut self, nu: F, c: Option<F>) -> Self {
200 self.0.nu = Some((nu, c.unwrap_or(F::one())));
201 self.0.c = None;
202 self
203 }
204}
205
206impl<F: Float, L> Default for SvmParams<F, L> {
207 fn default() -> Self {
208 Self::new()
209 }
210}
211
212impl<F: Float, L> Svm<F, L> {
213 pub fn params() -> SvmParams<F, L> {
214 SvmParams::new()
215 }
216}
217
218impl<F: Float, L> ParamGuard for SvmParams<F, L> {
219 type Checked = SvmValidParams<F, L>;
220 type Error = SvmError;
221
222 fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
223 self.0.platt_params().check_ref()?;
224
225 if self.0.solver_params.eps.is_negative()
226 || self.0.solver_params.eps.is_nan()
227 || self.0.solver_params.eps.is_infinite()
228 {
229 return Err(SvmError::InvalidEps(
230 self.0.solver_params.eps.to_f32().unwrap(),
231 ));
232 }
233 if let Some((c1, c2)) = self.0.c {
234 if c1 <= F::zero() || c2 <= F::zero() {
235 return Err(SvmError::InvalidC((
236 c1.to_f32().unwrap(),
237 c2.to_f32().unwrap(),
238 )));
239 }
240 }
241 if let Some((nu, _)) = self.0.nu {
242 if nu <= F::zero() || nu > F::one() {
243 return Err(SvmError::InvalidNu(nu.to_f32().unwrap()));
244 }
245 }
246
247 Ok(&self.0)
248 }
249
250 fn check(self) -> Result<Self::Checked, Self::Error> {
251 self.check_ref()?;
252 Ok(self.0)
253 }
254}