Using GMMs in Rust

24 Jun 2016

This post aims to introduce Gaussian Mixture Models (from now on referred to as GMMs) and explain what they can be used for. To do that I’ll be creating some synthetic data and training a GMM on that using rusty-machine. This post is fairly heavy on theory but I promise there is some code.

Some familiarity with the following will be helpful (though not entirely necessary):

Basic probability concepts

Bayesian statistics

Gaussian distributions

Basic machine learning concepts

What are you talking about?

Before jumping into GMMs let’s define a more general Mixture Model. A Mixture Model is a probabilistic model used to represent subclasses within a whole population. We use Mixture Models to make sense of datasets that we believe are composed of a mixture of different groups. We want to learn the composition of the wider population from which our data has been drawn.

GMMs are a specific case of a Mixture Model which can be used when the population we are considering is made up of continuous measurements (real numbers). We attempt to model this population as a combination of Gaussian random variables each representing a sub-group. A more formal definition will follow the example.

A GMM Example

Imagine that we have a large room filled with males aged 10, 15 and 20. We’d expect the heights within each age group to be roughly normally distributed* - with perhaps some overlap between the groups. This lends itself naturally to a GMM - we consider the whole population to be made up of sub groups with each having height normally distributed around some average.

What does all this mean in practice? We want to determine the mean and variance of these underlying Gaussians to learn something about our population. To do this we train the GMM using the data that we have gathered. This involves using an algorithm - commonly Expectation Maximization2.

A bit more formally…

Mixture models tend to have some pretty hefty notation1. I’ll try to introduce things in a sensible order…

First we have K - the number of Gaussians we mix in the model; in our above example this was 3. And also N, the total number of samples.

For each of our K Gaussians we will have some mean and variance, we’ll denote those: µ1..K and σ21..K , respectively.

And the last thing we’ll need is the mixture weights, which we’ll denote φ1..K . Where φi is the (prior) probability that a sample belongs to sub group i. In many cases we may be content assuming each subcategory is equally likely.

Putting this back into comprehensible language: we have K Gaussian random variables which represent some subcategories in our data. And we have some belief (φ) of how often each subcategory appears in our data.

If you’ve kept up so far, great! Next I’m going to write about a great way to test a broad class of models.

Using simulated data to test models

Suppose we’ve just finished writing some code that trains a GMM on a dataset. How can we be sure that our code works? We probably want to test our code out on some data that we believe contains subclasses that can be represented by Gaussians. We could try to find such a dataset - or we can create one!**

This is a general technique that can be applied to a broad class of models. The plan is to construct some data which explicitly contains the properties described by our model. We then train the model on this dataset and hope to recover the properties as we defined them.

In the case of GMMs: we define some Gaussian random variables by choosing pairs of means and variances, (µk, σ2k). We also define the mixture probabilities for our model which define how much each Gaussian will contribute to the population as a whole. We then draw samples from the model by choosing a Gaussian (according to the mixture probabilities) and, in turn, drawing a sample from it.

By repeating this process we get a data set which should contain the properties defined by the means, variances, and mixture weights chosen. Now we train our GMM on the generated data and try to recover the means, variance and mixture weights we chose above.

Finally, some code!

As promised, we’ll be using Rust. In this section we’ll walk through some code for simulating samples from a GMM. We’ll then use rusty-machine to train a GMM on this simulated data - and verify that we can learn the underlying model parameters.

