October 21, 2018

The Topology of Neural Networks, Part 1: The Generalization Problem

I gave a talk a few months ago at the Thompson-Scharlemann-Kirby conference about a theorem I recently proved about topological limitations on certain families of neural networks. Most of the talk was a description of how neural networks work in terms of more abstract mathematics than they’re typically described in, and I thought this was probably a good thing to write up in a blog post. I decided to split the post into two parts because it was getting quite long. So the first post will describe the general approach to defining Machine Learning models, and the second post will cover Neural Networks in particular.

I’ve found it surprisingly hard to learn about machine learning, including neural networks, from a mathematical perspective because most of the literature on the subject is (understandably) targeted towards future practitioners coming from an undergraduate computer science background, rather than towards mathematicians. It spends a lot of time carefully introducing basic mathematical concepts, goes very deep into subtle technical details and relies heavily on computer science notation/terminology with a bit of statistics terminology and notation thrown in. So even though many of the concepts are very closely related to ideas that we pure mathematicians know well, you still have to read through all the details to get to those concepts. It turns out it’s very hard to learn the notation and terminology of a subject without also starting from scratch (re)learning the concepts.

So I’ve been doing that in fits and starts for the last few years. When I started writing the Shape of Data blog, I focussed on extracting interesting undergraduate mathematics ideas from the very introductory ideas of the field because that’s as far as I had gotten into it. But after spending a few years writing that blog, then spending a few more years not writing but continuing to learn, I think I have a deep enough understanding to be able to extract some abstract ideas that are mathematically complex enough to be interesting to the audience of this blog. So that’s what this post is about.

At its core, machine learning is about making predictions, so we have to start by defining what that means. We’ll start with a problem called memorization that isn’t exactly what we’re looking for, but is pedagogically important.

The memorization problem: Given a finite set of points , find a function such that for each , .

The function is called a model and the interpretation of this problem is that each is a vector extracted from data about an instance of something that you want the model to make a prediction about. The is from data that you know before-hand. The is from the data that you want to predict. So for example, may be information about a person, and may be a way of encoding where they’ll click on a web page or how they’ll respond to a particular medical treatment.

The memorization problem is not a very interesting math problem because it’s either trivial and underspecified (if all the s are distinct) or impossible (if there are two points with the same and different s.) It’s also not a very good practical problem because it only allows you to make predictions about the data that you used to construct the function . So the real statement of the prediction problem is more subtle:

The generalization problem: Given a finite set of points called the training set, find a function such that given a second set called the evaluation set, for each , is “as close as possible” to .

This is a much better practical problem because you’re constructing the model based on the training set – examples of the data that you’ve seen in the past – but evaluating it based on a second set of points – data that you haven’t seen before, but want to make predictions about.

However, as a mathematical problem it’s even worse than the first one. In fact, it isn’t even really a mathematical problem because there’s no logical connection between what’s given and what you have to find.

So to make it a mathematical problem, you have to add in assumptions about how and are related. There are many ways to do this, and any algorithm or technique in machine learning/statistics that addresses this problem makes such assumptions, either implicitly or explicitly.

The approach that I’m going to describe below, which leads to the definition of neural networks, makes this assumption implicitly by restricting the family of functions from which can be chosen. You solve the memorization problem for on this restricted family and the resulting model is your candidate solution to the generalization problem.

Before we go into more detail about what this mean, lets return to the notion of “as close as possible”. We’ll make this precise by defining what’s called a cost function such that defines the “cost” of the difference between and . A common example is . Given a dataset , the cost of a given model is .

With this terminology, the memorization problem is to minimize , while the generalization problem is to minimize .

Note that the cost function is a function on the space of continuous functions . We will be interested in subspaces of , and we can define one by choosing a map . This is often called parameter space, and we will use the symbol .

The canonical example of this is linear regression. In the case where and , we define , and we’ll let be the coordinates of . Define to be the function that takes to the function , and use the difference-squared cost I mentioned above.

Now, given and , we get a cost function on which we can pull back to a function on and choose a point that minimizes this cost function. To determine how close this particular family of models comes to solving the generalization problem compared to other families of models, we evaluate .

The nice thing about this setup is that it gives us a relatively objective way to evaluate families of models. You can often find a family with a lower minimum for by increasing the dimension of and making the set of possible models more flexible. However, if you take this too far this will eventually increase for the model that minimizes . This is called the bias/variance tradeoff, and when you go too far it’s called overfitting.

There’s also a question of how you find the minimum in practice. For linear regression the cost function turns out to have a unique critical point which is the global minimum and there’s a closed form for the solution. However, for most model families you have to use an algorithm called gradient descent that discretely follows the gradient of the cost function until it finds a local minimum which may or may not be global.

So rather than just adding flexibility to a model family, the trick is to add the right kind of flexibility for a given dataset, i.e. in a way that minimizes the bias/variance tradeoff and reduces the number of spurious local minima. And this is where things get interesting from the perspective of geometry/topology since it becomes a question of how to characterize the different ways that a model family can be flexible, and how to connect these to properties of a dataset.

For example, the simplest way to make linear regression more flexible is to replace the line function with a polynomial of a fixed degree. However, this doesn’t turn out to be very practical in many cases because for higher-dimensional , the number of parameters goes up exponentially with the degree. So you end up with a lot of flexibility that is either redundant, or isn’t useful for your given dataset. One reason neural networks have become so popular is that they manage to be extremely flexible with relatively few parameters, at least compared to polynomials.

In the follow-up to this post, I’ll describe how a neural network defines a family of models, and I’ll outline my recent result about topological constraints on certain of these families.