1use std::cmp::Reverse;
4use std::collections::{HashMap, HashSet};
5use std::io::Read;
6use std::iter::IntoIterator;
7
8use encoding::types::EncodingRef;
9use encoding::DecoderTrap;
10use itertools::sorted;
11use ndarray::{Array1, ArrayBase, ArrayViewMut1, Data, Ix1};
12use regex::Regex;
13use sprs::{CsMat, CsVec};
14use unicode_normalization::UnicodeNormalization;
15
16use crate::error::{PreprocessingError, Result};
17use crate::helpers::NGramList;
18pub use hyperparams::{CountVectorizerParams, CountVectorizerValidParams};
19use linfa::ParamGuard;
20
21#[cfg(feature = "serde")]
22use serde_crate::{Deserialize, Serialize};
23
24mod hyperparams;
25
26pub(crate) type Tokenizerfp = fn(&str) -> Vec<&str>;
27pub enum Tokenizer {
28 Function(Tokenizerfp),
29 Regex(String),
30}
31
32impl CountVectorizerValidParams {
33 pub fn fit<T: ToString + Clone, D: Data<Elem = T>>(
42 &self,
43 x: &ArrayBase<D, Ix1>,
44 ) -> Result<CountVectorizer> {
45 let mut vocabulary: HashMap<String, (usize, usize)> = HashMap::new();
47 for string in x.iter().map(|s| transform_string(s.to_string(), self)) {
48 self.read_document_into_vocabulary(string, &self.split_regex(), &mut vocabulary);
49 }
50
51 let mut vocabulary = self.filter_vocabulary(vocabulary, x.len());
52
53 let vec_vocabulary = hashmap_to_vocabulary(&mut vocabulary);
54
55 Ok(CountVectorizer {
56 vocabulary,
57 vec_vocabulary,
58 properties: self.clone(),
59 })
60 }
61
62 pub fn fit_files<P: AsRef<std::path::Path>>(
76 &self,
77 input: &[P],
78 encoding: EncodingRef,
79 trap: DecoderTrap,
80 ) -> Result<CountVectorizer> {
81 let mut vocabulary: HashMap<String, (usize, usize)> = HashMap::new();
83 let documents_count = input.len();
84 for path in input {
85 let mut file = std::fs::File::open(path)?;
86 let mut document_bytes = Vec::new();
87 file.read_to_end(&mut document_bytes)?;
88 let document = encoding::decode(&document_bytes, trap, encoding).0;
89 if document.is_err() {
91 return Err(PreprocessingError::EncodingError(document.err().unwrap()));
92 }
93 let document = transform_string(document.unwrap(), self);
95 self.read_document_into_vocabulary(document, &self.split_regex(), &mut vocabulary);
96 }
97
98 let mut vocabulary = self.filter_vocabulary(vocabulary, documents_count);
99 let vec_vocabulary = hashmap_to_vocabulary(&mut vocabulary);
100
101 Ok(CountVectorizer {
102 vocabulary,
103 vec_vocabulary,
104 properties: self.clone(),
105 })
106 }
107
108 pub fn fit_vocabulary<T: ToString>(&self, words: &[T]) -> Result<CountVectorizer> {
112 let mut vocabulary: HashMap<String, (usize, usize)> = HashMap::with_capacity(words.len());
113 for item in words.iter().map(|w| w.to_string()) {
114 let len = vocabulary.len();
115 vocabulary.entry(item).or_insert((len, 1));
117 }
118 let vec_vocabulary = hashmap_to_vocabulary(&mut vocabulary);
119 Ok(CountVectorizer {
120 vocabulary,
121 vec_vocabulary,
122 properties: self.clone(),
123 })
124 }
125
126 fn filter_vocabulary(
131 &self,
132 vocabulary: HashMap<String, (usize, usize)>,
133 n_documents: usize,
134 ) -> HashMap<String, (usize, usize)> {
135 let (min_df, max_df) = self.document_frequency();
136 let len_f32 = n_documents as f32;
137 let (min_abs_df, max_abs_df) = ((min_df * len_f32) as usize, (max_df * len_f32) as usize);
138
139 let vocabulary = if min_abs_df == 0 && max_abs_df == n_documents {
140 match &self.stopwords() {
141 None => vocabulary,
142 Some(stopwords) => vocabulary
143 .into_iter()
144 .filter(|(entry, (_, _))| !stopwords.contains(entry))
145 .collect(),
146 }
147 } else {
148 match &self.stopwords() {
149 None => vocabulary
150 .into_iter()
151 .filter(|(_, (_, abs_count))| {
152 *abs_count >= min_abs_df && *abs_count <= max_abs_df
153 })
154 .collect(),
155 Some(stopwords) => vocabulary
156 .into_iter()
157 .filter(|(entry, (_, abs_count))| {
158 *abs_count >= min_abs_df
159 && *abs_count <= max_abs_df
160 && !stopwords.contains(entry)
161 })
162 .collect(),
163 }
164 };
165
166 if let Some(max_features) = self.max_features() {
167 sorted(
168 vocabulary
169 .into_iter()
170 .map(|(word, (x, freq))| (Reverse(freq), Reverse(word), x)),
171 )
172 .take(max_features)
173 .map(|(freq, word, x)| (word.0, (x, freq.0)))
174 .collect()
175 } else {
176 vocabulary
177 }
178 }
179
180 fn read_document_into_vocabulary(
184 &self,
185 doc: String,
186 regex: &Regex,
187 vocabulary: &mut HashMap<String, (usize, usize)>,
188 ) {
189 let words = if let Some(tokenizer) = self.tokenizer_function() {
190 tokenizer(&doc)
191 } else {
192 regex.find_iter(&doc).map(|mat| mat.as_str()).collect()
193 };
194 let list = NGramList::new(words, self.n_gram_range());
195 let document_vocabulary: HashSet<String> = list.into_iter().flatten().collect();
196 for word in document_vocabulary {
197 let len = vocabulary.len();
198 if let Some((_, freq)) = vocabulary.get_mut(&word) {
200 *freq += 1;
201 } else {
203 vocabulary.insert(word, (len, 1));
204 }
205 }
206 }
207}
208
209impl CountVectorizerParams {
210 pub fn fit<T: ToString + Clone, D: Data<Elem = T>>(
219 &self,
220 x: &ArrayBase<D, Ix1>,
221 ) -> Result<CountVectorizer> {
222 self.check_ref().and_then(|params| params.fit(x))
223 }
224
225 pub fn fit_files<P: AsRef<std::path::Path>>(
239 &self,
240 input: &[P],
241 encoding: EncodingRef,
242 trap: DecoderTrap,
243 ) -> Result<CountVectorizer> {
244 self.check_ref()
245 .and_then(|params| params.fit_files(input, encoding, trap))
246 }
247
248 pub fn fit_vocabulary<T: ToString>(&self, words: &[T]) -> Result<CountVectorizer> {
252 self.check_ref()
253 .and_then(|params| params.fit_vocabulary(words))
254 }
255}
256
257#[cfg_attr(
260 feature = "serde",
261 derive(Serialize, Deserialize),
262 serde(crate = "serde_crate")
263)]
264#[derive(Debug, Clone)]
265pub struct CountVectorizer {
266 pub(crate) vocabulary: HashMap<String, (usize, usize)>,
267 pub(crate) vec_vocabulary: Vec<String>,
268 pub(crate) properties: CountVectorizerValidParams,
269}
270
271impl CountVectorizer {
272 pub fn params() -> CountVectorizerParams {
274 CountVectorizerParams::default()
275 }
276
277 pub fn nentries(&self) -> usize {
279 self.vocabulary.len()
280 }
281
282 pub fn force_tokenizer_function_redefinition(&mut self, tokenizer: Tokenizerfp) {
283 self.properties.tokenizer_function = Some(tokenizer);
284 }
285
286 pub(crate) fn validate_deserialization(&self) -> Result<()> {
287 if self.properties.tokenizer_function().is_none()
288 && self.properties.tokenizer_deserialization_guard
289 {
290 return Err(PreprocessingError::TokenizerNotSet);
291 }
292
293 Ok(())
294 }
295
296 pub fn transform<T: ToString, D: Data<Elem = T>>(
301 &self,
302 x: &ArrayBase<D, Ix1>,
303 ) -> Result<CsMat<usize>> {
304 self.validate_deserialization()?;
305 let (vectorized, _) = self.get_term_and_document_frequencies(x);
306 Ok(vectorized)
307 }
308
309 pub fn transform_files<P: AsRef<std::path::Path>>(
317 &self,
318 input: &[P],
319 encoding: EncodingRef,
320 trap: DecoderTrap,
321 ) -> Result<CsMat<usize>> {
322 self.validate_deserialization()?;
323 let (vectorized, _) = self.get_term_and_document_frequencies_files(input, encoding, trap);
324 Ok(vectorized)
325 }
326
327 pub fn vocabulary(&self) -> &Vec<String> {
329 &self.vec_vocabulary
330 }
331
332 pub(crate) fn get_term_and_document_frequencies<T: ToString, D: Data<Elem = T>>(
335 &self,
336 x: &ArrayBase<D, Ix1>,
337 ) -> (CsMat<usize>, Array1<usize>) {
338 let mut document_frequencies = Array1::zeros(self.vocabulary.len());
339 let mut sprs_vectorized = CsMat::empty(sprs::CompressedStorage::CSR, self.vocabulary.len());
340 sprs_vectorized.reserve_outer_dim_exact(x.len());
341 let regex = self.properties.split_regex();
342 for string in x.into_iter().map(|s| s.to_string()) {
343 let row = self.analyze_document(string, ®ex, document_frequencies.view_mut());
344 sprs_vectorized = sprs_vectorized.append_outer_csvec(row.view());
345 }
346 (sprs_vectorized, document_frequencies)
347 }
348
349 pub(crate) fn get_term_and_document_frequencies_files<P: AsRef<std::path::Path>>(
352 &self,
353 input: &[P],
354 encoding: EncodingRef,
355 trap: DecoderTrap,
356 ) -> (CsMat<usize>, Array1<usize>) {
357 let mut document_frequencies = Array1::zeros(self.vocabulary.len());
358 let mut sprs_vectorized = CsMat::empty(sprs::CompressedStorage::CSR, self.vocabulary.len());
359 sprs_vectorized.reserve_outer_dim_exact(input.len());
360 let regex = self.properties.split_regex();
361 for file_path in input.iter() {
362 let mut file = std::fs::File::open(file_path).unwrap();
363 let mut document_bytes = Vec::new();
364 file.read_to_end(&mut document_bytes).unwrap();
365 let document = encoding::decode(&document_bytes, trap, encoding).0.unwrap();
366 sprs_vectorized = sprs_vectorized.append_outer_csvec(
367 self.analyze_document(document, ®ex, document_frequencies.view_mut())
368 .view(),
369 );
370 }
371 (sprs_vectorized, document_frequencies)
372 }
373
374 fn analyze_document(
377 &self,
378 document: String,
379 regex: &Regex,
380 mut doc_freqs: ArrayViewMut1<usize>,
381 ) -> CsVec<usize> {
382 let mut term_frequencies: Array1<usize> = Array1::zeros(self.vocabulary.len());
388 let string = transform_string(document, &self.properties);
389 let words = if let Some(tokenizer) = self.properties.tokenizer_function() {
390 tokenizer(&string)
391 } else {
392 regex.find_iter(&string).map(|mat| mat.as_str()).collect()
393 };
394 let list = NGramList::new(words, self.properties.n_gram_range());
395 for ngram_items in list {
396 for item in ngram_items {
397 if let Some((item_index, _)) = self.vocabulary.get(&item) {
398 let term_freq = term_frequencies.get_mut(*item_index).unwrap();
399 *term_freq += 1;
400 }
401 }
402 }
403 let mut sprs_term_frequencies = CsVec::empty(self.vocabulary.len());
404
405 for (i, freq) in term_frequencies
407 .into_iter()
408 .enumerate()
409 .filter(|(_, f)| *f > 0)
410 {
411 sprs_term_frequencies.append(i, freq);
412 doc_freqs[i] += 1;
413 }
414 sprs_term_frequencies
415 }
416}
417
418fn transform_string(mut string: String, properties: &CountVectorizerValidParams) -> String {
419 if properties.normalize() {
420 string = string.nfkd().collect();
421 }
422 if properties.convert_to_lowercase() {
423 string = string.to_lowercase();
424 }
425 string
426}
427
428fn hashmap_to_vocabulary(map: &mut HashMap<String, (usize, usize)>) -> Vec<String> {
429 let mut vec = Vec::with_capacity(map.len());
430 for (word, (ref mut idx, _)) in map {
431 *idx = vec.len();
432 vec.push(word.clone());
433 }
434 vec
435}
436
437#[cfg(test)]
438mod tests {
439
440 use super::*;
441 use crate::column_for_word;
442 use ndarray::{array, Array2};
443 use std::fs::File;
444 use std::io::Write;
445
446 macro_rules! assert_counts_for_word {
447
448 ($voc:expr, $transf:expr, $(($word:expr, $counts:expr)),*) => {
449 $ (
450 assert_eq!(column_for_word!($voc, $transf, $word), $counts);
451 )*
452 }
453 }
454
455 #[test]
456 fn simple_count_test() {
457 let texts = array!["oNe two three four", "TWO three four", "three;four", "four"];
458 let vectorizer = CountVectorizer::params().fit(&texts).unwrap();
459 let vocabulary = vectorizer.vocabulary();
460 let counts: Array2<usize> = vectorizer.transform(&texts).unwrap().to_dense();
461 let true_vocabulary = vec!["one", "two", "three", "four"];
462 assert_vocabulary_eq(&true_vocabulary, vocabulary);
463 assert_counts_for_word!(
464 vocabulary,
465 counts,
466 ("one", array![1, 0, 0, 0]),
467 ("two", array![1, 1, 0, 0]),
468 ("three", array![1, 1, 1, 0]),
469 ("four", array![1, 1, 1, 1])
470 );
471
472 let vectorizer = CountVectorizer::params()
473 .n_gram_range(2, 2)
474 .fit(&texts)
475 .unwrap();
476 let vocabulary = vectorizer.vocabulary();
477 let counts: Array2<usize> = vectorizer.transform(&texts).unwrap().to_dense();
478 let true_vocabulary = vec!["one two", "two three", "three four"];
479 assert_vocabulary_eq(&true_vocabulary, vocabulary);
480 assert_counts_for_word!(
481 vocabulary,
482 counts,
483 ("one two", array![1, 0, 0, 0]),
484 ("two three", array![1, 1, 0, 0]),
485 ("three four", array![1, 1, 1, 0])
486 );
487
488 let vectorizer = CountVectorizer::params()
489 .n_gram_range(1, 2)
490 .fit(&texts)
491 .unwrap();
492 let vocabulary = vectorizer.vocabulary();
493 let counts: Array2<usize> = vectorizer.transform(&texts).unwrap().to_dense();
494 let true_vocabulary = vec![
495 "one",
496 "one two",
497 "two",
498 "two three",
499 "three",
500 "three four",
501 "four",
502 ];
503 assert_vocabulary_eq(&true_vocabulary, vocabulary);
504 assert_counts_for_word!(
505 vocabulary,
506 counts,
507 ("one", array![1, 0, 0, 0]),
508 ("one two", array![1, 0, 0, 0]),
509 ("two", array![1, 1, 0, 0]),
510 ("two three", array![1, 1, 0, 0]),
511 ("three", array![1, 1, 1, 0]),
512 ("three four", array![1, 1, 1, 0]),
513 ("four", array![1, 1, 1, 1])
514 );
515 }
516
517 #[test]
518 fn simple_count_test_vocabulary() {
519 let texts = array![
520 "apples.and.trees fi",
521 "flowers,and,bees",
522 "trees!here;and trees:there",
523 "four bees and apples and apples again \u{FB01}"
524 ];
525 let vocabulary = ["apples", "bees", "flowers", "trees", "fi"];
526 let vectorizer = CountVectorizer::params()
527 .fit_vocabulary(&vocabulary)
528 .unwrap();
529 let vect_vocabulary = vectorizer.vocabulary();
530 assert_vocabulary_eq(&vocabulary, vect_vocabulary);
531 let transformed: Array2<usize> = vectorizer.transform(&texts).unwrap().to_dense();
532 assert_counts_for_word!(
533 vect_vocabulary,
534 transformed,
535 ("apples", array![1, 0, 0, 2]),
536 ("bees", array![0, 1, 0, 1]),
537 ("flowers", array![0, 1, 0, 0]),
538 ("trees", array![1, 0, 2, 0]),
539 ("fi", array![1, 0, 0, 1])
540 );
541 }
542
543 #[test]
544 fn simple_count_no_punctuation_test() {
545 let texts = array!["oNe two three four", "TWO three four", "three;four", "four"];
546 let vectorizer = CountVectorizer::params()
547 .tokenizer(Tokenizer::Regex(r"\b[^ ][^ ]+\b".to_string()))
548 .fit(&texts)
549 .unwrap();
550 let vocabulary = vectorizer.vocabulary();
551 let counts: Array2<usize> = vectorizer.transform(&texts).unwrap().to_dense();
552 let true_vocabulary = vec!["one", "two", "three", "four", "three;four"];
553 assert_vocabulary_eq(&true_vocabulary, vocabulary);
554 assert_counts_for_word!(
555 vocabulary,
556 counts,
557 ("one", array![1, 0, 0, 0]),
558 ("two", array![1, 1, 0, 0]),
559 ("three", array![1, 1, 0, 0]),
560 ("four", array![1, 1, 0, 1]),
561 ("three;four", array![0, 0, 1, 0])
562 );
563 }
564
565 #[test]
566 fn simple_count_no_lowercase_test() {
567 let texts = array!["oNe two three four", "TWO three four", "three;four", "four"];
568 let vectorizer = CountVectorizer::params()
569 .convert_to_lowercase(false)
570 .fit(&texts)
571 .unwrap();
572 let vocabulary = vectorizer.vocabulary();
573 let counts: Array2<usize> = vectorizer.transform(&texts).unwrap().to_dense();
574 let true_vocabulary = vec!["oNe", "two", "three", "four", "TWO"];
575 assert_vocabulary_eq(&true_vocabulary, vocabulary);
576 assert_counts_for_word!(
577 vocabulary,
578 counts,
579 ("oNe", array![1, 0, 0, 0]),
580 ("two", array![1, 0, 0, 0]),
581 ("three", array![1, 1, 1, 0]),
582 ("four", array![1, 1, 1, 1]),
583 ("TWO", array![0, 1, 0, 0])
584 );
585 }
586
587 #[test]
588 fn simple_count_no_both_test() {
589 let texts = array![
590 "oNe oNe two three four",
591 "TWO three four",
592 "three;four",
593 "four"
594 ];
595 for vectorizer in [
596 CountVectorizer::params()
597 .convert_to_lowercase(false)
598 .tokenizer(Tokenizer::Regex(r"\b[^ ][^ ]+\b".to_string()))
599 .fit(&texts)
600 .unwrap(),
601 CountVectorizer::params()
602 .convert_to_lowercase(false)
603 .tokenizer(Tokenizer::Function(|x| x.split(" ").collect()))
604 .fit(&texts)
605 .unwrap(),
606 ] {
607 let vocabulary = vectorizer.vocabulary();
608 let counts: Array2<usize> = vectorizer.transform(&texts).unwrap().to_dense();
609 let true_vocabulary = vec!["oNe", "two", "three", "four", "TWO", "three;four"];
610 assert_vocabulary_eq(&true_vocabulary, vocabulary);
611 assert_counts_for_word!(
612 vocabulary,
613 counts,
614 ("oNe", array![2, 0, 0, 0]),
615 ("two", array![1, 0, 0, 0]),
616 ("three", array![1, 1, 0, 0]),
617 ("four", array![1, 1, 0, 1]),
618 ("TWO", array![0, 1, 0, 0]),
619 ("three;four", array![0, 0, 1, 0])
620 );
621 }
622 }
623
624 #[test]
625 fn test_min_max_df() {
626 let texts = array![
627 "one and two and three",
628 "three and four and five",
629 "seven and eight",
630 "maybe ten and eleven",
631 "avoid singletons: one two four five seven eight ten eleven and an and"
632 ];
633 let vectorizer = CountVectorizer::params()
634 .document_frequency(2. / 5., 3. / 5.)
635 .fit(&texts)
636 .unwrap();
637 let vocabulary = vectorizer.vocabulary();
638 let true_vocabulary = vec![
639 "one", "two", "three", "four", "five", "seven", "eight", "ten", "eleven",
640 ];
641 assert_vocabulary_eq(&true_vocabulary, vocabulary);
642 }
643
644 #[test]
645 fn test_fit_transform_files() {
646 let text_files = create_test_files();
647 let vectorizer = CountVectorizer::params()
648 .fit_files(
649 &text_files[..],
650 encoding::all::UTF_8,
651 encoding::DecoderTrap::Strict,
652 )
653 .unwrap();
654 let vocabulary = vectorizer.vocabulary();
655 let counts: Array2<usize> = vectorizer
656 .transform_files(
657 &text_files[..],
658 encoding::all::UTF_8,
659 encoding::DecoderTrap::Strict,
660 )
661 .unwrap()
662 .to_dense();
663 let true_vocabulary = vec!["one", "two", "three", "four"];
664 assert_vocabulary_eq(&true_vocabulary, vocabulary);
665 assert_counts_for_word!(
666 vocabulary,
667 counts,
668 ("one", array![1, 0, 0, 0]),
669 ("two", array![1, 1, 0, 0]),
670 ("three", array![1, 1, 1, 0]),
671 ("four", array![1, 1, 1, 1])
672 );
673
674 let vectorizer = CountVectorizer::params()
675 .n_gram_range(2, 2)
676 .fit_files(
677 &text_files[..],
678 encoding::all::UTF_8,
679 encoding::DecoderTrap::Strict,
680 )
681 .unwrap();
682 let vocabulary = vectorizer.vocabulary();
683 let counts: Array2<usize> = vectorizer
684 .transform_files(
685 &text_files[..],
686 encoding::all::UTF_8,
687 encoding::DecoderTrap::Strict,
688 )
689 .unwrap()
690 .to_dense();
691 let true_vocabulary = vec!["one two", "two three", "three four"];
692 assert_vocabulary_eq(&true_vocabulary, vocabulary);
693 assert_counts_for_word!(
694 vocabulary,
695 counts,
696 ("one two", array![1, 0, 0, 0]),
697 ("two three", array![1, 1, 0, 0]),
698 ("three four", array![1, 1, 1, 0])
699 );
700
701 let vectorizer = CountVectorizer::params()
702 .n_gram_range(1, 2)
703 .fit_files(
704 &text_files[..],
705 encoding::all::UTF_8,
706 encoding::DecoderTrap::Strict,
707 )
708 .unwrap();
709 let vocabulary = vectorizer.vocabulary();
710 let counts: Array2<usize> = vectorizer
711 .transform_files(
712 &text_files[..],
713 encoding::all::UTF_8,
714 encoding::DecoderTrap::Strict,
715 )
716 .unwrap()
717 .to_dense();
718 let true_vocabulary = vec![
719 "one",
720 "one two",
721 "two",
722 "two three",
723 "three",
724 "three four",
725 "four",
726 ];
727 assert_vocabulary_eq(&true_vocabulary, vocabulary);
728 assert_counts_for_word!(
729 vocabulary,
730 counts,
731 ("one", array![1, 0, 0, 0]),
732 ("one two", array![1, 0, 0, 0]),
733 ("two", array![1, 1, 0, 0]),
734 ("two three", array![1, 1, 0, 0]),
735 ("three", array![1, 1, 1, 0]),
736 ("three four", array![1, 1, 1, 0]),
737 ("four", array![1, 1, 1, 1])
738 );
739 delete_test_files(&text_files);
740 }
741
742 #[test]
743 fn test_stopwords() {
744 let texts = array![
745 "one and two and three",
746 "three and four and five",
747 "seven and eight",
748 "maybe ten and eleven",
749 "avoid singletons: one two four five seven eight ten eleven and an and"
750 ];
751 let stopwords = ["and", "maybe", "an"];
752 let vectorizer = CountVectorizer::params()
753 .stopwords(&stopwords)
754 .fit(&texts)
755 .unwrap();
756 let vocabulary = vectorizer.vocabulary();
757 let true_vocabulary = vec![
758 "one",
759 "two",
760 "three",
761 "four",
762 "five",
763 "seven",
764 "eight",
765 "ten",
766 "eleven",
767 "avoid",
768 "singletons",
769 ];
770 println!("voc: {:?}", vocabulary);
771 assert_vocabulary_eq(&true_vocabulary, vocabulary);
772 }
773
774 #[test]
775 fn test_invalid_gram_boundaries() {
776 let texts = array!["oNe two three four", "TWO three four", "three;four", "four"];
777 let vectorizer = CountVectorizer::params().n_gram_range(0, 1).fit(&texts);
778 assert!(vectorizer.is_err());
779 let vectorizer = CountVectorizer::params().n_gram_range(1, 0).fit(&texts);
780 assert!(vectorizer.is_err());
781 let vectorizer = CountVectorizer::params().n_gram_range(2, 1).fit(&texts);
782 assert!(vectorizer.is_err());
783 let vectorizer = CountVectorizer::params()
784 .document_frequency(1.1, 1.)
785 .fit(&texts);
786 assert!(vectorizer.is_err());
787 let vectorizer = CountVectorizer::params()
788 .document_frequency(1., -0.1)
789 .fit(&texts);
790 assert!(vectorizer.is_err());
791 let vectorizer = CountVectorizer::params()
792 .document_frequency(0.5, 0.2)
793 .fit(&texts);
794 assert!(vectorizer.is_err());
795 }
796
797 #[test]
798 fn test_invalid_regex() {
799 let texts = array!["oNe two three four", "TWO three four", "three;four", "four"];
800 let vectorizer = CountVectorizer::params()
801 .tokenizer(Tokenizer::Regex(r"[".to_string()))
802 .fit(&texts);
803 assert!(vectorizer.is_err())
804 }
805
806 fn assert_vocabulary_eq<T: ToString>(true_voc: &[T], voc: &[String]) {
807 for word in true_voc {
808 assert!(voc.contains(&word.to_string()));
809 }
810 assert_eq!(true_voc.len(), voc.len());
811 }
812
813 fn create_test_files() -> Vec<&'static str> {
814 let file_names = vec![
815 "./count_vectorization_test_file_1",
816 "./count_vectorization_test_file_2",
817 "./count_vectorization_test_file_3",
818 "./count_vectorization_test_file_4",
819 ];
820 let contents = &["oNe two three four", "TWO three four", "three;four", "four"];
821 for (f_name, f_content) in file_names.iter().zip(contents.iter()) {
823 let mut file = File::create(f_name).unwrap();
824 file.write_all(f_content.as_bytes()).unwrap();
825 }
826 file_names
827 }
828
829 fn delete_test_files(file_names: &[&'static str]) {
830 for f_name in file_names.iter() {
831 std::fs::remove_file(f_name).unwrap();
832 }
833 }
834}