linfa_clustering/dbscan/
hyperparams.rs1use linfa::{param_guard::TransformGuard, prelude::*, Float};
2use linfa_nn::{distance::Distance, NearestNeighbour};
3#[cfg(feature = "serde")]
4use serde_crate::{Deserialize, Serialize};
5use thiserror::Error;
6
7#[cfg_attr(
8 feature = "serde",
9 derive(Serialize, Deserialize),
10 serde(crate = "serde_crate")
11)]
12#[derive(Debug, Clone, PartialEq)]
13pub struct DbscanValidParams<F: Float, D: Distance<F>, N: NearestNeighbour> {
16 pub(crate) tolerance: F,
17 pub(crate) min_points: usize,
18 pub(crate) dist_fn: D,
19 pub(crate) nn_algo: N,
20}
21
22#[derive(Debug, Clone, PartialEq)]
23pub struct DbscanParams<F: Float, D: Distance<F>, N: NearestNeighbour>(DbscanValidParams<F, D, N>);
25
26#[derive(Error, Debug)]
27pub enum DbscanParamsError {
28 #[error("min_points must be greater than 1")]
29 MinPoints,
30 #[error("tolerance must be greater than 0")]
31 Tolerance,
32}
33
34impl<F: Float, D: Distance<F>, N: NearestNeighbour> DbscanParams<F, D, N> {
35 pub(crate) fn new(min_points: usize, dist_fn: D, nn_algo: N) -> Self {
36 Self(DbscanValidParams {
37 min_points,
38 tolerance: F::cast(1e-4),
39 dist_fn,
40 nn_algo,
41 })
42 }
43
44 pub fn tolerance(mut self, tolerance: F) -> Self {
46 self.0.tolerance = tolerance;
47 self
48 }
49
50 pub fn nn_algo(mut self, nn_algo: N) -> Self {
52 self.0.nn_algo = nn_algo;
53 self
54 }
55
56 pub fn dist_fn(mut self, dist_fn: D) -> Self {
58 self.0.dist_fn = dist_fn;
59 self
60 }
61}
62
63impl<F: Float, D: Distance<F>, N: NearestNeighbour> ParamGuard for DbscanParams<F, D, N> {
64 type Checked = DbscanValidParams<F, D, N>;
65 type Error = DbscanParamsError;
66
67 fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
68 if self.0.min_points <= 1 {
69 Err(DbscanParamsError::MinPoints)
70 } else if self.0.tolerance <= F::zero() {
71 Err(DbscanParamsError::Tolerance)
72 } else {
73 Ok(&self.0)
74 }
75 }
76
77 fn check(self) -> Result<Self::Checked, Self::Error> {
78 self.check_ref()?;
79 Ok(self.0)
80 }
81}
82impl<F: Float, D: Distance<F>, N: NearestNeighbour> TransformGuard for DbscanParams<F, D, N> {}
83
84impl<F: Float, D: Distance<F>, N: NearestNeighbour> DbscanValidParams<F, D, N> {
85 pub fn tolerance(&self) -> F {
87 self.tolerance
88 }
89
90 pub fn minimum_points(&self) -> usize {
93 self.min_points
94 }
95
96 pub fn dist_fn(&self) -> &D {
98 &self.dist_fn
99 }
100
101 pub fn nn_algo(&self) -> &N {
103 &self.nn_algo
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use linfa_nn::{distance::L2Dist, CommonNearestNeighbour, KdTree};
110
111 use super::*;
112
113 #[test]
114 fn autotraits() {
115 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
116 has_autotraits::<DbscanParamsError>();
117 has_autotraits::<DbscanParams<f64, L2Dist, KdTree>>();
118 has_autotraits::<DbscanValidParams<f64, L2Dist, KdTree>>();
119 }
120
121 #[test]
122 fn tolerance_cannot_be_zero() {
123 let res = DbscanParams::new(2, L2Dist, CommonNearestNeighbour::KdTree)
124 .tolerance(0.0)
125 .check();
126 assert!(matches!(res, Err(DbscanParamsError::Tolerance)));
127 }
128
129 #[test]
130 fn min_points_at_least_2() {
131 let res = DbscanParams::new(1, L2Dist, CommonNearestNeighbour::KdTree)
132 .tolerance(3.3)
133 .check();
134 assert!(matches!(res, Err(DbscanParamsError::MinPoints)));
135 }
136}