userusty_machine::stats::dist::gaussian::Gaussian;pubfnsimulate_gmm_1d_data(count:usize,means:Vec<f64>,vars:Vec<f64>,mixture_weights:Vec<f64>)->Vec<f64>{assert_eq!(means.len(),vars.len());assert_eq!(means.len(),mixture_weights.len());letgmm_count=means.len();letmutgaussians=Vec::with_capacity(gmm_count);foriin0..gmm_count{// Create a gaussian with mean and vargaussians.push(Gaussian::new(means[i],vars[i]));}letmutrng=thread_rng();letmutout_samples=Vec::with_capacity(count);for_in0..count{// We'll write this part next}out_samples}

This is our function for generating a dataset with the properties of a GMM. We do some sensible length checking and then we fill a vector with Gaussians according to the incoming means and variances. Next we will be picking one of these Gaussians according to the mixture weights and taking a sample from it.

pubfnsimulate_gmm_1d_data(n:usize,means:Vec<f64>,vars:Vec<f64>,mixture_weights:Vec<f64>)->Vec<f64>{// Setting up the Gaussians above...letmutrng=thread_rng();letmutout_samples=Vec::with_capacity(count);for_in0..n{// Pick a gaussian from the mixture weightsletchosen_gaussian=gaussians[pick_gaussian_idx(&mixture_weights,&mutrng)];// Draw a sample from ituserand::dist::IndependentSample;letsample=chosen_gaussian.ind_sample(&mutrng);// Add to dataout_samples.push(sample);}out_samples}

And, just for good measure, here’s the helper function pick_gaussian_idx:

userand::Rng;fnpick_gaussian_idx<R:Rng>(unnorm_dist:&[f64],rng:&mutR)->usize{assert!(unnorm_dist.len()>0);// Get the sum of the unnormalized distributionletsum=unnorm_dist.iter().fold(0f64,|acc,&x|acc+x);// Sum must be positiveassert!(sum>0);// A random number between 0 and sumletrand=rng.gen_range(0.0f64,sum);letmutunnorm_pmf=0.0;for(i,p)inunnorm_dist.iter().enumerate(){// Add the current probability to the pmfunnorm_pmf+=*p;// Return i if rand falls in the correct intervalifrand<unnorm_pmf{returni;}}panic!("No random value was sampled!");}

This function implements a standard method for generating samples from a discrete distribution.

What is all of the above actually good for?

Using the code above we can generate some samples from a GMM. We can then train a GMM using rusty-machine on the generated samples. The hope is that - we should be able to recover the original parameters that generated the samples (the mean, variance, and mixture probabilities).

Luckily, rusty-machine makes training a GMM very easy!

userusty_machine::learning::gmm::GaussianMixtureModel;userusty_machine::prelude::*;// Number of Gaussians and Samplesletgmm_count=3;letcount=1000;// Parameters for our modelletmeans=vec![-3f64,0.,3.];letvars=vec![1f64,0.5,0.25];letweights=vec![0.5,0.25,0.25];// Simulate some data using the parameters aboveletsamples=simulate_gmm_1d_data(count,means,vars,weights);// Create a GMM with the same number of Gaussiansletmutgmm=GaussianMixtureModel::new(gmm_count);// Train the model on the samplesgmm.train(&Matrix::new(count,1,samples));println!("Means = {:?}",gmm.means());println!("Covs = {:?}",gmm.covariances());println!("Mix Weights = {:?}",gmm.mixture_weights());

Peeking under the hood (a grueling detour)

As I mentioned above it is common to use the Expectation Maximization Algorithm2 to train GMMs. I’m not going to go into too much detail here but will touch on the basic steps.

The EM algorithm is split into two parts - an E-step and an M-step.

For GMMs the E-Step consists of computing the posterior probability (called membership weights in this case) for each data step lying in each class. This means - roughly - given our current estimates of the parameters and the data we have - how likely is it for each data point i to be in subgroup k (for each i and k).

The M-Step involves using the posterior probabilities computed above with the data to compute new parameter estimates. We compute new updates for the mixture weights, the means, and the variances. These notes provide a very good introduction and explanation.

The current rusty-machine implementation is fairly basic but does allow for some type-based-niceness. A common issue with GMMs is numerical instability around computing covariance inverses. The rusty-machine implementation uses a CovOption enum which allows the user to specify how the covariance updates should be computed.

pubenumCovOption{/// The full covariance structure.Full,/// Adds a regularization constant to the covariance diagonal.Regularized(f64),/// Only the diagonal covariance structure.Diagonal,}

Though an exceedingly simple example this follows the general mandate of trying to keep rusty-machine’s models simple (and obvious) whilst fully customisable.

Remarks

Thanks for reading! Please give feedback on this post and rusty-machine. If you see some obvious improvements I’d love to hear them!

* - I have absolutely no idea if they actually are. But it sounds sort of right?

** - Of course it is also a good idea to test on real data and test out some other edge cases.

*** - Though, in my excitement to do machine learning I knowingly neglected some error handling among other things. :(