1use crate::countgrams::{CountVectorizer, CountVectorizerParams, Tokenizerfp};
4use crate::error::Result;
5use crate::Tokenizer;
6use encoding::types::EncodingRef;
7use encoding::DecoderTrap;
8use ndarray::{Array1, ArrayBase, Data, Ix1};
9use sprs::CsMat;
10
11#[cfg(feature = "serde")]
12use serde_crate::{Deserialize, Serialize};
13
14#[cfg_attr(
15 feature = "serde",
16 derive(Serialize, Deserialize),
17 serde(crate = "serde_crate")
18)]
19#[derive(Clone, Debug, PartialEq, Eq, Hash)]
20pub enum TfIdfMethod {
22 Smooth,
27 NonSmooth,
30 Textbook,
32}
33
34impl TfIdfMethod {
35 pub fn compute_idf(&self, n: usize, df: usize) -> f64 {
36 match self {
37 TfIdfMethod::Smooth => ((1. + n as f64) / (1. + df as f64)).ln() + 1.,
38 TfIdfMethod::NonSmooth => (n as f64 / df as f64).ln() + 1.,
39 TfIdfMethod::Textbook => (n as f64 / (1. + df as f64)).ln(),
40 }
41 }
42}
43
44#[cfg_attr(
51 feature = "serde",
52 derive(Serialize, Deserialize),
53 serde(crate = "serde_crate")
54)]
55#[derive(Clone, Debug)]
56pub struct TfIdfVectorizer {
57 count_vectorizer: CountVectorizerParams,
58 method: TfIdfMethod,
59}
60
61impl std::default::Default for TfIdfVectorizer {
62 fn default() -> Self {
63 Self {
64 count_vectorizer: CountVectorizerParams::default(),
65 method: TfIdfMethod::Smooth,
66 }
67 }
68}
69
70impl TfIdfVectorizer {
71 pub fn tokenizer(self, tokenizer: Tokenizer) -> Self {
74 Self {
75 count_vectorizer: self.count_vectorizer.tokenizer(tokenizer),
76 method: self.method,
77 }
78 }
79
80 pub fn max_features(self, max_features: Option<usize>) -> Self {
83 Self {
84 count_vectorizer: self.count_vectorizer.max_features(max_features),
85 method: self.method,
86 }
87 }
88
89 pub fn convert_to_lowercase(self, convert_to_lowercase: bool) -> Self {
91 Self {
92 count_vectorizer: self
93 .count_vectorizer
94 .convert_to_lowercase(convert_to_lowercase),
95 method: self.method,
96 }
97 }
98
99 pub fn n_gram_range(self, min_n: usize, max_n: usize) -> Self {
105 Self {
106 count_vectorizer: self.count_vectorizer.n_gram_range(min_n, max_n),
107 method: self.method,
108 }
109 }
110
111 pub fn normalize(self, normalize: bool) -> Self {
113 Self {
114 count_vectorizer: self.count_vectorizer.normalize(normalize),
115 method: self.method,
116 }
117 }
118
119 pub fn document_frequency(self, min_freq: f32, max_freq: f32) -> Self {
122 Self {
123 count_vectorizer: self.count_vectorizer.document_frequency(min_freq, max_freq),
124 method: self.method,
125 }
126 }
127
128 pub fn stopwords<T: ToString>(self, stopwords: &[T]) -> Self {
130 Self {
131 count_vectorizer: self.count_vectorizer.stopwords(stopwords),
132 method: self.method,
133 }
134 }
135
136 pub fn fit<T: ToString + Clone, D: Data<Elem = T>>(
144 &self,
145 x: &ArrayBase<D, Ix1>,
146 ) -> Result<FittedTfIdfVectorizer> {
147 let fitted_vectorizer = self.count_vectorizer.fit(x)?;
148 Ok(FittedTfIdfVectorizer {
149 fitted_vectorizer,
150 method: self.method.clone(),
151 })
152 }
153
154 pub fn fit_vocabulary<T: ToString>(&self, words: &[T]) -> Result<FittedTfIdfVectorizer> {
158 let fitted_vectorizer = self.count_vectorizer.fit_vocabulary(words)?;
159 Ok(FittedTfIdfVectorizer {
160 fitted_vectorizer,
161 method: self.method.clone(),
162 })
163 }
164
165 pub fn fit_files<P: AsRef<std::path::Path>>(
166 &self,
167 input: &[P],
168 encoding: EncodingRef,
169 trap: DecoderTrap,
170 ) -> Result<FittedTfIdfVectorizer> {
171 let fitted_vectorizer = self.count_vectorizer.fit_files(input, encoding, trap)?;
172 Ok(FittedTfIdfVectorizer {
173 fitted_vectorizer,
174 method: self.method.clone(),
175 })
176 }
177}
178
179#[cfg_attr(
183 feature = "serde",
184 derive(Serialize, Deserialize),
185 serde(crate = "serde_crate")
186)]
187#[derive(Clone, Debug)]
188pub struct FittedTfIdfVectorizer {
189 fitted_vectorizer: CountVectorizer,
190 method: TfIdfMethod,
191}
192
193impl FittedTfIdfVectorizer {
194 pub fn force_tokenizer_redefinition(&mut self, tokenizer: Tokenizerfp) {
195 self.fitted_vectorizer
196 .force_tokenizer_function_redefinition(tokenizer);
197 }
198
199 pub fn nentries(&self) -> usize {
201 self.fitted_vectorizer.vocabulary.len()
202 }
203
204 pub fn vocabulary(&self) -> &Vec<String> {
206 self.fitted_vectorizer.vocabulary()
207 }
208
209 pub fn method(&self) -> &TfIdfMethod {
211 &self.method
212 }
213
214 pub fn transform<T: ToString, D: Data<Elem = T>>(
218 &self,
219 x: &ArrayBase<D, Ix1>,
220 ) -> Result<CsMat<f64>> {
221 self.fitted_vectorizer.validate_deserialization()?;
222 let (term_freqs, doc_freqs) = self.fitted_vectorizer.get_term_and_document_frequencies(x);
223 Ok(self.apply_tf_idf(term_freqs, doc_freqs))
224 }
225
226 pub fn transform_files<P: AsRef<std::path::Path>>(
227 &self,
228 input: &[P],
229 encoding: EncodingRef,
230 trap: DecoderTrap,
231 ) -> Result<CsMat<f64>> {
232 self.fitted_vectorizer.validate_deserialization()?;
233 let (term_freqs, doc_freqs) = self
234 .fitted_vectorizer
235 .get_term_and_document_frequencies_files(input, encoding, trap);
236 Ok(self.apply_tf_idf(term_freqs, doc_freqs))
237 }
238
239 fn apply_tf_idf(&self, term_freqs: CsMat<usize>, doc_freqs: Array1<usize>) -> CsMat<f64> {
240 let mut term_freqs: CsMat<f64> = term_freqs.map(|x| *x as f64);
241 let inv_doc_freqs =
242 doc_freqs.mapv(|doc_freq| self.method.compute_idf(term_freqs.rows(), doc_freq));
243 for mut row_vec in term_freqs.outer_iterator_mut() {
244 for (col_i, val) in row_vec.iter_mut() {
245 *val *= inv_doc_freqs[col_i];
246 }
247 }
248 term_freqs
249 }
250}
251
252#[cfg(test)]
253mod tests {
254
255 use super::*;
256 use crate::column_for_word;
257 use approx::assert_abs_diff_eq;
258 use ndarray::array;
259 use std::fs::File;
260 use std::io::Write;
261
262 macro_rules! assert_tf_idfs_for_word {
263
264 ($voc:expr, $transf:expr, $(($word:expr, $counts:expr)),*) => {
265 $ (
266 assert_abs_diff_eq!(column_for_word!($voc, $transf, $word), $counts, epsilon=1e-3);
267 )*
268 }
269 }
270
271 #[test]
272 fn autotraits() {
273 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
274 has_autotraits::<TfIdfMethod>();
275 }
276
277 #[test]
278 fn test_tf_idf() {
279 let texts = array![
280 "one and two and three",
281 "three and four and five",
282 "seven and eight",
283 "maybe ten and eleven",
284 "avoid singletons: one two four five seven eight ten eleven and an and"
285 ];
286 let vectorizer = TfIdfVectorizer::default().fit(&texts).unwrap();
287 let vocabulary = vectorizer.vocabulary();
288 let transformed = vectorizer.transform(&texts).unwrap().to_dense();
289 assert_eq!(transformed.dim(), (texts.len(), vocabulary.len()));
290 assert_tf_idfs_for_word!(
291 vocabulary,
292 transformed,
293 ("one", array![1.693, 0.0, 0.0, 0.0, 1.693]),
294 ("two", array![1.693, 0.0, 0.0, 0.0, 1.693]),
295 ("three", array![1.693, 1.693, 0.0, 0.0, 0.0]),
296 ("four", array![0.0, 1.693, 0.0, 0.0, 1.693]),
297 ("and", array![2.0, 2.0, 1.0, 1.0, 2.0]),
298 ("five", array![0.0, 1.693, 0.0, 0.0, 1.693]),
299 ("seven", array![0.0, 0.0, 1.693, 0.0, 1.693]),
300 ("eight", array![0.0, 0.0, 1.693, 0.0, 1.693]),
301 ("ten", array![0.0, 0.0, 0.0, 1.693, 1.693]),
302 ("eleven", array![0.0, 0.0, 0.0, 1.693, 1.693]),
303 ("an", array![0.0, 0.0, 0.0, 0.0, 2.098]),
304 ("avoid", array![0.0, 0.0, 0.0, 0.0, 2.098]),
305 ("singletons", array![0.0, 0.0, 0.0, 0.0, 2.098]),
306 ("maybe", array![0.0, 0.0, 0.0, 2.098, 0.0])
307 );
308 }
309
310 #[test]
311 fn test_tf_idf_files() {
312 let text_files = create_test_files();
313 let vectorizer = TfIdfVectorizer::default()
314 .fit_files(
315 &text_files,
316 encoding::all::UTF_8,
317 encoding::DecoderTrap::Strict,
318 )
319 .unwrap();
320 let vocabulary = vectorizer.vocabulary();
321 let transformed = vectorizer
322 .transform_files(
323 &text_files,
324 encoding::all::UTF_8,
325 encoding::DecoderTrap::Strict,
326 )
327 .unwrap()
328 .to_dense();
329 assert_eq!(transformed.dim(), (text_files.len(), vocabulary.len()));
330 assert_tf_idfs_for_word!(
331 vocabulary,
332 transformed,
333 ("one", array![1.693, 0.0, 0.0, 0.0, 1.693]),
334 ("two", array![1.693, 0.0, 0.0, 0.0, 1.693]),
335 ("three", array![1.693, 1.693, 0.0, 0.0, 0.0]),
336 ("four", array![0.0, 1.693, 0.0, 0.0, 1.693]),
337 ("and", array![2.0, 2.0, 1.0, 1.0, 2.0]),
338 ("five", array![0.0, 1.693, 0.0, 0.0, 1.693]),
339 ("seven", array![0.0, 0.0, 1.693, 0.0, 1.693]),
340 ("eight", array![0.0, 0.0, 1.693, 0.0, 1.693]),
341 ("ten", array![0.0, 0.0, 0.0, 1.693, 1.693]),
342 ("eleven", array![0.0, 0.0, 0.0, 1.693, 1.693]),
343 ("an", array![0.0, 0.0, 0.0, 0.0, 2.098]),
344 ("avoid", array![0.0, 0.0, 0.0, 0.0, 2.098]),
345 ("singletons", array![0.0, 0.0, 0.0, 0.0, 2.098]),
346 ("maybe", array![0.0, 0.0, 0.0, 2.098, 0.0])
347 );
348 delete_test_files(&text_files)
349 }
350
351 fn create_test_files() -> Vec<&'static str> {
352 let file_names = vec![
353 "./tf_idf_vectorization_test_file_1",
354 "./tf_idf_vectorization_test_file_2",
355 "./tf_idf_vectorization_test_file_3",
356 "./tf_idf_vectorization_test_file_4",
357 "./tf_idf_vectorization_test_file_5",
358 ];
359 let contents = &[
360 "one and two and three",
361 "three and four and five",
362 "seven and eight",
363 "maybe ten and eleven",
364 "avoid singletons: one two four five seven eight ten eleven and an and",
365 ];
366 for (f_name, f_content) in file_names.iter().zip(contents.iter()) {
368 let mut file = File::create(f_name).unwrap();
369 file.write_all(f_content.as_bytes()).unwrap();
370 }
371 file_names
372 }
373
374 fn delete_test_files(file_names: &[&'static str]) {
375 for f_name in file_names.iter() {
376 std::fs::remove_file(f_name).unwrap();
377 }
378 }
379}