linfa_clustering/optics/
hyperparams.rs

1use crate::optics::errors::OpticsError;
2use linfa::{param_guard::TransformGuard, Float, ParamGuard};
3#[cfg(feature = "serde")]
4use serde_crate::{Deserialize, Serialize};
5
6#[derive(Clone, Debug, PartialEq)]
7#[cfg_attr(
8    feature = "serde",
9    derive(Serialize, Deserialize),
10    serde(crate = "serde_crate")
11)]
12/// The set of hyperparameters that can be specified for the execution of
13/// the [OPTICS algorithm](crate::Optics).
14pub struct OpticsValidParams<F, D, N> {
15    /// Distance between points for them to be considered neighbours.
16    tolerance: F,
17    /// Distance metric to be used for the algorithm
18    dist_fn: D,
19    /// Nearest Neighbour algorithm to use to find the nearest points
20    nn_algo: N,
21    /// Minimum number of neighboring points a point needs to have to be a core
22    /// point and not a noise point.
23    min_points: usize,
24}
25
26impl<F: Float, D, N> OpticsValidParams<F, D, N> {
27    /// Two points are considered neighbors if the euclidean distance between
28    /// them is below the tolerance
29    pub fn tolerance(&self) -> F {
30        self.tolerance
31    }
32
33    /// Minimum number of a points in a neighborhood around a point for it to
34    /// not be considered noise
35    pub fn minimum_points(&self) -> usize {
36        self.min_points
37    }
38
39    /// Distance metric to be used for the algorithm
40    pub fn dist_fn(&self) -> &D {
41        &self.dist_fn
42    }
43
44    /// Nearest Neighbour algorithm to use to find the nearest points
45    pub fn nn_algo(&self) -> &N {
46        &self.nn_algo
47    }
48}
49
50#[derive(Clone, Debug, PartialEq)]
51#[cfg_attr(
52    feature = "serde",
53    derive(Serialize, Deserialize),
54    serde(crate = "serde_crate")
55)]
56pub struct OpticsParams<F, D, N>(OpticsValidParams<F, D, N>);
57
58impl<F: Float, D, N> OpticsParams<F, D, N> {
59    pub fn new(min_points: usize, dist_fn: D, nn_algo: N) -> Self {
60        Self(OpticsValidParams {
61            min_points,
62            tolerance: F::infinity(),
63            dist_fn,
64            nn_algo,
65        })
66    }
67
68    /// Distance between points for them to be considered neighbors. Compared to DBSCAN this
69    /// parameter isn't strictly necessary but improves execution time by not considering every
70    /// point. If the tolerance is too low the distances calculated are undefined and no clusters
71    /// will be returned.
72    pub fn tolerance(mut self, tolerance: F) -> Self {
73        self.0.tolerance = tolerance;
74        self
75    }
76
77    /// Distance metric to be used for the algorithm
78    pub fn dist_fn(mut self, dist_fn: D) -> Self {
79        self.0.dist_fn = dist_fn;
80        self
81    }
82
83    /// Nearest Neighbour algorithm to use to find the nearest points
84    pub fn nn_algo(mut self, nn_algo: N) -> Self {
85        self.0.nn_algo = nn_algo;
86        self
87    }
88}
89
90impl<F: Float, D, N> ParamGuard for OpticsParams<F, D, N> {
91    type Checked = OpticsValidParams<F, D, N>;
92    type Error = OpticsError;
93
94    fn check_ref(&self) -> Result<&Self::Checked, OpticsError> {
95        if self.0.tolerance <= F::zero() {
96            Err(OpticsError::InvalidValue(
97                "`tolerance` must be greater than 0!".to_string(),
98            ))
99        } else if self.0.min_points <= 1 {
100            // There is always at least one neighbor to a point (itself)
101            Err(OpticsError::InvalidValue(
102                "`min_points` must be greater than 1!".to_string(),
103            ))
104        } else {
105            Ok(&self.0)
106        }
107    }
108
109    fn check(self) -> Result<Self::Checked, OpticsError> {
110        self.check_ref()?;
111        Ok(self.0)
112    }
113}
114impl<F: Float, D, N> TransformGuard for OpticsParams<F, D, N> {}