linfa_preprocessing/countgrams/
mod.rs

1//! Count vectorization methods
2
3use std::cmp::Reverse;
4use std::collections::{HashMap, HashSet};
5use std::io::Read;
6use std::iter::IntoIterator;
7
8use encoding::types::EncodingRef;
9use encoding::DecoderTrap;
10use itertools::sorted;
11use ndarray::{Array1, ArrayBase, ArrayViewMut1, Data, Ix1};
12use regex::Regex;
13use sprs::{CsMat, CsVec};
14use unicode_normalization::UnicodeNormalization;
15
16use crate::error::{PreprocessingError, Result};
17use crate::helpers::NGramList;
18pub use hyperparams::{CountVectorizerParams, CountVectorizerValidParams};
19use linfa::ParamGuard;
20
21#[cfg(feature = "serde")]
22use serde_crate::{Deserialize, Serialize};
23
24mod hyperparams;
25
26pub(crate) type Tokenizerfp = fn(&str) -> Vec<&str>;
27pub enum Tokenizer {
28    Function(Tokenizerfp),
29    Regex(String),
30}
31
32impl CountVectorizerValidParams {
33    /// Learns a vocabulary from the documents in `x`, according to the specified attributes and maps each
34    /// vocabulary entry to an integer value, producing a [CountVectorizer](CountVectorizer).
35    ///
36    /// Returns an error if:
37    /// * one of the `n_gram` boundaries is set to zero or the minimum value is greater than the maximum value
38    /// * if the minimum document frequency is greater than one or than the maximum frequency, or if the maximum frequency is  
39    ///   smaller than zero
40    /// * if the regex expression for the split is invalid
41    pub fn fit<T: ToString + Clone, D: Data<Elem = T>>(
42        &self,
43        x: &ArrayBase<D, Ix1>,
44    ) -> Result<CountVectorizer> {
45        // word, (integer mapping for word, document frequency for word)
46        let mut vocabulary: HashMap<String, (usize, usize)> = HashMap::new();
47        for string in x.iter().map(|s| transform_string(s.to_string(), self)) {
48            self.read_document_into_vocabulary(string, &self.split_regex(), &mut vocabulary);
49        }
50
51        let mut vocabulary = self.filter_vocabulary(vocabulary, x.len());
52
53        let vec_vocabulary = hashmap_to_vocabulary(&mut vocabulary);
54
55        Ok(CountVectorizer {
56            vocabulary,
57            vec_vocabulary,
58            properties: self.clone(),
59        })
60    }
61
62    /// Learns a vocabulary from the documents contained in the files in `input`, according to the specified attributes and maps each
63    /// vocabulary entry to an integer value, producing a [CountVectorizer](CountVectorizer).
64    ///
65    /// The files will be read using the specified `encoding`, and any sequence unrecognized by the encoding will be handled
66    /// according to `trap`.
67    ///
68    /// Returns an error if:
69    /// * one of the `n_gram` boundaries is set to zero or the minimum value is greater than the maximum value
70    /// * if the minimum document frequency is greater than one or than the maximum frequency, or if the maximum frequency is  
71    ///   smaller than zero
72    /// * if the regex expression for the split is invalid
73    /// * if one of the files couldn't be opened
74    /// * if the trap is strict and an unrecognized sequence is encountered in one of the files
75    pub fn fit_files<P: AsRef<std::path::Path>>(
76        &self,
77        input: &[P],
78        encoding: EncodingRef,
79        trap: DecoderTrap,
80    ) -> Result<CountVectorizer> {
81        // word, (integer mapping for word, document frequency for word)
82        let mut vocabulary: HashMap<String, (usize, usize)> = HashMap::new();
83        let documents_count = input.len();
84        for path in input {
85            let mut file = std::fs::File::open(path)?;
86            let mut document_bytes = Vec::new();
87            file.read_to_end(&mut document_bytes)?;
88            let document = encoding::decode(&document_bytes, trap, encoding).0;
89            // encoding error contains a cow string, can't just use ?, must go through the unwrap
90            if document.is_err() {
91                return Err(PreprocessingError::EncodingError(document.err().unwrap()));
92            }
93            // safe unwrap now that error has been handled
94            let document = transform_string(document.unwrap(), self);
95            self.read_document_into_vocabulary(document, &self.split_regex(), &mut vocabulary);
96        }
97
98        let mut vocabulary = self.filter_vocabulary(vocabulary, documents_count);
99        let vec_vocabulary = hashmap_to_vocabulary(&mut vocabulary);
100
101        Ok(CountVectorizer {
102            vocabulary,
103            vec_vocabulary,
104            properties: self.clone(),
105        })
106    }
107
108    /// Produces a [CountVectorizer](CountVectorizer) with the input vocabulary.
109    /// All struct attributes are ignored in the fitting but will be used by the [CountVectorizer](CountVectorizer)
110    /// to transform any text to be examined. As such this will return an error in the same cases as the `fit` method.
111    pub fn fit_vocabulary<T: ToString>(&self, words: &[T]) -> Result<CountVectorizer> {
112        let mut vocabulary: HashMap<String, (usize, usize)> = HashMap::with_capacity(words.len());
113        for item in words.iter().map(|w| w.to_string()) {
114            let len = vocabulary.len();
115            // do not care about frequencies/stopwords if a vocabulary is given. Always 1 frequency
116            vocabulary.entry(item).or_insert((len, 1));
117        }
118        let vec_vocabulary = hashmap_to_vocabulary(&mut vocabulary);
119        Ok(CountVectorizer {
120            vocabulary,
121            vec_vocabulary,
122            properties: self.clone(),
123        })
124    }
125
126    /// Removes vocabulary items that do not satisfy the document frequencies constraints or if they appear in the
127    /// optional stopwords test.
128    /// The total number of documents is needed to convert from relative document frequencies to
129    /// their absolute counterparts.
130    fn filter_vocabulary(
131        &self,
132        vocabulary: HashMap<String, (usize, usize)>,
133        n_documents: usize,
134    ) -> HashMap<String, (usize, usize)> {
135        let (min_df, max_df) = self.document_frequency();
136        let len_f32 = n_documents as f32;
137        let (min_abs_df, max_abs_df) = ((min_df * len_f32) as usize, (max_df * len_f32) as usize);
138
139        let vocabulary = if min_abs_df == 0 && max_abs_df == n_documents {
140            match &self.stopwords() {
141                None => vocabulary,
142                Some(stopwords) => vocabulary
143                    .into_iter()
144                    .filter(|(entry, (_, _))| !stopwords.contains(entry))
145                    .collect(),
146            }
147        } else {
148            match &self.stopwords() {
149                None => vocabulary
150                    .into_iter()
151                    .filter(|(_, (_, abs_count))| {
152                        *abs_count >= min_abs_df && *abs_count <= max_abs_df
153                    })
154                    .collect(),
155                Some(stopwords) => vocabulary
156                    .into_iter()
157                    .filter(|(entry, (_, abs_count))| {
158                        *abs_count >= min_abs_df
159                            && *abs_count <= max_abs_df
160                            && !stopwords.contains(entry)
161                    })
162                    .collect(),
163            }
164        };
165
166        if let Some(max_features) = self.max_features() {
167            sorted(
168                vocabulary
169                    .into_iter()
170                    .map(|(word, (x, freq))| (Reverse(freq), Reverse(word), x)),
171            )
172            .take(max_features)
173            .map(|(freq, word, x)| (word.0, (x, freq.0)))
174            .collect()
175        } else {
176            vocabulary
177        }
178    }
179
180    /// Inserts all vocabulary entries learned from a single document (`doc`) into the
181    /// shared `vocabulary`, setting the document frequency to one for new entries and
182    /// incrementing it by one for entries which were already present.
183    fn read_document_into_vocabulary(
184        &self,
185        doc: String,
186        regex: &Regex,
187        vocabulary: &mut HashMap<String, (usize, usize)>,
188    ) {
189        let words = if let Some(tokenizer) = self.tokenizer_function() {
190            tokenizer(&doc)
191        } else {
192            regex.find_iter(&doc).map(|mat| mat.as_str()).collect()
193        };
194        let list = NGramList::new(words, self.n_gram_range());
195        let document_vocabulary: HashSet<String> = list.into_iter().flatten().collect();
196        for word in document_vocabulary {
197            let len = vocabulary.len();
198            // If vocabulary item was already present then increase its document frequency
199            if let Some((_, freq)) = vocabulary.get_mut(&word) {
200                *freq += 1;
201            // otherwise set it to one
202            } else {
203                vocabulary.insert(word, (len, 1));
204            }
205        }
206    }
207}
208
209impl CountVectorizerParams {
210    /// Learns a vocabulary from the documents in `x`, according to the specified attributes and maps each
211    /// vocabulary entry to an integer value, producing a [CountVectorizer](CountVectorizer).
212    ///
213    /// Returns an error if:
214    /// * one of the `n_gram` boundaries is set to zero or the minimum value is greater than the maximum value
215    /// * if the minimum document frequency is greater than one or than the maximum frequency, or if the maximum frequency is  
216    ///   smaller than zero
217    /// * if the regex expression for the split is invalid
218    pub fn fit<T: ToString + Clone, D: Data<Elem = T>>(
219        &self,
220        x: &ArrayBase<D, Ix1>,
221    ) -> Result<CountVectorizer> {
222        self.check_ref().and_then(|params| params.fit(x))
223    }
224
225    /// Learns a vocabulary from the documents contained in the files in `input`, according to the specified attributes and maps each
226    /// vocabulary entry to an integer value, producing a [CountVectorizer](CountVectorizer).
227    ///
228    /// The files will be read using the specified `encoding`, and any sequence unrecognized by the encoding will be handled
229    /// according to `trap`.
230    ///
231    /// Returns an error if:
232    /// * one of the `n_gram` boundaries is set to zero or the minimum value is greater than the maximum value
233    /// * if the minimum document frequency is greater than one or than the maximum frequency, or if the maximum frequency is  
234    ///   smaller than zero
235    /// * if the regex expression for the split is invalid
236    /// * if one of the files couldn't be opened
237    /// * if the trap is strict and an unrecognized sequence is encountered in one of the files
238    pub fn fit_files<P: AsRef<std::path::Path>>(
239        &self,
240        input: &[P],
241        encoding: EncodingRef,
242        trap: DecoderTrap,
243    ) -> Result<CountVectorizer> {
244        self.check_ref()
245            .and_then(|params| params.fit_files(input, encoding, trap))
246    }
247
248    /// Produces a [CountVectorizer](CountVectorizer) with the input vocabulary.
249    /// All struct attributes are ignored in the fitting but will be used by the [CountVectorizer](CountVectorizer)
250    /// to transform any text to be examined. As such this will return an error in the same cases as the `fit` method.
251    pub fn fit_vocabulary<T: ToString>(&self, words: &[T]) -> Result<CountVectorizer> {
252        self.check_ref()
253            .and_then(|params| params.fit_vocabulary(words))
254    }
255}
256
257/// Counts the occurrences of each vocabulary entry, learned during fitting, in a sequence of documents. Each vocabulary entry is mapped
258/// to an integer value that is used to index the count in the result.
259#[cfg_attr(
260    feature = "serde",
261    derive(Serialize, Deserialize),
262    serde(crate = "serde_crate")
263)]
264#[derive(Debug, Clone)]
265pub struct CountVectorizer {
266    pub(crate) vocabulary: HashMap<String, (usize, usize)>,
267    pub(crate) vec_vocabulary: Vec<String>,
268    pub(crate) properties: CountVectorizerValidParams,
269}
270
271impl CountVectorizer {
272    /// Construct a new set of parameters
273    pub fn params() -> CountVectorizerParams {
274        CountVectorizerParams::default()
275    }
276
277    /// Number of vocabulary entries learned during fitting
278    pub fn nentries(&self) -> usize {
279        self.vocabulary.len()
280    }
281
282    pub fn force_tokenizer_function_redefinition(&mut self, tokenizer: Tokenizerfp) {
283        self.properties.tokenizer_function = Some(tokenizer);
284    }
285
286    pub(crate) fn validate_deserialization(&self) -> Result<()> {
287        if self.properties.tokenizer_function().is_none()
288            && self.properties.tokenizer_deserialization_guard
289        {
290            return Err(PreprocessingError::TokenizerNotSet);
291        }
292
293        Ok(())
294    }
295
296    /// Given a sequence of `n` documents, produces a sparse array of size `(n, vocabulary_entries)` where column `j` of row `i`
297    /// is the number of occurrences of vocabulary entry `j` in the document of index `i`. Vocabulary entry `j` is the string
298    /// at the `j`-th position in the vocabulary. If a vocabulary entry was not encountered in a document, then the relative
299    /// cell in the sparse matrix will be set to `None`.
300    pub fn transform<T: ToString, D: Data<Elem = T>>(
301        &self,
302        x: &ArrayBase<D, Ix1>,
303    ) -> Result<CsMat<usize>> {
304        self.validate_deserialization()?;
305        let (vectorized, _) = self.get_term_and_document_frequencies(x);
306        Ok(vectorized)
307    }
308
309    /// Given a sequence of `n` file names, produces a sparse array of size `(n, vocabulary_entries)` where column `j` of row `i`
310    /// is the number of occurrences of vocabulary entry `j` in the document contained in the file of index `i`. Vocabulary entry `j` is the string
311    /// at the `j`-th position in the vocabulary. If a vocabulary entry was not encountered in a document, then the relative
312    /// cell in the sparse matrix will be set to `None`.
313    ///
314    /// The files will be read using the specified `encoding`, and any sequence unrecognized by the encoding will be handled
315    /// according to `trap`.
316    pub fn transform_files<P: AsRef<std::path::Path>>(
317        &self,
318        input: &[P],
319        encoding: EncodingRef,
320        trap: DecoderTrap,
321    ) -> Result<CsMat<usize>> {
322        self.validate_deserialization()?;
323        let (vectorized, _) = self.get_term_and_document_frequencies_files(input, encoding, trap);
324        Ok(vectorized)
325    }
326
327    /// Contains all vocabulary entries, in the same order used by the `transform` methods.
328    pub fn vocabulary(&self) -> &Vec<String> {
329        &self.vec_vocabulary
330    }
331
332    /// Counts the occurrence of each vocabulary entry in each document and keeps track of the overall
333    /// document frequency of each entry.
334    pub(crate) fn get_term_and_document_frequencies<T: ToString, D: Data<Elem = T>>(
335        &self,
336        x: &ArrayBase<D, Ix1>,
337    ) -> (CsMat<usize>, Array1<usize>) {
338        let mut document_frequencies = Array1::zeros(self.vocabulary.len());
339        let mut sprs_vectorized = CsMat::empty(sprs::CompressedStorage::CSR, self.vocabulary.len());
340        sprs_vectorized.reserve_outer_dim_exact(x.len());
341        let regex = self.properties.split_regex();
342        for string in x.into_iter().map(|s| s.to_string()) {
343            let row = self.analyze_document(string, &regex, document_frequencies.view_mut());
344            sprs_vectorized = sprs_vectorized.append_outer_csvec(row.view());
345        }
346        (sprs_vectorized, document_frequencies)
347    }
348
349    /// Counts the occurrence of each vocabulary entry in each document and keeps track of the overall
350    /// document frequency of each entry.
351    pub(crate) fn get_term_and_document_frequencies_files<P: AsRef<std::path::Path>>(
352        &self,
353        input: &[P],
354        encoding: EncodingRef,
355        trap: DecoderTrap,
356    ) -> (CsMat<usize>, Array1<usize>) {
357        let mut document_frequencies = Array1::zeros(self.vocabulary.len());
358        let mut sprs_vectorized = CsMat::empty(sprs::CompressedStorage::CSR, self.vocabulary.len());
359        sprs_vectorized.reserve_outer_dim_exact(input.len());
360        let regex = self.properties.split_regex();
361        for file_path in input.iter() {
362            let mut file = std::fs::File::open(file_path).unwrap();
363            let mut document_bytes = Vec::new();
364            file.read_to_end(&mut document_bytes).unwrap();
365            let document = encoding::decode(&document_bytes, trap, encoding).0.unwrap();
366            sprs_vectorized = sprs_vectorized.append_outer_csvec(
367                self.analyze_document(document, &regex, document_frequencies.view_mut())
368                    .view(),
369            );
370        }
371        (sprs_vectorized, document_frequencies)
372    }
373
374    /// Produces a sparse array which counts the occurrences of each vocbulary entry in the given document. Also increases
375    /// the document frequency of all entries found.
376    fn analyze_document(
377        &self,
378        document: String,
379        regex: &Regex,
380        mut doc_freqs: ArrayViewMut1<usize>,
381    ) -> CsVec<usize> {
382        // A dense array is needed to parse each document, since sparse arrays can be mutated only
383        // if all insertions are made with increasing index. Since  vocabulary entries can be
384        // encountered in any order this condition does not hold true in this case.
385        // However, keeping only one dense array at a time, greatly limits memory consumption
386        // in sparse cases.
387        let mut term_frequencies: Array1<usize> = Array1::zeros(self.vocabulary.len());
388        let string = transform_string(document, &self.properties);
389        let words = if let Some(tokenizer) = self.properties.tokenizer_function() {
390            tokenizer(&string)
391        } else {
392            regex.find_iter(&string).map(|mat| mat.as_str()).collect()
393        };
394        let list = NGramList::new(words, self.properties.n_gram_range());
395        for ngram_items in list {
396            for item in ngram_items {
397                if let Some((item_index, _)) = self.vocabulary.get(&item) {
398                    let term_freq = term_frequencies.get_mut(*item_index).unwrap();
399                    *term_freq += 1;
400                }
401            }
402        }
403        let mut sprs_term_frequencies = CsVec::empty(self.vocabulary.len());
404
405        // only insert non-zero elements in order to keep a sparse representation
406        for (i, freq) in term_frequencies
407            .into_iter()
408            .enumerate()
409            .filter(|(_, f)| *f > 0)
410        {
411            sprs_term_frequencies.append(i, freq);
412            doc_freqs[i] += 1;
413        }
414        sprs_term_frequencies
415    }
416}
417
418fn transform_string(mut string: String, properties: &CountVectorizerValidParams) -> String {
419    if properties.normalize() {
420        string = string.nfkd().collect();
421    }
422    if properties.convert_to_lowercase() {
423        string = string.to_lowercase();
424    }
425    string
426}
427
428fn hashmap_to_vocabulary(map: &mut HashMap<String, (usize, usize)>) -> Vec<String> {
429    let mut vec = Vec::with_capacity(map.len());
430    for (word, (ref mut idx, _)) in map {
431        *idx = vec.len();
432        vec.push(word.clone());
433    }
434    vec
435}
436
437#[cfg(test)]
438mod tests {
439
440    use super::*;
441    use crate::column_for_word;
442    use ndarray::{array, Array2};
443    use std::fs::File;
444    use std::io::Write;
445
446    macro_rules! assert_counts_for_word {
447
448        ($voc:expr, $transf:expr, $(($word:expr, $counts:expr)),*) => {
449            $ (
450                assert_eq!(column_for_word!($voc, $transf, $word), $counts);
451            )*
452        }
453    }
454
455    #[test]
456    fn simple_count_test() {
457        let texts = array!["oNe two three four", "TWO three four", "three;four", "four"];
458        let vectorizer = CountVectorizer::params().fit(&texts).unwrap();
459        let vocabulary = vectorizer.vocabulary();
460        let counts: Array2<usize> = vectorizer.transform(&texts).unwrap().to_dense();
461        let true_vocabulary = vec!["one", "two", "three", "four"];
462        assert_vocabulary_eq(&true_vocabulary, vocabulary);
463        assert_counts_for_word!(
464            vocabulary,
465            counts,
466            ("one", array![1, 0, 0, 0]),
467            ("two", array![1, 1, 0, 0]),
468            ("three", array![1, 1, 1, 0]),
469            ("four", array![1, 1, 1, 1])
470        );
471
472        let vectorizer = CountVectorizer::params()
473            .n_gram_range(2, 2)
474            .fit(&texts)
475            .unwrap();
476        let vocabulary = vectorizer.vocabulary();
477        let counts: Array2<usize> = vectorizer.transform(&texts).unwrap().to_dense();
478        let true_vocabulary = vec!["one two", "two three", "three four"];
479        assert_vocabulary_eq(&true_vocabulary, vocabulary);
480        assert_counts_for_word!(
481            vocabulary,
482            counts,
483            ("one two", array![1, 0, 0, 0]),
484            ("two three", array![1, 1, 0, 0]),
485            ("three four", array![1, 1, 1, 0])
486        );
487
488        let vectorizer = CountVectorizer::params()
489            .n_gram_range(1, 2)
490            .fit(&texts)
491            .unwrap();
492        let vocabulary = vectorizer.vocabulary();
493        let counts: Array2<usize> = vectorizer.transform(&texts).unwrap().to_dense();
494        let true_vocabulary = vec![
495            "one",
496            "one two",
497            "two",
498            "two three",
499            "three",
500            "three four",
501            "four",
502        ];
503        assert_vocabulary_eq(&true_vocabulary, vocabulary);
504        assert_counts_for_word!(
505            vocabulary,
506            counts,
507            ("one", array![1, 0, 0, 0]),
508            ("one two", array![1, 0, 0, 0]),
509            ("two", array![1, 1, 0, 0]),
510            ("two three", array![1, 1, 0, 0]),
511            ("three", array![1, 1, 1, 0]),
512            ("three four", array![1, 1, 1, 0]),
513            ("four", array![1, 1, 1, 1])
514        );
515    }
516
517    #[test]
518    fn simple_count_test_vocabulary() {
519        let texts = array![
520            "apples.and.trees fi",
521            "flowers,and,bees",
522            "trees!here;and trees:there",
523            "four bees and apples and apples again \u{FB01}"
524        ];
525        let vocabulary = ["apples", "bees", "flowers", "trees", "fi"];
526        let vectorizer = CountVectorizer::params()
527            .fit_vocabulary(&vocabulary)
528            .unwrap();
529        let vect_vocabulary = vectorizer.vocabulary();
530        assert_vocabulary_eq(&vocabulary, vect_vocabulary);
531        let transformed: Array2<usize> = vectorizer.transform(&texts).unwrap().to_dense();
532        assert_counts_for_word!(
533            vect_vocabulary,
534            transformed,
535            ("apples", array![1, 0, 0, 2]),
536            ("bees", array![0, 1, 0, 1]),
537            ("flowers", array![0, 1, 0, 0]),
538            ("trees", array![1, 0, 2, 0]),
539            ("fi", array![1, 0, 0, 1])
540        );
541    }
542
543    #[test]
544    fn simple_count_no_punctuation_test() {
545        let texts = array!["oNe two three four", "TWO three four", "three;four", "four"];
546        let vectorizer = CountVectorizer::params()
547            .tokenizer(Tokenizer::Regex(r"\b[^ ][^ ]+\b".to_string()))
548            .fit(&texts)
549            .unwrap();
550        let vocabulary = vectorizer.vocabulary();
551        let counts: Array2<usize> = vectorizer.transform(&texts).unwrap().to_dense();
552        let true_vocabulary = vec!["one", "two", "three", "four", "three;four"];
553        assert_vocabulary_eq(&true_vocabulary, vocabulary);
554        assert_counts_for_word!(
555            vocabulary,
556            counts,
557            ("one", array![1, 0, 0, 0]),
558            ("two", array![1, 1, 0, 0]),
559            ("three", array![1, 1, 0, 0]),
560            ("four", array![1, 1, 0, 1]),
561            ("three;four", array![0, 0, 1, 0])
562        );
563    }
564
565    #[test]
566    fn simple_count_no_lowercase_test() {
567        let texts = array!["oNe two three four", "TWO three four", "three;four", "four"];
568        let vectorizer = CountVectorizer::params()
569            .convert_to_lowercase(false)
570            .fit(&texts)
571            .unwrap();
572        let vocabulary = vectorizer.vocabulary();
573        let counts: Array2<usize> = vectorizer.transform(&texts).unwrap().to_dense();
574        let true_vocabulary = vec!["oNe", "two", "three", "four", "TWO"];
575        assert_vocabulary_eq(&true_vocabulary, vocabulary);
576        assert_counts_for_word!(
577            vocabulary,
578            counts,
579            ("oNe", array![1, 0, 0, 0]),
580            ("two", array![1, 0, 0, 0]),
581            ("three", array![1, 1, 1, 0]),
582            ("four", array![1, 1, 1, 1]),
583            ("TWO", array![0, 1, 0, 0])
584        );
585    }
586
587    #[test]
588    fn simple_count_no_both_test() {
589        let texts = array![
590            "oNe oNe two three four",
591            "TWO three four",
592            "three;four",
593            "four"
594        ];
595        for vectorizer in [
596            CountVectorizer::params()
597                .convert_to_lowercase(false)
598                .tokenizer(Tokenizer::Regex(r"\b[^ ][^ ]+\b".to_string()))
599                .fit(&texts)
600                .unwrap(),
601            CountVectorizer::params()
602                .convert_to_lowercase(false)
603                .tokenizer(Tokenizer::Function(|x| x.split(" ").collect()))
604                .fit(&texts)
605                .unwrap(),
606        ] {
607            let vocabulary = vectorizer.vocabulary();
608            let counts: Array2<usize> = vectorizer.transform(&texts).unwrap().to_dense();
609            let true_vocabulary = vec!["oNe", "two", "three", "four", "TWO", "three;four"];
610            assert_vocabulary_eq(&true_vocabulary, vocabulary);
611            assert_counts_for_word!(
612                vocabulary,
613                counts,
614                ("oNe", array![2, 0, 0, 0]),
615                ("two", array![1, 0, 0, 0]),
616                ("three", array![1, 1, 0, 0]),
617                ("four", array![1, 1, 0, 1]),
618                ("TWO", array![0, 1, 0, 0]),
619                ("three;four", array![0, 0, 1, 0])
620            );
621        }
622    }
623
624    #[test]
625    fn test_min_max_df() {
626        let texts = array![
627            "one and two and three",
628            "three and four and five",
629            "seven and eight",
630            "maybe ten and eleven",
631            "avoid singletons: one two four five seven eight ten eleven and an and"
632        ];
633        let vectorizer = CountVectorizer::params()
634            .document_frequency(2. / 5., 3. / 5.)
635            .fit(&texts)
636            .unwrap();
637        let vocabulary = vectorizer.vocabulary();
638        let true_vocabulary = vec![
639            "one", "two", "three", "four", "five", "seven", "eight", "ten", "eleven",
640        ];
641        assert_vocabulary_eq(&true_vocabulary, vocabulary);
642    }
643
644    #[test]
645    fn test_fit_transform_files() {
646        let text_files = create_test_files();
647        let vectorizer = CountVectorizer::params()
648            .fit_files(
649                &text_files[..],
650                encoding::all::UTF_8,
651                encoding::DecoderTrap::Strict,
652            )
653            .unwrap();
654        let vocabulary = vectorizer.vocabulary();
655        let counts: Array2<usize> = vectorizer
656            .transform_files(
657                &text_files[..],
658                encoding::all::UTF_8,
659                encoding::DecoderTrap::Strict,
660            )
661            .unwrap()
662            .to_dense();
663        let true_vocabulary = vec!["one", "two", "three", "four"];
664        assert_vocabulary_eq(&true_vocabulary, vocabulary);
665        assert_counts_for_word!(
666            vocabulary,
667            counts,
668            ("one", array![1, 0, 0, 0]),
669            ("two", array![1, 1, 0, 0]),
670            ("three", array![1, 1, 1, 0]),
671            ("four", array![1, 1, 1, 1])
672        );
673
674        let vectorizer = CountVectorizer::params()
675            .n_gram_range(2, 2)
676            .fit_files(
677                &text_files[..],
678                encoding::all::UTF_8,
679                encoding::DecoderTrap::Strict,
680            )
681            .unwrap();
682        let vocabulary = vectorizer.vocabulary();
683        let counts: Array2<usize> = vectorizer
684            .transform_files(
685                &text_files[..],
686                encoding::all::UTF_8,
687                encoding::DecoderTrap::Strict,
688            )
689            .unwrap()
690            .to_dense();
691        let true_vocabulary = vec!["one two", "two three", "three four"];
692        assert_vocabulary_eq(&true_vocabulary, vocabulary);
693        assert_counts_for_word!(
694            vocabulary,
695            counts,
696            ("one two", array![1, 0, 0, 0]),
697            ("two three", array![1, 1, 0, 0]),
698            ("three four", array![1, 1, 1, 0])
699        );
700
701        let vectorizer = CountVectorizer::params()
702            .n_gram_range(1, 2)
703            .fit_files(
704                &text_files[..],
705                encoding::all::UTF_8,
706                encoding::DecoderTrap::Strict,
707            )
708            .unwrap();
709        let vocabulary = vectorizer.vocabulary();
710        let counts: Array2<usize> = vectorizer
711            .transform_files(
712                &text_files[..],
713                encoding::all::UTF_8,
714                encoding::DecoderTrap::Strict,
715            )
716            .unwrap()
717            .to_dense();
718        let true_vocabulary = vec![
719            "one",
720            "one two",
721            "two",
722            "two three",
723            "three",
724            "three four",
725            "four",
726        ];
727        assert_vocabulary_eq(&true_vocabulary, vocabulary);
728        assert_counts_for_word!(
729            vocabulary,
730            counts,
731            ("one", array![1, 0, 0, 0]),
732            ("one two", array![1, 0, 0, 0]),
733            ("two", array![1, 1, 0, 0]),
734            ("two three", array![1, 1, 0, 0]),
735            ("three", array![1, 1, 1, 0]),
736            ("three four", array![1, 1, 1, 0]),
737            ("four", array![1, 1, 1, 1])
738        );
739        delete_test_files(&text_files);
740    }
741
742    #[test]
743    fn test_stopwords() {
744        let texts = array![
745            "one and two and three",
746            "three and four and five",
747            "seven and eight",
748            "maybe ten and eleven",
749            "avoid singletons: one two four five seven eight ten eleven and an and"
750        ];
751        let stopwords = ["and", "maybe", "an"];
752        let vectorizer = CountVectorizer::params()
753            .stopwords(&stopwords)
754            .fit(&texts)
755            .unwrap();
756        let vocabulary = vectorizer.vocabulary();
757        let true_vocabulary = vec![
758            "one",
759            "two",
760            "three",
761            "four",
762            "five",
763            "seven",
764            "eight",
765            "ten",
766            "eleven",
767            "avoid",
768            "singletons",
769        ];
770        println!("voc: {:?}", vocabulary);
771        assert_vocabulary_eq(&true_vocabulary, vocabulary);
772    }
773
774    #[test]
775    fn test_invalid_gram_boundaries() {
776        let texts = array!["oNe two three four", "TWO three four", "three;four", "four"];
777        let vectorizer = CountVectorizer::params().n_gram_range(0, 1).fit(&texts);
778        assert!(vectorizer.is_err());
779        let vectorizer = CountVectorizer::params().n_gram_range(1, 0).fit(&texts);
780        assert!(vectorizer.is_err());
781        let vectorizer = CountVectorizer::params().n_gram_range(2, 1).fit(&texts);
782        assert!(vectorizer.is_err());
783        let vectorizer = CountVectorizer::params()
784            .document_frequency(1.1, 1.)
785            .fit(&texts);
786        assert!(vectorizer.is_err());
787        let vectorizer = CountVectorizer::params()
788            .document_frequency(1., -0.1)
789            .fit(&texts);
790        assert!(vectorizer.is_err());
791        let vectorizer = CountVectorizer::params()
792            .document_frequency(0.5, 0.2)
793            .fit(&texts);
794        assert!(vectorizer.is_err());
795    }
796
797    #[test]
798    fn test_invalid_regex() {
799        let texts = array!["oNe two three four", "TWO three four", "three;four", "four"];
800        let vectorizer = CountVectorizer::params()
801            .tokenizer(Tokenizer::Regex(r"[".to_string()))
802            .fit(&texts);
803        assert!(vectorizer.is_err())
804    }
805
806    fn assert_vocabulary_eq<T: ToString>(true_voc: &[T], voc: &[String]) {
807        for word in true_voc {
808            assert!(voc.contains(&word.to_string()));
809        }
810        assert_eq!(true_voc.len(), voc.len());
811    }
812
813    fn create_test_files() -> Vec<&'static str> {
814        let file_names = vec![
815            "./count_vectorization_test_file_1",
816            "./count_vectorization_test_file_2",
817            "./count_vectorization_test_file_3",
818            "./count_vectorization_test_file_4",
819        ];
820        let contents = &["oNe two three four", "TWO three four", "three;four", "four"];
821        //create files and write contents
822        for (f_name, f_content) in file_names.iter().zip(contents.iter()) {
823            let mut file = File::create(f_name).unwrap();
824            file.write_all(f_content.as_bytes()).unwrap();
825        }
826        file_names
827    }
828
829    fn delete_test_files(file_names: &[&'static str]) {
830        for f_name in file_names.iter() {
831            std::fs::remove_file(f_name).unwrap();
832        }
833    }
834}