linfa_reduction/random_projection/
hyperparams.rs1use std::{fmt::Debug, marker::PhantomData};
2
3use linfa::ParamGuard;
4
5use rand::Rng;
6
7use crate::ReductionError;
8
9use super::methods::ProjectionMethod;
10
11pub struct RandomProjectionParams<Proj: ProjectionMethod, R: Rng + Clone>(
24 pub(crate) RandomProjectionValidParams<Proj, R>,
25);
26
27impl<Proj: ProjectionMethod, R: Rng + Clone> RandomProjectionParams<Proj, R> {
28 pub fn target_dim(mut self, dim: usize) -> Self {
33 self.0.params = RandomProjectionParamsInner::Dimension { target_dim: dim };
34
35 self
36 }
37
38 pub fn eps(mut self, eps: f64) -> Self {
43 self.0.params = RandomProjectionParamsInner::Epsilon { eps };
44
45 self
46 }
47
48 pub fn with_rng<R2: Rng + Clone>(self, rng: R2) -> RandomProjectionParams<Proj, R2> {
50 RandomProjectionParams(RandomProjectionValidParams {
51 params: self.0.params,
52 rng,
53 marker: PhantomData,
54 })
55 }
56}
57
58#[derive(Debug, Clone, PartialEq)]
71pub struct RandomProjectionValidParams<Proj: ProjectionMethod, R: Rng + Clone> {
72 pub(super) params: RandomProjectionParamsInner,
73 pub(super) rng: R,
74 pub(crate) marker: PhantomData<Proj>,
75}
76
77#[derive(Debug, Clone, PartialEq)]
81pub(crate) enum RandomProjectionParamsInner {
82 Dimension { target_dim: usize },
83 Epsilon { eps: f64 },
84}
85
86impl RandomProjectionParamsInner {
87 fn target_dim(&self) -> Option<usize> {
88 use RandomProjectionParamsInner::*;
89 match self {
90 Dimension { target_dim } => Some(*target_dim),
91 Epsilon { .. } => None,
92 }
93 }
94
95 fn eps(&self) -> Option<f64> {
96 use RandomProjectionParamsInner::*;
97 match self {
98 Dimension { .. } => None,
99 Epsilon { eps } => Some(*eps),
100 }
101 }
102}
103
104impl<Proj: ProjectionMethod, R: Rng + Clone> RandomProjectionValidParams<Proj, R> {
105 pub fn target_dim(&self) -> Option<usize> {
106 self.params.target_dim()
107 }
108
109 pub fn eps(&self) -> Option<f64> {
110 self.params.eps()
111 }
112
113 pub fn rng(&self) -> &R {
114 &self.rng
115 }
116}
117
118impl<Proj: ProjectionMethod, R: Rng + Clone> ParamGuard for RandomProjectionParams<Proj, R> {
119 type Checked = RandomProjectionValidParams<Proj, R>;
120 type Error = ReductionError;
121
122 fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
123 match self.0.params {
124 RandomProjectionParamsInner::Dimension { target_dim } => {
125 if target_dim == 0 {
126 return Err(ReductionError::NonPositiveEmbeddingSize);
127 }
128 }
129 RandomProjectionParamsInner::Epsilon { eps } => {
130 if eps <= 0. || eps >= 1. {
131 return Err(ReductionError::InvalidPrecision);
132 }
133 }
134 };
135 Ok(&self.0)
136 }
137
138 fn check(self) -> Result<Self::Checked, Self::Error> {
139 self.check_ref()?;
140 Ok(self.0)
141 }
142}