linfa_clustering/dbscan/
hyperparams.rs

1use linfa::{param_guard::TransformGuard, prelude::*, Float};
2use linfa_nn::{distance::Distance, NearestNeighbour};
3#[cfg(feature = "serde")]
4use serde_crate::{Deserialize, Serialize};
5use thiserror::Error;
6
7#[cfg_attr(
8    feature = "serde",
9    derive(Serialize, Deserialize),
10    serde(crate = "serde_crate")
11)]
12#[derive(Debug, Clone, PartialEq)]
13/// The set of hyperparameters that can be specified for the execution of
14/// the [DBSCAN algorithm](crate::Dbscan).
15pub struct DbscanValidParams<F: Float, D: Distance<F>, N: NearestNeighbour> {
16    pub(crate) tolerance: F,
17    pub(crate) min_points: usize,
18    pub(crate) dist_fn: D,
19    pub(crate) nn_algo: N,
20}
21
22#[derive(Debug, Clone, PartialEq)]
23/// Helper struct for building a set of [DBSCAN hyperparameters](DbscanParams)
24pub struct DbscanParams<F: Float, D: Distance<F>, N: NearestNeighbour>(DbscanValidParams<F, D, N>);
25
26#[derive(Error, Debug)]
27pub enum DbscanParamsError {
28    #[error("min_points must be greater than 1")]
29    MinPoints,
30    #[error("tolerance must be greater than 0")]
31    Tolerance,
32}
33
34impl<F: Float, D: Distance<F>, N: NearestNeighbour> DbscanParams<F, D, N> {
35    pub(crate) fn new(min_points: usize, dist_fn: D, nn_algo: N) -> Self {
36        Self(DbscanValidParams {
37            min_points,
38            tolerance: F::cast(1e-4),
39            dist_fn,
40            nn_algo,
41        })
42    }
43
44    /// Set the tolerance
45    pub fn tolerance(mut self, tolerance: F) -> Self {
46        self.0.tolerance = tolerance;
47        self
48    }
49
50    /// Set the nearest neighbour algorithm to be used
51    pub fn nn_algo(mut self, nn_algo: N) -> Self {
52        self.0.nn_algo = nn_algo;
53        self
54    }
55
56    /// Set the distance metric
57    pub fn dist_fn(mut self, dist_fn: D) -> Self {
58        self.0.dist_fn = dist_fn;
59        self
60    }
61}
62
63impl<F: Float, D: Distance<F>, N: NearestNeighbour> ParamGuard for DbscanParams<F, D, N> {
64    type Checked = DbscanValidParams<F, D, N>;
65    type Error = DbscanParamsError;
66
67    fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
68        if self.0.min_points <= 1 {
69            Err(DbscanParamsError::MinPoints)
70        } else if self.0.tolerance <= F::zero() {
71            Err(DbscanParamsError::Tolerance)
72        } else {
73            Ok(&self.0)
74        }
75    }
76
77    fn check(self) -> Result<Self::Checked, Self::Error> {
78        self.check_ref()?;
79        Ok(self.0)
80    }
81}
82impl<F: Float, D: Distance<F>, N: NearestNeighbour> TransformGuard for DbscanParams<F, D, N> {}
83
84impl<F: Float, D: Distance<F>, N: NearestNeighbour> DbscanValidParams<F, D, N> {
85    /// Maximum distance between two points to be considered neighbors
86    pub fn tolerance(&self) -> F {
87        self.tolerance
88    }
89
90    /// Minimum number of neighboring points a point needs to have to be a core                                                                                                
91    /// point and not a noise point.
92    pub fn minimum_points(&self) -> usize {
93        self.min_points
94    }
95
96    /// Distance metric used in the DBSCAN calculation
97    pub fn dist_fn(&self) -> &D {
98        &self.dist_fn
99    }
100
101    /// Nearest neighbour algorithm used for range queries
102    pub fn nn_algo(&self) -> &N {
103        &self.nn_algo
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use linfa_nn::{distance::L2Dist, CommonNearestNeighbour, KdTree};
110
111    use super::*;
112
113    #[test]
114    fn autotraits() {
115        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
116        has_autotraits::<DbscanParamsError>();
117        has_autotraits::<DbscanParams<f64, L2Dist, KdTree>>();
118        has_autotraits::<DbscanValidParams<f64, L2Dist, KdTree>>();
119    }
120
121    #[test]
122    fn tolerance_cannot_be_zero() {
123        let res = DbscanParams::new(2, L2Dist, CommonNearestNeighbour::KdTree)
124            .tolerance(0.0)
125            .check();
126        assert!(matches!(res, Err(DbscanParamsError::Tolerance)));
127    }
128
129    #[test]
130    fn min_points_at_least_2() {
131        let res = DbscanParams::new(1, L2Dist, CommonNearestNeighbour::KdTree)
132            .tolerance(3.3)
133            .check();
134        assert!(matches!(res, Err(DbscanParamsError::MinPoints)));
135    }
136}