linfa_clustering/optics/
hyperparams.rs1use crate::optics::errors::OpticsError;
2use linfa::{param_guard::TransformGuard, Float, ParamGuard};
3#[cfg(feature = "serde")]
4use serde_crate::{Deserialize, Serialize};
5
6#[derive(Clone, Debug, PartialEq)]
7#[cfg_attr(
8 feature = "serde",
9 derive(Serialize, Deserialize),
10 serde(crate = "serde_crate")
11)]
12pub struct OpticsValidParams<F, D, N> {
15 tolerance: F,
17 dist_fn: D,
19 nn_algo: N,
21 min_points: usize,
24}
25
26impl<F: Float, D, N> OpticsValidParams<F, D, N> {
27 pub fn tolerance(&self) -> F {
30 self.tolerance
31 }
32
33 pub fn minimum_points(&self) -> usize {
36 self.min_points
37 }
38
39 pub fn dist_fn(&self) -> &D {
41 &self.dist_fn
42 }
43
44 pub fn nn_algo(&self) -> &N {
46 &self.nn_algo
47 }
48}
49
50#[derive(Clone, Debug, PartialEq)]
51#[cfg_attr(
52 feature = "serde",
53 derive(Serialize, Deserialize),
54 serde(crate = "serde_crate")
55)]
56pub struct OpticsParams<F, D, N>(OpticsValidParams<F, D, N>);
57
58impl<F: Float, D, N> OpticsParams<F, D, N> {
59 pub fn new(min_points: usize, dist_fn: D, nn_algo: N) -> Self {
60 Self(OpticsValidParams {
61 min_points,
62 tolerance: F::infinity(),
63 dist_fn,
64 nn_algo,
65 })
66 }
67
68 pub fn tolerance(mut self, tolerance: F) -> Self {
73 self.0.tolerance = tolerance;
74 self
75 }
76
77 pub fn dist_fn(mut self, dist_fn: D) -> Self {
79 self.0.dist_fn = dist_fn;
80 self
81 }
82
83 pub fn nn_algo(mut self, nn_algo: N) -> Self {
85 self.0.nn_algo = nn_algo;
86 self
87 }
88}
89
90impl<F: Float, D, N> ParamGuard for OpticsParams<F, D, N> {
91 type Checked = OpticsValidParams<F, D, N>;
92 type Error = OpticsError;
93
94 fn check_ref(&self) -> Result<&Self::Checked, OpticsError> {
95 if self.0.tolerance <= F::zero() {
96 Err(OpticsError::InvalidValue(
97 "`tolerance` must be greater than 0!".to_string(),
98 ))
99 } else if self.0.min_points <= 1 {
100 Err(OpticsError::InvalidValue(
102 "`min_points` must be greater than 1!".to_string(),
103 ))
104 } else {
105 Ok(&self.0)
106 }
107 }
108
109 fn check(self) -> Result<Self::Checked, OpticsError> {
110 self.check_ref()?;
111 Ok(self.0)
112 }
113}
114impl<F: Float, D, N> TransformGuard for OpticsParams<F, D, N> {}