linfa_preprocessing/
tf_idf_vectorization.rs

1//! Term frequency - inverse document frequency vectorization methods
2
3use crate::countgrams::{CountVectorizer, CountVectorizerParams, Tokenizerfp};
4use crate::error::Result;
5use crate::Tokenizer;
6use encoding::types::EncodingRef;
7use encoding::DecoderTrap;
8use ndarray::{Array1, ArrayBase, Data, Ix1};
9use sprs::CsMat;
10
11#[cfg(feature = "serde")]
12use serde_crate::{Deserialize, Serialize};
13
14#[cfg_attr(
15    feature = "serde",
16    derive(Serialize, Deserialize),
17    serde(crate = "serde_crate")
18)]
19#[derive(Clone, Debug, PartialEq, Eq, Hash)]
20/// Methods for computing the inverse document frequency of a vocabulary entry
21pub enum TfIdfMethod {
22    /// Computes the idf as `log(1+n/1+document_frequency) + 1`. The "plus ones" inside the log
23    /// add an artificial document containing every vocabulary entry, preventing divisions by zero.
24    /// The "plus one" after the log allows vocabulary entries that appear in every document to still be considered with
25    /// a weight of one instead of being completely discarded.
26    Smooth,
27    /// Computes the idf as `log(n/document_frequency) +1`. The "plus one" after the log allows vocabulary entries that appear in every document to still be considered with
28    /// a weight of one instead of being completely discarded. If a vocabulary entry has zero document frequency this will produce a division by zero.
29    NonSmooth,
30    /// Textbook definition of idf, computed as `log(n/ 1 + document_frequency)` which prevents divisions by zero and discards entries that appear in every document.
31    Textbook,
32}
33
34impl TfIdfMethod {
35    pub fn compute_idf(&self, n: usize, df: usize) -> f64 {
36        match self {
37            TfIdfMethod::Smooth => ((1. + n as f64) / (1. + df as f64)).ln() + 1.,
38            TfIdfMethod::NonSmooth => (n as f64 / df as f64).ln() + 1.,
39            TfIdfMethod::Textbook => (n as f64 / (1. + df as f64)).ln(),
40        }
41    }
42}
43
44/// Simlar to [`CountVectorizer`] but instead of
45/// just counting the term frequency of each vocabulary entry in each given document,
46/// it computes the term frequecy times the inverse document frequency, thus giving more importance
47/// to entries that appear many times but only on some documents. The weight function can be adjusted
48/// by setting the appropriate [method](TfIdfMethod). This struct provides the same string  
49/// processing customizations described in [`CountVectorizer`].
50#[cfg_attr(
51    feature = "serde",
52    derive(Serialize, Deserialize),
53    serde(crate = "serde_crate")
54)]
55#[derive(Clone, Debug)]
56pub struct TfIdfVectorizer {
57    count_vectorizer: CountVectorizerParams,
58    method: TfIdfMethod,
59}
60
61impl std::default::Default for TfIdfVectorizer {
62    fn default() -> Self {
63        Self {
64            count_vectorizer: CountVectorizerParams::default(),
65            method: TfIdfMethod::Smooth,
66        }
67    }
68}
69
70impl TfIdfVectorizer {
71    // Set the tokenizer as either a function pointer or a regex
72    // If this method is not called, the default is to use regex "\b\w\w+\b"
73    pub fn tokenizer(self, tokenizer: Tokenizer) -> Self {
74        Self {
75            count_vectorizer: self.count_vectorizer.tokenizer(tokenizer),
76            method: self.method,
77        }
78    }
79
80    /// When building the vocabulary, only consider the top max_features (by term frequency).
81    /// If None, all features are used.
82    pub fn max_features(self, max_features: Option<usize>) -> Self {
83        Self {
84            count_vectorizer: self.count_vectorizer.max_features(max_features),
85            method: self.method,
86        }
87    }
88
89    ///If true, all documents used for fitting will be converted to lowercase.
90    pub fn convert_to_lowercase(self, convert_to_lowercase: bool) -> Self {
91        Self {
92            count_vectorizer: self
93                .count_vectorizer
94                .convert_to_lowercase(convert_to_lowercase),
95            method: self.method,
96        }
97    }
98
99    /// If set to `(1,1)` single tokens will be candidate vocabulary entries, if `(2,2)` then adjacent token pairs will be considered,
100    /// if `(1,2)` then both single tokens and adjacent token pairs will be considered, and so on. The definition of token depends on the
101    /// regex used fpr splitting the documents.
102    ///
103    /// `min_n` should not be greater than `max_n`
104    pub fn n_gram_range(self, min_n: usize, max_n: usize) -> Self {
105        Self {
106            count_vectorizer: self.count_vectorizer.n_gram_range(min_n, max_n),
107            method: self.method,
108        }
109    }
110
111    /// If true, all charachters in the documents used for fitting will be normalized according to unicode's NFKD normalization.
112    pub fn normalize(self, normalize: bool) -> Self {
113        Self {
114            count_vectorizer: self.count_vectorizer.normalize(normalize),
115            method: self.method,
116        }
117    }
118
119    /// Specifies the minimum and maximum (relative) document frequencies that each vocabulary entry must satisfy.
120    /// `min_freq` and `max_freq` must lie in `0..=1` and `min_freq` should not be greater than `max_freq`
121    pub fn document_frequency(self, min_freq: f32, max_freq: f32) -> Self {
122        Self {
123            count_vectorizer: self.count_vectorizer.document_frequency(min_freq, max_freq),
124            method: self.method,
125        }
126    }
127
128    /// List of entries to be excluded from the generated vocabulary.
129    pub fn stopwords<T: ToString>(self, stopwords: &[T]) -> Self {
130        Self {
131            count_vectorizer: self.count_vectorizer.stopwords(stopwords),
132            method: self.method,
133        }
134    }
135
136    /// Learns a vocabulary from the texts in `x`, according to the specified attributes and maps each
137    /// vocabulary entry to an integer value, producing a [FittedTfIdfVectorizer].
138    ///
139    /// Returns an error if:
140    /// * one of the `n_gram` boundaries is set to zero or the minimum value is greater than the maximum value
141    /// * if the minimum document frequency is greater than one or than the maximum frequency, or if the maximum frequecy is  
142    ///   smaller than zero
143    pub fn fit<T: ToString + Clone, D: Data<Elem = T>>(
144        &self,
145        x: &ArrayBase<D, Ix1>,
146    ) -> Result<FittedTfIdfVectorizer> {
147        let fitted_vectorizer = self.count_vectorizer.fit(x)?;
148        Ok(FittedTfIdfVectorizer {
149            fitted_vectorizer,
150            method: self.method.clone(),
151        })
152    }
153
154    /// Produces a [FittedTfIdfVectorizer] with the input vocabulary.
155    /// All struct attributes are ignored in the fitting but will be used by the [FittedTfIdfVectorizer]
156    /// to transform any text to be examined. As such this will return an error in the same cases as the `fit` method.
157    pub fn fit_vocabulary<T: ToString>(&self, words: &[T]) -> Result<FittedTfIdfVectorizer> {
158        let fitted_vectorizer = self.count_vectorizer.fit_vocabulary(words)?;
159        Ok(FittedTfIdfVectorizer {
160            fitted_vectorizer,
161            method: self.method.clone(),
162        })
163    }
164
165    pub fn fit_files<P: AsRef<std::path::Path>>(
166        &self,
167        input: &[P],
168        encoding: EncodingRef,
169        trap: DecoderTrap,
170    ) -> Result<FittedTfIdfVectorizer> {
171        let fitted_vectorizer = self.count_vectorizer.fit_files(input, encoding, trap)?;
172        Ok(FittedTfIdfVectorizer {
173            fitted_vectorizer,
174            method: self.method.clone(),
175        })
176    }
177}
178
179/// Counts the occurrences of each vocabulary entry, learned during fitting, in a sequence of texts and scales them by the inverse document
180/// document frequency defined by the [method](TfIdfMethod). Each vocabulary entry is mapped
181/// to an integer value that is used to index the count in the result.
182#[cfg_attr(
183    feature = "serde",
184    derive(Serialize, Deserialize),
185    serde(crate = "serde_crate")
186)]
187#[derive(Clone, Debug)]
188pub struct FittedTfIdfVectorizer {
189    fitted_vectorizer: CountVectorizer,
190    method: TfIdfMethod,
191}
192
193impl FittedTfIdfVectorizer {
194    pub fn force_tokenizer_redefinition(&mut self, tokenizer: Tokenizerfp) {
195        self.fitted_vectorizer
196            .force_tokenizer_function_redefinition(tokenizer);
197    }
198
199    /// Number of vocabulary entries learned during fitting
200    pub fn nentries(&self) -> usize {
201        self.fitted_vectorizer.vocabulary.len()
202    }
203
204    /// Constains all vocabulary entries, in the same order used by the `transform` method.
205    pub fn vocabulary(&self) -> &Vec<String> {
206        self.fitted_vectorizer.vocabulary()
207    }
208
209    /// Returns the inverse document frequency method used in the tansform method
210    pub fn method(&self) -> &TfIdfMethod {
211        &self.method
212    }
213
214    /// Given a sequence of `n` documents, produces an array of size `(n, vocabulary_entries)` where column `j` of row `i`
215    /// is the number of occurrences of vocabulary entry `j` in the text of index `i`, scaled by the inverse document frequency.
216    ///  Vocabulary entry `j` is the string at the `j`-th position in the vocabulary.
217    pub fn transform<T: ToString, D: Data<Elem = T>>(
218        &self,
219        x: &ArrayBase<D, Ix1>,
220    ) -> Result<CsMat<f64>> {
221        self.fitted_vectorizer.validate_deserialization()?;
222        let (term_freqs, doc_freqs) = self.fitted_vectorizer.get_term_and_document_frequencies(x);
223        Ok(self.apply_tf_idf(term_freqs, doc_freqs))
224    }
225
226    pub fn transform_files<P: AsRef<std::path::Path>>(
227        &self,
228        input: &[P],
229        encoding: EncodingRef,
230        trap: DecoderTrap,
231    ) -> Result<CsMat<f64>> {
232        self.fitted_vectorizer.validate_deserialization()?;
233        let (term_freqs, doc_freqs) = self
234            .fitted_vectorizer
235            .get_term_and_document_frequencies_files(input, encoding, trap);
236        Ok(self.apply_tf_idf(term_freqs, doc_freqs))
237    }
238
239    fn apply_tf_idf(&self, term_freqs: CsMat<usize>, doc_freqs: Array1<usize>) -> CsMat<f64> {
240        let mut term_freqs: CsMat<f64> = term_freqs.map(|x| *x as f64);
241        let inv_doc_freqs =
242            doc_freqs.mapv(|doc_freq| self.method.compute_idf(term_freqs.rows(), doc_freq));
243        for mut row_vec in term_freqs.outer_iterator_mut() {
244            for (col_i, val) in row_vec.iter_mut() {
245                *val *= inv_doc_freqs[col_i];
246            }
247        }
248        term_freqs
249    }
250}
251
252#[cfg(test)]
253mod tests {
254
255    use super::*;
256    use crate::column_for_word;
257    use approx::assert_abs_diff_eq;
258    use ndarray::array;
259    use std::fs::File;
260    use std::io::Write;
261
262    macro_rules! assert_tf_idfs_for_word {
263
264        ($voc:expr, $transf:expr, $(($word:expr, $counts:expr)),*) => {
265            $ (
266                assert_abs_diff_eq!(column_for_word!($voc, $transf, $word), $counts, epsilon=1e-3);
267            )*
268        }
269    }
270
271    #[test]
272    fn autotraits() {
273        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
274        has_autotraits::<TfIdfMethod>();
275    }
276
277    #[test]
278    fn test_tf_idf() {
279        let texts = array![
280            "one and two and three",
281            "three and four and five",
282            "seven and eight",
283            "maybe ten and eleven",
284            "avoid singletons: one two four five seven eight ten eleven and an and"
285        ];
286        let vectorizer = TfIdfVectorizer::default().fit(&texts).unwrap();
287        let vocabulary = vectorizer.vocabulary();
288        let transformed = vectorizer.transform(&texts).unwrap().to_dense();
289        assert_eq!(transformed.dim(), (texts.len(), vocabulary.len()));
290        assert_tf_idfs_for_word!(
291            vocabulary,
292            transformed,
293            ("one", array![1.693, 0.0, 0.0, 0.0, 1.693]),
294            ("two", array![1.693, 0.0, 0.0, 0.0, 1.693]),
295            ("three", array![1.693, 1.693, 0.0, 0.0, 0.0]),
296            ("four", array![0.0, 1.693, 0.0, 0.0, 1.693]),
297            ("and", array![2.0, 2.0, 1.0, 1.0, 2.0]),
298            ("five", array![0.0, 1.693, 0.0, 0.0, 1.693]),
299            ("seven", array![0.0, 0.0, 1.693, 0.0, 1.693]),
300            ("eight", array![0.0, 0.0, 1.693, 0.0, 1.693]),
301            ("ten", array![0.0, 0.0, 0.0, 1.693, 1.693]),
302            ("eleven", array![0.0, 0.0, 0.0, 1.693, 1.693]),
303            ("an", array![0.0, 0.0, 0.0, 0.0, 2.098]),
304            ("avoid", array![0.0, 0.0, 0.0, 0.0, 2.098]),
305            ("singletons", array![0.0, 0.0, 0.0, 0.0, 2.098]),
306            ("maybe", array![0.0, 0.0, 0.0, 2.098, 0.0])
307        );
308    }
309
310    #[test]
311    fn test_tf_idf_files() {
312        let text_files = create_test_files();
313        let vectorizer = TfIdfVectorizer::default()
314            .fit_files(
315                &text_files,
316                encoding::all::UTF_8,
317                encoding::DecoderTrap::Strict,
318            )
319            .unwrap();
320        let vocabulary = vectorizer.vocabulary();
321        let transformed = vectorizer
322            .transform_files(
323                &text_files,
324                encoding::all::UTF_8,
325                encoding::DecoderTrap::Strict,
326            )
327            .unwrap()
328            .to_dense();
329        assert_eq!(transformed.dim(), (text_files.len(), vocabulary.len()));
330        assert_tf_idfs_for_word!(
331            vocabulary,
332            transformed,
333            ("one", array![1.693, 0.0, 0.0, 0.0, 1.693]),
334            ("two", array![1.693, 0.0, 0.0, 0.0, 1.693]),
335            ("three", array![1.693, 1.693, 0.0, 0.0, 0.0]),
336            ("four", array![0.0, 1.693, 0.0, 0.0, 1.693]),
337            ("and", array![2.0, 2.0, 1.0, 1.0, 2.0]),
338            ("five", array![0.0, 1.693, 0.0, 0.0, 1.693]),
339            ("seven", array![0.0, 0.0, 1.693, 0.0, 1.693]),
340            ("eight", array![0.0, 0.0, 1.693, 0.0, 1.693]),
341            ("ten", array![0.0, 0.0, 0.0, 1.693, 1.693]),
342            ("eleven", array![0.0, 0.0, 0.0, 1.693, 1.693]),
343            ("an", array![0.0, 0.0, 0.0, 0.0, 2.098]),
344            ("avoid", array![0.0, 0.0, 0.0, 0.0, 2.098]),
345            ("singletons", array![0.0, 0.0, 0.0, 0.0, 2.098]),
346            ("maybe", array![0.0, 0.0, 0.0, 2.098, 0.0])
347        );
348        delete_test_files(&text_files)
349    }
350
351    fn create_test_files() -> Vec<&'static str> {
352        let file_names = vec![
353            "./tf_idf_vectorization_test_file_1",
354            "./tf_idf_vectorization_test_file_2",
355            "./tf_idf_vectorization_test_file_3",
356            "./tf_idf_vectorization_test_file_4",
357            "./tf_idf_vectorization_test_file_5",
358        ];
359        let contents = &[
360            "one and two and three",
361            "three and four and five",
362            "seven and eight",
363            "maybe ten and eleven",
364            "avoid singletons: one two four five seven eight ten eleven and an and",
365        ];
366        //create files and write contents
367        for (f_name, f_content) in file_names.iter().zip(contents.iter()) {
368            let mut file = File::create(f_name).unwrap();
369            file.write_all(f_content.as_bytes()).unwrap();
370        }
371        file_names
372    }
373
374    fn delete_test_files(file_names: &[&'static str]) {
375        for f_name in file_names.iter() {
376            std::fs::remove_file(f_name).unwrap();
377        }
378    }
379}