Exploring Computer Vision (Part II): Transfer Learning

Nathan Lintz -
February 22, 2016

Welcome back to our three part series on computer vision. In the previous post, we discussed convolutional neural networks (CNNs). This post will assume that you have a basic understanding of CNNs; we encourage you to reread the first post if you want a refresher on convolutional networks.

Introduction to Transfer Learning

When we start learning about programming, we tend to think of algorithms as “black boxes” that transform an input to an output. Sorting algorithms convert unordered sequences into ordered sequences, graphics algorithms convert matrices into pretty pictures, recommendation systems take a history of events and yield a prediction for the next best action, and so on. At first, it’s tempting to write long, monolithic functions to perform a single task, and this can be very valuable because now you have an automated solution! However, you can often optimize and re-use pieces of that monolithic solution. By composing a solution from reusable chunks, our algorithms can be applied in many different projects. Similarly, in the context of machine learning, given a good objective and training, we can reuse parts of a model trained in one domain to solve problems in different domains. This is called transfer learning.

Transfer Learning From CNNs

Based on the previous post, you’ve seen how stacking layers of convolution filters allows us to accurately classify the content of images. This raises the question – is it necessary to train a brand new CNN for every task? In traditional machine learning, you collect a new dataset and train a new algorithm to make predictions on those examples. This feels awfully similar to when we were learning how to program and didn’t realize that we could share parts of code in different algorithms. Which pieces of the current solution can be optimized and reused to compose new solutions?

Let’s say we have a CNN which can accurately classify images as containing either cats or dogs. This model contains a stack of convolutional kernels which learn features necessary for distinguishing between two animals (e.g., patterns of edges that represent hair, fur color, head shape, etc). Can the same features that distinguish cats from dogs be used to distinguish between other animals like horses and wombats? To test this hypothesis, we’d change our classifier to predict horses vs wombats and use the kernels from our pre trained model. This idea – that features from one model can be used to inform the classifier of another model – is the heart of transfer learning.

Visualization of transfer learning

You might still be skeptical about reusing kernels from one model to make another model. After all, if the model learns extremely specific features about cats and dogs, it will get completely confused when it sees a wombat. If your kernels/features are specific to cats and dogs, this could very well happen. That’s one reason why it’s so important to use large, diverse datasets, to encourage our models to learn generalizable (reusable) features about images.

Real World Example of Transfer Learning

If you’re still unsure about the power of transfer learning, let’s go through a concrete case study of how transfer learning can be used for image recognition tasks. There is a Kaggle competition sponsored by Yelp where you are given photos of restaurants and are asked to predict characteristics of that restaurant, such as whether or not it serves alcohol. Competitors have had success creating models from scratch on the provided training set. We trained a fresh CNN on the data and achieved a 66% accuracy rate. This seems pretty solid, but we wondered what would happen if we used transfer learning instead.

We decided to try to transfer features from Google’s CNN, inception, which they recently released to the public. In order to use this model for transfer learning, we took the training images, ran them through the inception network, and extracted the output of the network from the layer before the classifier. We then applied a new classifier, which predicted restaurant characteristics based on the features we extracted from inception. This approach increased our accuracy by nearly 7%, bringing us up to 73% accuracy, and putting us in the top 22% of contestants. Clearly, the transfer learning approach was a success. Since inception was trained on a massive dataset and has such a deep architecture, it has the capacity to extract general features about any image it sees.

Architecture of Google’s inception network

You don’t always need to train new models to solve new problems. Indeed, one strength of deep learning is that it allows us to reuse parts of one model to solve new problems without collecting a massive dataset for the new task.

Transfer Learning at indico

If you want to start tackling image recognition problems via transfer learning, indico exposes high-quality pre-trained convolutional features via the Custom Collections API and the Image Features API. The Custom Collection API enables you to quickly train a custom text or image classifier when you only have a small amount of data; the Image Features API transforms text or image samples into rich feature vectors. Under the hood, both APIs leverage a deep neural network that has been trained to classify image and text samples. Before you begin making predictions from your Custom Collection, the underlying model has already seen millions of training data examples. Since the model has already seen so much data ahead of time, you only have to send it a few new examples before it learns how to solve the classification task at hand.

A good way to understand transfer learning is to think of it like writing modular code. If you want to write fewer lines of code, you break your programs down into reusable parts. Similarly, if you want to solve classification problems and don’t have enough data or the capacity to train deep models, you can reuse parts of existing models. Next time you find yourself trying to solve a problem and only have a small amount of labeled data, we encourage you to try out our Custom Collection API or our Image Features API.