linfa_preprocessing/countgrams/
hyperparams.rs

1use crate::PreprocessingError;
2use linfa::ParamGuard;
3use regex::Regex;
4use std::cell::{Ref, RefCell};
5use std::collections::HashSet;
6
7#[cfg(feature = "serde")]
8use serde_crate::{Deserialize, Serialize};
9
10use super::{Tokenizer, Tokenizerfp};
11
12#[derive(Clone, Debug)]
13#[cfg(not(feature = "serde"))]
14struct SerdeRegex(Regex);
15#[derive(Clone, Debug, Serialize, Deserialize)]
16#[serde(crate = "serde_crate")]
17#[cfg(feature = "serde")]
18struct SerdeRegex(serde_regex::Serde<Regex>);
19
20#[cfg(not(feature = "serde"))]
21impl SerdeRegex {
22    fn new(re: &str) -> Result<Self, regex::Error> {
23        Ok(Self(Regex::new(re)?))
24    }
25
26    fn as_re(&self) -> &Regex {
27        &self.0
28    }
29}
30
31#[cfg(feature = "serde")]
32impl SerdeRegex {
33    fn new(re: &str) -> Result<Self, regex::Error> {
34        Ok(Self(serde_regex::Serde(Regex::new(re)?)))
35    }
36
37    fn as_re(&self) -> &Regex {
38        use std::ops::Deref;
39        &self.0.deref()
40    }
41}
42
43/// Count vectorizer: learns a vocabulary from a sequence of documents (or file paths) and maps each
44/// vocabulary entry to an integer value, producing a [CountVectorizer](crate::CountVectorizer) that can
45/// be used to count the occurrences of each vocabulary entry in any sequence of documents. Alternatively a user-specified vocabulary can
46/// be used for fitting.
47///
48/// ### Attributes
49///
50/// If a user-defined vocabulary is used for fitting then the following attributes will not be considered during the fitting phase but
51/// they will still be used by the [CountVectorizer](crate::CountVectorizer) to transform any text to be examined.
52///
53/// * `split_regex`: the regex espression used to split decuments into tokens. Defaults to r"\\b\\w\\w+\\b", which selects "words", using whitespaces and
54///   punctuation symbols as separators.
55/// * `convert_to_lowercase`: if true, all documents used for fitting will be converted to lowercase. Defaults to `true`.
56/// * `n_gram_range`: if set to `(1,1)` single tokens will be candidate vocabulary entries, if `(2,2)` then adjacent token pairs will be considered,
57///   if `(1,2)` then both single tokens and adjacent token pairs will be considered, and so on. The definition of token depends on the
58///   regex used fpr splitting the documents. The default value is `(1,1)`.
59/// * `normalize`: if true, all charachters in the documents used for fitting will be normalized according to unicode's NFKD normalization. Defaults to `true`.
60/// * `document_frequency`: specifies the minimum and maximum (relative) document frequencies that each vocabulary entry must satisfy. Defaults to `(0., 1.)` (i.e. 0% minimum and 100% maximum)
61/// * `stopwords`: optional list of entries to be excluded from the generated vocabulary. Defaults to `None`
62#[cfg_attr(
63    feature = "serde",
64    derive(Serialize, Deserialize),
65    serde(crate = "serde_crate")
66)]
67#[derive(Clone, Debug)]
68pub struct CountVectorizerValidParams {
69    convert_to_lowercase: bool,
70    split_regex_expr: String,
71    split_regex: RefCell<Option<SerdeRegex>>,
72    n_gram_range: (usize, usize),
73    normalize: bool,
74    document_frequency: (f32, f32),
75    stopwords: Option<HashSet<String>>,
76    max_features: Option<usize>,
77    #[cfg_attr(feature = "serde", serde(skip))]
78    pub(crate) tokenizer_function: Option<Tokenizerfp>,
79    pub(crate) tokenizer_deserialization_guard: bool,
80}
81
82impl CountVectorizerValidParams {
83    pub fn tokenizer_function(&self) -> Option<Tokenizerfp> {
84        self.tokenizer_function
85    }
86
87    pub fn max_features(&self) -> Option<usize> {
88        self.max_features
89    }
90
91    pub fn convert_to_lowercase(&self) -> bool {
92        self.convert_to_lowercase
93    }
94
95    pub fn split_regex(&self) -> Ref<'_, Regex> {
96        Ref::map(self.split_regex.borrow(), |x| x.as_ref().unwrap().as_re())
97    }
98
99    pub fn n_gram_range(&self) -> (usize, usize) {
100        self.n_gram_range
101    }
102
103    pub fn normalize(&self) -> bool {
104        self.normalize
105    }
106
107    pub fn document_frequency(&self) -> (f32, f32) {
108        self.document_frequency
109    }
110
111    pub fn stopwords(&self) -> &Option<HashSet<String>> {
112        &self.stopwords
113    }
114}
115
116#[cfg_attr(
117    feature = "serde",
118    derive(Serialize, Deserialize),
119    serde(crate = "serde_crate")
120)]
121#[derive(Clone, Debug)]
122pub struct CountVectorizerParams(CountVectorizerValidParams);
123
124impl std::default::Default for CountVectorizerParams {
125    fn default() -> Self {
126        Self(CountVectorizerValidParams {
127            convert_to_lowercase: true,
128            split_regex_expr: r"\b\w\w+\b".to_string(),
129            split_regex: RefCell::new(None),
130            n_gram_range: (1, 1),
131            normalize: true,
132            document_frequency: (0., 1.),
133            stopwords: None,
134            max_features: None,
135            tokenizer_function: None,
136            tokenizer_deserialization_guard: false,
137        })
138    }
139}
140
141impl CountVectorizerParams {
142    // Set the tokenizer as either a function pointer or a regex
143    // If this method is not called, the default is to use regex "\b\w\w+\b"
144    pub fn tokenizer(mut self, tokenizer: Tokenizer) -> Self {
145        match tokenizer {
146            Tokenizer::Function(fp) => {
147                self.0.tokenizer_function = Some(fp);
148                self.0.tokenizer_deserialization_guard = true;
149            }
150            Tokenizer::Regex(regex_str) => {
151                self.0.split_regex_expr = regex_str.to_string();
152                self.0.tokenizer_deserialization_guard = false;
153            }
154        }
155
156        self
157    }
158
159    /// When building the vocabulary, only consider the top max_features (by term frequency).
160    /// If None, all features are used.
161    pub fn max_features(mut self, max_features: Option<usize>) -> Self {
162        self.0.max_features = max_features;
163        self
164    }
165
166    ///If true, all documents used for fitting will be converted to lowercase.
167    pub fn convert_to_lowercase(mut self, convert_to_lowercase: bool) -> Self {
168        self.0.convert_to_lowercase = convert_to_lowercase;
169        self
170    }
171
172    /// If set to `(1,1)` single tokens will be candidate vocabulary entries, if `(2,2)` then adjacent token pairs will be considered,
173    /// if `(1,2)` then both single tokens and adjacent token pairs will be considered, and so on. The definition of token depends on the
174    /// regex used fpr splitting the documents.
175    ///
176    /// `min_n` should not be greater than `max_n`
177    pub fn n_gram_range(mut self, min_n: usize, max_n: usize) -> Self {
178        self.0.n_gram_range = (min_n, max_n);
179        self
180    }
181
182    /// If true, all charachters in the documents used for fitting will be normalized according to unicode's NFKD normalization.
183    pub fn normalize(mut self, normalize: bool) -> Self {
184        self.0.normalize = normalize;
185        self
186    }
187
188    /// Specifies the minimum and maximum (relative) document frequencies that each vocabulary entry must satisfy.
189    /// `min_freq` and `max_freq` must lie in `0..=1` and `min_freq` should not be greater than `max_freq`
190    pub fn document_frequency(mut self, min_freq: f32, max_freq: f32) -> Self {
191        self.0.document_frequency = (min_freq, max_freq);
192        self
193    }
194
195    /// List of entries to be excluded from the generated vocabulary.
196    pub fn stopwords<T: ToString>(mut self, stopwords: &[T]) -> Self {
197        self.0.stopwords = Some(stopwords.iter().map(|t| t.to_string()).collect());
198        self
199    }
200}
201
202impl ParamGuard for CountVectorizerParams {
203    type Checked = CountVectorizerValidParams;
204    type Error = PreprocessingError;
205
206    fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
207        let (n_gram_min, n_gram_max) = self.0.n_gram_range;
208        let (min_freq, max_freq) = self.0.document_frequency;
209
210        if n_gram_min == 0 || n_gram_max == 0 {
211            Err(PreprocessingError::InvalidNGramBoundaries(
212                n_gram_min, n_gram_max,
213            ))
214        } else if n_gram_min > n_gram_max {
215            Err(PreprocessingError::FlippedNGramBoundaries(
216                n_gram_min, n_gram_max,
217            ))
218        } else if min_freq < 0. || max_freq < 0. {
219            Err(PreprocessingError::InvalidDocumentFrequencies(
220                min_freq, max_freq,
221            ))
222        } else if max_freq < min_freq {
223            Err(PreprocessingError::FlippedDocumentFrequencies(
224                min_freq, max_freq,
225            ))
226        } else {
227            *self.0.split_regex.borrow_mut() = Some(SerdeRegex::new(&self.0.split_regex_expr)?);
228
229            Ok(&self.0)
230        }
231    }
232
233    fn check(self) -> Result<Self::Checked, Self::Error> {
234        self.check_ref()?;
235        Ok(self.0)
236    }
237}