A comprehensive toolkit for Statistical Learning in Rust.

Linfa is a comprehensive toolkit for statistical learning, providing algorithms for optimal model and density estimation.

It contains common learning algorithms such as Elastic Net, Support Vector Machines, Kernel Methods, Linear Decision Trees, etc. and bundles them together.

You can find the source code at rust-ml/linfa.

let (train, valid) = linfa_datasets::iris()

// train the model
let model = GaussianNbParams::params()

// Predict the validation dataset
let pred = model.predict(&valid);

// construct confusion matrix
let cm = pred.confusion_matrix(&valid)?;

// print confusion matrix, accuracy and precision
println!("{:?}", cm);
println!("accuracy {}, precision {}", 
    cm.accuracy(), cm.precision());
// assume we have a dataset with multiple,
// uncorrelated targets and we want to train
// a single model for each target variable
let model = train.target_iter()
    .map(|x| params.fit(&x).unwrap())
    .collect::<MultiTarget<_, _>>()?;

// composing `model` returns multiple targets
let valid_est = model.predict(valid);
println!("{}", valid_est.ntargets());
// normalize the iris dataset
let ds = linfa_datasets::iris();
let ds = Pca::params(3).whiten(true).fit(&ds).transform(ds);

// transform to two-dimensional embeddings
let ds = TSne::embedding_size(2)

// write embedding to file
let mut f = File::create("iris.dat")?;
for (x, y) in ds.sample_iter() {
    f.write(format!("{} {} {}\n", x[0], x[1], y[0]).as_bytes())?;
// perform cross-validation with the F1 score
let f1_runs = dataset
    .iter_fold(8, |v| params.fit(&v).unwrap())
    .map(|(model, valid)| {
        let cm = model
            .mapv(|x| x > Pr::even())
// calculate mean and standard deviation
println!("F1 score: {}±{}",
    f1_runs.std_axis(Axis(0), 0.0),
// generate RBF kernel with sparsity constraints
let kernel = Kernel::params()

let embedding = DiffusionMap::<f64>::params(2)

// get embedding
let embedding = embedding.embedding();
let params = Svm::<_, Pr>::params()

// assume we have a binary decision model (here SVM) 
// predicting probability. We can merge them into a 
// multi-class model by collecting several of them
// into a `MultiClassModel`
let model = train
    .map(|(l, x)| (l, params.fit(&x).unwrap()))
    .collect::<MultiClassModel<_, _>>();

// predict multi-class label
let pred = model.predict(&valid);
// parameters to compare
let ratios = vec![0.1, 0.2, 0.5, 0.7, 1.0];

// create a model for each parameter
let models = ratios
    .map(|ratio| ElasticNet::params().penalty(0.3).l1_ratio(*ratio))

// get the mean r2 validation score across 5 folds for each model
let r2_values =
    dataset.cross_validate(5, &models, |prediction, truth| prediction.r2(&truth))?;

// show the mean r2 score for each parameter choice
for (ratio, r2) in ratios.iter().zip(r2_values.iter()) {
    println!("L1 ratio: {}, r2 score: {}", ratio, r2);
// everything above 6.5 is considered a good wine
let (train, valid) = linfa_datasets::winequality()
    .map_targets(|x| *x > 6)

// train SVM with nu=0.01 and RBF with eps=80.0
let model = Svm::params()

// print model performance and number of SVs
println!("{}", model);
let (train, valid) = linfa_datasets::diabetes()

// train pure LASSO model with 0.1 penalty
let model = ElasticNet::params()

println!("z score: {:?}", model.z_score());

// validate
let y_est = model.predict(&valid);
println!("predicted variance: {}", y_est.r2(&valid)?);
// Load linnerud dataset with  20 samples, 
// 3 input features, 3 output features
let ds = linfa_datasets::linnerud();

// Fit PLS2 method using 2 principal components 
// (latent variables)
let pls = PlsRegression::params(2).fit(&ds)?;

// We can either apply the dimension reduction to the dataset
let reduced_ds = pls.transform(ds);

// ... or predict outputs given a new input sample.
let exercices = array![[14., 146., 61.], [6., 80., 60.]];
let physio_measures = pls.predict(exercices);
let (train, valid) = linfa_datasets::iris()

// Train model with Gini criterion
let gini_model = DecisionTree::params()

let cm = gini_model.predict(&valid)

println!("{:?}", cm);
println!("Accuracy {}%", cm.accuracy() * 100.0);

Fast by default

The Rust language allows us to implement everything directly without the need to trade performance for readability.

No need to have a second language for performance reason.

Embedded in the Rust ecosystem

We use many existing tools from the Rust ecosystem to improve the code quality.

We are using rustfmt, clippy, Tarpaulin, Github Actions and approx for testing and verification.

Reason more clearly

Benefit from a clear ownership model and structure your code flow with functional programming.

Many composing algorithms can be expressed more clearly with features missing in older imperative languages.

Published by rust-ml group

We are a group of ML/stats people who are exploring the language space for more performant and readable implementations.

Beginner friendly

No need to be an expert or have years of programming experience.

If you are interest in a specific implementation or need some inspiration to learn Rust for data processing reach out to us. We are happy to help on both sides.

FOSS and accessible

Everything is published under the Apache/MIT license and we are trying constantly to make the code as accessible as possible for humans and machines.

The development of Linfa happens in public and every contribution is welcome.


Typical implementations for clustering algorithms, such as K-Means, Gaussian Mixture and DBSCAN.

Dimensionality reduction

Variance preserving algorithms with linear (PCA) and non-linear (Diffusion Maps) implementations.

The implementation uses LOBPCG and can handle high-dimensional data very efficiently.

Target modeling

Regression with shrinkage operators, such as LASSO and Ridge.

Classification of linear, as well as non-linear data with SVM and kernel-methods.

There are many more implementations, which you can find here.