linfa_preprocessing/countgrams/
hyperparams.rs1use crate::PreprocessingError;
2use linfa::ParamGuard;
3use regex::Regex;
4use std::cell::{Ref, RefCell};
5use std::collections::HashSet;
6
7#[cfg(feature = "serde")]
8use serde_crate::{Deserialize, Serialize};
9
10use super::{Tokenizer, Tokenizerfp};
11
12#[derive(Clone, Debug)]
13#[cfg(not(feature = "serde"))]
14struct SerdeRegex(Regex);
15#[derive(Clone, Debug, Serialize, Deserialize)]
16#[serde(crate = "serde_crate")]
17#[cfg(feature = "serde")]
18struct SerdeRegex(serde_regex::Serde<Regex>);
19
20#[cfg(not(feature = "serde"))]
21impl SerdeRegex {
22 fn new(re: &str) -> Result<Self, regex::Error> {
23 Ok(Self(Regex::new(re)?))
24 }
25
26 fn as_re(&self) -> &Regex {
27 &self.0
28 }
29}
30
31#[cfg(feature = "serde")]
32impl SerdeRegex {
33 fn new(re: &str) -> Result<Self, regex::Error> {
34 Ok(Self(serde_regex::Serde(Regex::new(re)?)))
35 }
36
37 fn as_re(&self) -> &Regex {
38 use std::ops::Deref;
39 &self.0.deref()
40 }
41}
42
43#[cfg_attr(
63 feature = "serde",
64 derive(Serialize, Deserialize),
65 serde(crate = "serde_crate")
66)]
67#[derive(Clone, Debug)]
68pub struct CountVectorizerValidParams {
69 convert_to_lowercase: bool,
70 split_regex_expr: String,
71 split_regex: RefCell<Option<SerdeRegex>>,
72 n_gram_range: (usize, usize),
73 normalize: bool,
74 document_frequency: (f32, f32),
75 stopwords: Option<HashSet<String>>,
76 max_features: Option<usize>,
77 #[cfg_attr(feature = "serde", serde(skip))]
78 pub(crate) tokenizer_function: Option<Tokenizerfp>,
79 pub(crate) tokenizer_deserialization_guard: bool,
80}
81
82impl CountVectorizerValidParams {
83 pub fn tokenizer_function(&self) -> Option<Tokenizerfp> {
84 self.tokenizer_function
85 }
86
87 pub fn max_features(&self) -> Option<usize> {
88 self.max_features
89 }
90
91 pub fn convert_to_lowercase(&self) -> bool {
92 self.convert_to_lowercase
93 }
94
95 pub fn split_regex(&self) -> Ref<'_, Regex> {
96 Ref::map(self.split_regex.borrow(), |x| x.as_ref().unwrap().as_re())
97 }
98
99 pub fn n_gram_range(&self) -> (usize, usize) {
100 self.n_gram_range
101 }
102
103 pub fn normalize(&self) -> bool {
104 self.normalize
105 }
106
107 pub fn document_frequency(&self) -> (f32, f32) {
108 self.document_frequency
109 }
110
111 pub fn stopwords(&self) -> &Option<HashSet<String>> {
112 &self.stopwords
113 }
114}
115
116#[cfg_attr(
117 feature = "serde",
118 derive(Serialize, Deserialize),
119 serde(crate = "serde_crate")
120)]
121#[derive(Clone, Debug)]
122pub struct CountVectorizerParams(CountVectorizerValidParams);
123
124impl std::default::Default for CountVectorizerParams {
125 fn default() -> Self {
126 Self(CountVectorizerValidParams {
127 convert_to_lowercase: true,
128 split_regex_expr: r"\b\w\w+\b".to_string(),
129 split_regex: RefCell::new(None),
130 n_gram_range: (1, 1),
131 normalize: true,
132 document_frequency: (0., 1.),
133 stopwords: None,
134 max_features: None,
135 tokenizer_function: None,
136 tokenizer_deserialization_guard: false,
137 })
138 }
139}
140
141impl CountVectorizerParams {
142 pub fn tokenizer(mut self, tokenizer: Tokenizer) -> Self {
145 match tokenizer {
146 Tokenizer::Function(fp) => {
147 self.0.tokenizer_function = Some(fp);
148 self.0.tokenizer_deserialization_guard = true;
149 }
150 Tokenizer::Regex(regex_str) => {
151 self.0.split_regex_expr = regex_str.to_string();
152 self.0.tokenizer_deserialization_guard = false;
153 }
154 }
155
156 self
157 }
158
159 pub fn max_features(mut self, max_features: Option<usize>) -> Self {
162 self.0.max_features = max_features;
163 self
164 }
165
166 pub fn convert_to_lowercase(mut self, convert_to_lowercase: bool) -> Self {
168 self.0.convert_to_lowercase = convert_to_lowercase;
169 self
170 }
171
172 pub fn n_gram_range(mut self, min_n: usize, max_n: usize) -> Self {
178 self.0.n_gram_range = (min_n, max_n);
179 self
180 }
181
182 pub fn normalize(mut self, normalize: bool) -> Self {
184 self.0.normalize = normalize;
185 self
186 }
187
188 pub fn document_frequency(mut self, min_freq: f32, max_freq: f32) -> Self {
191 self.0.document_frequency = (min_freq, max_freq);
192 self
193 }
194
195 pub fn stopwords<T: ToString>(mut self, stopwords: &[T]) -> Self {
197 self.0.stopwords = Some(stopwords.iter().map(|t| t.to_string()).collect());
198 self
199 }
200}
201
202impl ParamGuard for CountVectorizerParams {
203 type Checked = CountVectorizerValidParams;
204 type Error = PreprocessingError;
205
206 fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
207 let (n_gram_min, n_gram_max) = self.0.n_gram_range;
208 let (min_freq, max_freq) = self.0.document_frequency;
209
210 if n_gram_min == 0 || n_gram_max == 0 {
211 Err(PreprocessingError::InvalidNGramBoundaries(
212 n_gram_min, n_gram_max,
213 ))
214 } else if n_gram_min > n_gram_max {
215 Err(PreprocessingError::FlippedNGramBoundaries(
216 n_gram_min, n_gram_max,
217 ))
218 } else if min_freq < 0. || max_freq < 0. {
219 Err(PreprocessingError::InvalidDocumentFrequencies(
220 min_freq, max_freq,
221 ))
222 } else if max_freq < min_freq {
223 Err(PreprocessingError::FlippedDocumentFrequencies(
224 min_freq, max_freq,
225 ))
226 } else {
227 *self.0.split_regex.borrow_mut() = Some(SerdeRegex::new(&self.0.split_regex_expr)?);
228
229 Ok(&self.0)
230 }
231 }
232
233 fn check(self) -> Result<Self::Checked, Self::Error> {
234 self.check_ref()?;
235 Ok(self.0)
236 }
237}