I know it’s a weird way to start a blog with a negative, but there was a wave of discussion in the last few days that I think serves as a good hook for some topics on which I’ve been thinking recently. It all started with a post in the Simply Stats blog by Jeff Leek on the caveats of using deep learning in the small sample size regime. In sum, he argues that when the sample size is small (which happens a lot in the bio domain), linear models with few parameters perform better than deep nets even with a modicum of layers and hidden units. He goes on to show that a very simple linear predictor, with top ten most informative features, performs better than a simple deep net when trying to classify zeros and ones in the MNIST dataset using only 80 or so samples. This prompted Andrew Beam to write a rebuttal in which a properly trained deep net was able to beat the simple linear model, even with very few training samples. This back-and-forth comes at a time where more and more researchers in biomedical informatics are adopting deep learning for various problems. Is the hype real or are linear models really all we need? The answer, as always, is that it depends. In this post, I want to visit use cases in machine learning where using deep learning does not really make sense as well as tackle preconceptions that I think prevent deep learning to be used effectively, especially for newcomers.

Breaking deep learning preconceptions

First, let’s tackle some preconceptions that I perceive most folks outside the field have that turn out to be half-truths. There’s two broad ones and one a bit more technical that I’m going to elaborate on. This is somewhat of an extension to Andrew Beam’s excellent “Misconceptions” section in his post.

Deep learning can really work on small sample sizes

Deep learning’s claim to fame was in a context with lots of data (remember that the first Google brain project was feeding lots of YouTube videos to a deep net), and ever since it has constantly been publicized as complex algorithms running in lots of data. Unfortunately, this big data/deep learning pair somehow translated into the converse as well: the myth that it cannot be used in the small sample regime. If you have just a few samples, tapping into a neural net with a high parameter-per-sample ratio may superficially seem like a sure road to overfitting. However, just considering sample size and dimensionality for a given problem, be it supervised or unsupervised, is sort of modeling the data in a vacuum, without any context. It is probably the case that you have data sources that are related to your problem, or that there’s a strong prior that a domain expert can provide, or that the data is structured in a very particular way (e.g. is encoded in a graph or image).

In all of these cases, there’s a chance deep learning can make sense as a method of choice – for example, you can encode useful representations of bigger, related datasets and use those representations in your problem. A classic illustration of this is common in natural language processing, where you can learn word embeddings on a large corpus like Wikipedia and then use those as embeddings in a smaller, narrower corpus for a supervised task. In the extreme, you can have a set of neural nets jointly learn a representation and an effective way to reuse the representation in small sets of samples. This is called one-shot learning and has been successfully applied in a number of fields with high-dimensional data including computer vision and drug discovery.

Deep learning is not the answer to everything

The second preconception I hear the most is the hype. Many yet-to-be practitioners expect deep nets to give them a mythical performance boost just because it worked in other fields. Others are inspired by impressive work in modeling and manipulating images, music, and language – three data types close to any human heart – and rush headfirst into the field by trying to train the latest GAN architecture. The hype is real in many ways. Deep learning has become an undeniable force in machine learning and an important tool in the arsenal of any data modeler. It’s popularity has brought forth essential frameworks such as TensorFlow and PyTorch that are incredibly useful even outside deep learning. It’s underdog to superstar origin story has inspired researchers to revisit other previously obscure methods like evolutionary strategies and reinforcement learning. But it’s not a panacea by any means. Aside from no-free-lunch considerations, deep learning models can be very nuanced and require careful and sometimes very expensive hyperparameter searches, tuning, and testing (much more on this later in the post). Besides, there are many cases where using deep learning just doesn’t make sense from a practical perspective and simpler models work much better.

Deep learning is more than .fit()

There is also an aspect of deep learning models that I see gets sort of lost in translation when coming from other fields of machine learning. Most tutorials and introductory material to deep learning describe these models as composed by hierarchically-connected layers of nodes where the first layer is the input and the last layer is the output and that you can train them using some form of stochastic gradient descent. After maybe some brief mentions on how stochastic gradient descent works and what backpropagation is, the bulk of the explanation focuses on the rich landscape of neural network types (convolutional, recurrent, etc.). The optimization methods themselves receive little additional attention, which is unfortunate since it’s likely that a big (if not the biggest) part of why deep learning works is because of those particular methods (check out, e.g. this post from Ferenc Huszár’s and this paper taken from that post), and knowing how to optimize their parameters and how to partition data to use them effectively is crucial to get good convergence in a reasonable amount of time.

Exactly why stochastic gradients matter so much is still unknown, but some clues are emerging here and there. One of my favorites is the interpretation of the methods as part of performing Bayesian inference. In essence, every time that you do some form of numerical optimization, you’re performing some Bayesian inference with particular assumptions and priors. Indeed, there’s a whole field, called probabilistic numerics, that has emerged from taking this view. Stochastic gradient descent is no different, and recent work suggests that the procedure is really a Markov chain that, under certain assumptions, has a stationary distribution that can be seen as a sort of variational approximation to the posterior. So when you stop your SGD and take the final parameters, you’re basically sampling from this approximate distribution. I found this idea to be illuminating, because the optimizer’s parameters (in this case, the learning rate) make so much more sense that way. As an example, as you increase the learning parameter of SGD the Markov chain becomes unstable until it finds wide local minima that samples a large area; that is, you increase the variance of procedure. On the other hand, if you decrease the learning parameter, the Markov chain slowly approximates narrower minima until it converges in a tight region; that is, you increase the bias for a certain region. Another parameter, the batch size in SGD, also controls what type of region the algorithm converges two: wider regions for small batches and sharper regions with larger batches.

This complexity means that optimizers of deep nets become first class citizens: they are a very central part of the model, every bit as important as the layer architecture. This doesn’t quite happen with many other models in machine learning. Linear models (even regularized ones, like the LASSO) and SVMs are convex optimization problems for which there is not as much nuance and really only one answer. That’s why folks that come from other fields and/or using tools like scikit-learn are puzzled when they don’t find a very simple API with a .fit() method (although there are some tools, like skflow, that attempt to bottle simple nets into a .fit() signature, I think it’s a bit misguided since the whole point of deep learning is its flexibility).

When not to use deep learning

So, when is deep learning not ideal for a task? From my perspective, these are the main scenarios where deep learning is more of a hinderance than a boon.