The proper use of Cross Validation

Post navigation

While teaching Data Science and Machine Learning to people coming into the industry (or just trying to get an exposure to), one of the first “tricky” theory subjects that come up is the one of why and how to properly use Cross Validation as a model assessment tool.

Let’s first quickly define what Cross Validation is, and then let’s dive into the most common gotchas when trying to apply it.

“Cross-validation is a procedure for model assessment, will the ultimate objetive of getting an estimation of how good model performance will be on unseen data.”

Given the above definition, it sounds a lot like the usual training/test split that you might utilise, when you train your model on a set of data, and then assess on the test (or hold-out) data set.

There are a few issues with using the humble (and straight forward) train/test split though.

First, when prototyping a model, I would like to expose the model to as much data as possible, with the widest possible range of values. By doing a train test split, we increase the risk of training on a biased sample (if the split is not random). We our training set might not approximate the whole data appropriately.

Second, by making one split, we increase the opportunity for simple chance to give us a wrong picture of the model performance. You can see this when you get a better test score than training score for your model. At the end of the day, we are getting a point estimate for performance, but no sense of the variability due to particular samples.

Given those two main caveats (not using all the possible data at our disposal and only getting a point estimate), we can think of ways of extending this procedure to make it more robust. And that’s were Cross-Validation comes in.

So, how does Cross Validation actually works?

Cross Validation (or more specifically, K-Fold Cross Validation), works by splitting our labeled data in K different folds (usually between 5 and 10). This K is also the number of distinct models that we will end up training.

Once we have the K folds, we will train a model on each combination of K-1 folds, using the remaining fold to test our model.

The output of this procedure will be K trained models and K error metrics on their test folds (for example accuracy or MSE). This list of error metrics provides both a point estimate (if you get the mean value of them) and an idea of how much variability we have (provided for example by getting the standard deviation or variance for the set of metrics).

The above provides a better picture of our model performance than just using a simple train/test split.

Is Cross Validation a free lunch for model assessment?

I would argue that it almost is. The main drawback here is that you will end up training K models instead of 1, which of course requires more computation.

But in practice training times usually fall between almost instantaneously, for simpler models on “not too big” datasets, or hours long, for more complex models on bigger datasets. In the first case, the difference between training one model vs 5 or 10, will be almost imperceptible. If the later case if closer to reality for the modeller, then it’s usually the case that you already have a way of “off loading” the training into the cloud, and it’s just a case of using more instances for training.

Finally, lets go through some common gotchas when just starting out with Cross Validation (CV, for short):

CV is NOT a performance optimisation technique, it’s a model assessment technique. Don’t try different values for K until you get the best mean error. 🙂

There’s a tradeoff between higher or lower values for K, please check this link. But in general don’t overthink it too much, default values like 5, 10, 15 are perfectly appropriate.

Once you are happy with the results you get from assessing your model, what do you do? Cross Validation doesn’t really “output” one model that you can use. But the assumption is that a given model specification will at least perform as well with more data, than with less data, so once you are happy with the model specification (the model type, ex: Logistic Regression, plus it’s hyper-parameter values), you should re-train it with ALL the available data, before shipping it out.