Eli Bendersky's website - Machine Learninghttps://eli.thegreenplace.net/2018-06-07T05:34:00-07:00Minimal character-based LSTM implementation2018-06-07T05:34:00-07:002018-06-07T05:34:00-07:00Eli Benderskytag:eli.thegreenplace.net,2018-06-07:/2018/minimal-character-based-lstm-implementation/<p>Following up on <a class="reference external" href="https://eli.thegreenplace.net/2018/understanding-how-to-implement-a-character-based-rnn-language-model/">the earlier post</a>
deciphering a minimal vanilla RNN implementation, here I'd like to extend the
example to a simple LSTM model.</p>
<p>Once again, the idea is to combine a well-commented code sample
(<a class="reference external" href="https://github.com/eliben/deep-learning-samples/blob/master/min-char-rnn/min-char-lstm.py">available here</a>)
with some high-level diagrams and math to enable someone to
fully understand the …</p><p>Following up on <a class="reference external" href="https://eli.thegreenplace.net/2018/understanding-how-to-implement-a-character-based-rnn-language-model/">the earlier post</a>
deciphering a minimal vanilla RNN implementation, here I'd like to extend the
example to a simple LSTM model.</p>
<p>Once again, the idea is to combine a well-commented code sample
(<a class="reference external" href="https://github.com/eliben/deep-learning-samples/blob/master/min-char-rnn/min-char-lstm.py">available here</a>)
with some high-level diagrams and math to enable someone to
fully understand the code. The LSTM architecture presented herein is the
standard one originating from Hochreiter's and Schmidthuber's <a class="reference external" href="https://www.google.com/search?q=lstm+hochreiter">1997 paper</a>. It's described pretty much
everywhere; <a class="reference external" href="http://colah.github.io/posts/2015-08-Understanding-LSTMs/">Chris Olah's post</a> has particularly
nice diagrams and is worth reading.</p>
<div class="section" id="lstm-cell-structure">
<h2>LSTM cell structure</h2>
<p>From 30,000 feet, LSTMs look just like regular RNNs; there's a &quot;cell&quot; that has a
recurrent connection (output tied to input), and when trained this cell is
usually unrolled to some fixed length.</p>
<p>So we can take the basic RNN structure from the <a class="reference external" href="https://eli.thegreenplace.net/2018/understanding-how-to-implement-a-character-based-rnn-language-model">previous post</a>:</p>
<img alt="Basic RNN diagram" class="align-center" src="https://eli.thegreenplace.net/images/2018/rnnbasic.png" />
<p>LSTMs are a bit trickier because there are two recurrent connections; these
can be &quot;packed&quot; into a single vector <em>h</em>, so the above diagram still applies.
Here's how an LSTM cell looks inside:</p>
<img alt="LSTM cell" class="align-center" src="https://eli.thegreenplace.net/images/2018/lstm-cell.png" />
<p><em>x</em> is the input; <em>p</em> is the probabilities computed from the output <em>y</em> (these
symbols are named consistently with my earlier RNN post) and exit the cell at
the bottom purely due to topological convenience. The two memory vectors are <em>h</em>
and <em>c</em> - as mentioned earlier, they could be combined into a single vector, but
are shown here separately for clarity.</p>
<p>The main idea of LSTMs is to enable training of longer sequences by providing
a &quot;fast-path&quot; to back-propagate information farther down in memory. Hence the
<em>c</em> vector is not multiplied by any matrices on its path. The circle-in-circle
block means element-wise multiplication of two vectors; plus-in-square is
element-wise addition. The funny greek letter is the Sigmoid non-linearity:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/8b0db8368e8a617143fa6566f42c1e47cd833c9c.svg" style="height: 38px;" type="image/svg+xml">
\[\sigma(x) =\frac{1}{1+e^{-x}}\]</object>
<p>The only other block we haven't seen in the vanilla RNN diagram is the
colon-in-square in the bottom-left corner; this is simply the concatenation of
<em>h</em> and <em>x</em> into a single column vector. In addition, I've combined the
&quot;multiply by matrix <em>W</em>, then add bias <em>b</em>&quot; operation into a single rectantular
box to save on precious diagram space.</p>
<p>Here are the equations computed by a cell:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/c2cc966ba7ce8075317b87885bc9c432aafe2dba.svg" style="height: 249px;" type="image/svg+xml">
\[\begin{align*}
xh&amp;=x^{[t]}:h^{[t-1]}\\
f&amp;=\sigma(W_f\cdot xh+b_f)\\
i&amp;=\sigma(W_i\cdot xh+b_i)\\
o&amp;=\sigma(W_o\cdot xh+b_o)\\
cc&amp;=tanh(W_{cc}\cdot xh+b_{cc})\\
c^{[t]}&amp;=c^{[t-1]}\odot f +cc\odot i\\
h^{[t]}&amp;=tanh(c^{[t]})\odot o\\
y^{[t]}&amp;=W_{y}\cdot h^{[t]}+b_y\\
p^{[t]}&amp;=softmax(y^{[t]})\\
\end{align*}\]</object>
</div>
<div class="section" id="backpropagating-through-an-lstm-cell">
<h2>Backpropagating through an LSTM cell</h2>
<p>This works <em>exactly</em> like backprop through a vanilla RNN; we have to carefully
compute how the gradient flows through every node and make sure we properly
combine gradients at fork points. Most of the elements in the LSTM diagram are
familiar from the <a class="reference external" href="https://eli.thegreenplace.net/2018/understanding-how-to-implement-a-character-based-rnn-language-model">previous post</a>.
Let's briefly work through the new ones.</p>
<p>First, the Sigmoid function; it's an elementwise function, and computing its
derivative is very similar to the <em>tanh</em> function discussed in the previous
post. As usual, given <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/e9ef6bd037537d5fe08743736acadccc09e70b06.svg" style="height: 18px;" type="image/svg+xml">f=\sigma(k)</object>, from the chain rule we have the
following derivative w.r.t. some weight <em>w</em>:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/57e3f2cab3c9b46a03d763a2f73b83963a1cd500.svg" style="height: 39px;" type="image/svg+xml">
\[\frac{\partial f}{\partial w}=\frac{\partial \sigma(k)}{\partial k}\frac{\partial k}{\partial w}\]</object>
<p>To compute the derivative <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/8aa59f2f536b727cf97239b345ddcc98e41c2c91.svg" style="height: 26px;" type="image/svg+xml">\frac{\partial \sigma(k)}{\partiak k}</object>, we'll
use the ratio-derivative formula:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/9e006cf5e9f1f8ccac82ba1f2bcdabd710731756.svg" style="height: 42px;" type="image/svg+xml">
\[(\frac{f}{g})&#x27;=\frac{f&#x27;g-g&#x27;f}{g^2}\]</object>
<p>So:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/e3f7af782f52215e8326b389271709a440993984.svg" style="height: 44px;" type="image/svg+xml">
\[\sigma &#x27;(k)=\frac{e^{-k}}{(1+e^{-k})^2}\]</object>
<p>A clever way to express this is:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/eb1953be928287ff01ae23dfb4ff1cb2290854c9.svg" style="height: 20px;" type="image/svg+xml">
\[\sigma &#x27;(k)=\sigma(k)(1-\sigma(k))\]</object>
<p>Going back to the chain rule with <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/e9ef6bd037537d5fe08743736acadccc09e70b06.svg" style="height: 18px;" type="image/svg+xml">f=\sigma(k)</object>, we get:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/885829ecab969c96daed7f0df6e5864339ad9d8b.svg" style="height: 38px;" type="image/svg+xml">
\[\frac{\partial f}{\partial w}=f(1-f)\frac{\partial k}{\partial w}\]</object>
<p>The other new operation we'll have to find the derivative of is element-wise
multiplication. Let's say we have the column vectors <em>x</em>, <em>y</em> and <em>z</em>, each with
<em>m</em> rows, and we have <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/660b1e0dacc15aa3737b8170c3ecfdcbc6e77db4.svg" style="height: 18px;" type="image/svg+xml">z(x)=x\odot y</object>. Since <em>z</em> as a function of <em>x</em>
has <em>m</em> inputs and <em>m</em> outputs, its Jacobian has dimensions [m,m].</p>
<p><object class="valign-m6" data="https://eli.thegreenplace.net/images/math/0ab96cb4e5d8c6ba3ac8038fda07d518bbe1f388.svg" style="height: 18px;" type="image/svg+xml">D_{j}z_{i}</object> is the derivative of the i-th element of <em>z</em> w.r.t. the j-th
element of <em>x</em>. For <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/660b1e0dacc15aa3737b8170c3ecfdcbc6e77db4.svg" style="height: 18px;" type="image/svg+xml">z(x)=x\odot y</object> this is non-zero only
when <em>i</em> and <em>j</em> are equal, and in that case the derivative is <img alt="y_i" class="valign-m4" src="https://eli.thegreenplace.net/images/math/35c2ac2f82d0ff8f9011b596ed7e54bfcc55f471.png" style="height: 12px;" />.</p>
<p>Therefore, <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/e6631f3b13f877a8bb7b3b6a0c0d2ca110ecce23.svg" style="height: 18px;" type="image/svg+xml">Dz(x)</object> is a square matrix with the elements of <em>y</em> on the
diagonal and zeros elsewhere:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/2450b2e2a827054f5d292822ff292eaa63c77d1b.svg" style="height: 97px;" type="image/svg+xml">
\[Dz=\begin{bmatrix}
y_1 &amp; 0 &amp; \cdots &amp; 0 \\
0 &amp; y_2 &amp; \cdots &amp; 0 \\
\vdots &amp; \ddots &amp; \ddots &amp; \vdots \\
0 &amp; 0 &amp; \cdots &amp; y_m \\
\end{bmatrix}\]</object>
<p>If we want to backprop some loss <em>L</em> through this function, we get:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/48b17da284ae52bc4b9fdeb7b98b73f398bd4458.svg" style="height: 38px;" type="image/svg+xml">
\[\frac{\partial L}{\partial x}=\frac{\partial L}{\partial z}Dz\]</object>
<p>As <em>x</em> has <em>m</em> elements, the right-hand side of this equation multiplies a [1,m]
vector by a [m,m] matrix which is diagonal, resulting in element-wise multiplication
with the matrix's diagonal elements. In other words:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/e2a6c0742fb006e35e3001d3b3d33f78316fb1e8.svg" style="height: 38px;" type="image/svg+xml">
\[\frac{\partial L}{\partial x}=\frac{\partial L}{\partial z}\odot y\]</object>
<p>In code, it looks like this:</p>
<div class="highlight"><pre><span></span><span class="c1"># Assuming dz is the gradient of loss w.r.t. z; dz, y and dx are all</span>
<span class="c1"># column vectors.</span>
<span class="n">dx</span> <span class="o">=</span> <span class="n">dz</span> <span class="o">*</span> <span class="n">y</span>
</pre></div>
</div>
<div class="section" id="model-quality">
<h2>Model quality</h2>
<p>In the <a class="reference external" href="https://eli.thegreenplace.net/2018/understanding-how-to-implement-a-character-based-rnn-language-model/">post about min-char-rnn</a>,
we've seen that the vanilla RNN generates fairly low quality text:</p>
<blockquote>
one, my dred, roriny. qued bamp gond hilves non froange saws, to mold
his a work, you shirs larcs anverver strepule thunboler
muste, thum and cormed sightourd
so was rewa her besee pilman</blockquote>
<p>The LSTM's generated text quality is somewhat better when trained with roughtly
the same hyper-parameters:</p>
<blockquote>
the she, over is was besiving the fact to seramed for i said over he
will round, such when a where, &quot;i went of where stood it at eye heardul rrawed
only coside the showed had off with the refaurtoned</blockquote>
<p>I'm fairly sure that it can be made to perform even better with larger memory
vectors and more training data. That said, an even more advanced architecture
can be helpful here. Moreover, since this is a <em>character</em>-based model, to
really capture effects between words a few words apart we'll need a much deeper
LSTM (I'm unrolling to 16 characters we can only capture 2-3 words), and hence
much more training data and time.</p>
<p>Once again, the goal here is not to develop a state-of-the-art language model,
but to show a simple, comprehensible example of how and LSTM is implemented
end-to-end in Python code. <a class="reference external" href="https://github.com/eliben/deep-learning-samples/blob/master/min-char-rnn/min-char-lstm.py">The full code is here</a>
- please let me know if you find any issues with it or something still remains
unclear.</p>
</div>
Understanding how to implement a character-based RNN language model2018-05-25T05:20:00-07:002018-05-25T05:20:00-07:00Eli Benderskytag:eli.thegreenplace.net,2018-05-25:/2018/understanding-how-to-implement-a-character-based-rnn-language-model/<p>In <a class="reference external" href="https://gist.github.com/karpathy/d4dee566867f8291f086">a single gist</a>,
<a class="reference external" href="https://cs.stanford.edu/people/karpathy/">Andrej Karpathy</a> did something
truly impressive. In a little over 100 lines of Python - without relying on any
heavy-weight machine learning frameworks - he presents a fairly complete
implementation of training a character-based recurrent neural network (RNN)
language model; this includes the full backpropagation learning with Adagrad …</p><p>In <a class="reference external" href="https://gist.github.com/karpathy/d4dee566867f8291f086">a single gist</a>,
<a class="reference external" href="https://cs.stanford.edu/people/karpathy/">Andrej Karpathy</a> did something
truly impressive. In a little over 100 lines of Python - without relying on any
heavy-weight machine learning frameworks - he presents a fairly complete
implementation of training a character-based recurrent neural network (RNN)
language model; this includes the full backpropagation learning with Adagrad
optimization.</p>
<p>I love such minimal examples because they allow me to understand some topic in
full depth, connecting the math to the code and having a complete picture of how
everything works. In this post I want to present a companion explanation to
Karpathy's gist, showing the diagrams and math that hide in its Python code.</p>
<p>My own fork of the code <a class="reference external" href="https://github.com/eliben/deep-learning-samples/blob/master/min-char-rnn/min-char-rnn.py">is here</a>;
it's semantically equivalent to Karpathy's gist, but includes many more comments
and some debugging options. I won't reproduce the whole program here; instead,
the idea is that you'd go through the code while reading this article. The
diagrams, formulae and explanations here are complementary to the code comments.</p>
<div class="section" id="what-rnns-do">
<h2>What RNNs do</h2>
<p>I expect readers to have a basic idea of what RNN do and why they work well for
some problems. RNN are well-suited for problem domains where the input (and/or
output) is some sort of a sequence - time-series financial data, words or
sentences in natural language, speech, etc.</p>
<p>There is <em>a lot</em> of material about this online, and the basics
are easy to understand for anyone with even a bit of machine learning
background. However, there is not enough coherent material online about how RNNs
are implemented and trained - this is the goal of this post.</p>
</div>
<div class="section" id="character-based-rnn-language-model">
<h2>Character-based RNN language model</h2>
<p>The basic structure of <tt class="docutils literal"><span class="pre">min-char-rnn</span></tt> is represented by this recurrent
diagram, where <em>x</em> is the input vector (at time step <em>t</em>), <em>y</em> is the output
vector and <em>h</em> is the <em>state vector</em> kept inside the model.</p>
<img alt="Basic RNN diagram" class="align-center" src="https://eli.thegreenplace.net/images/2018/rnnbasic.png" />
<p>The line leaving and returning to the cell represents that the state is retained
between invocations of the network. When a new time step arrives, some things
are still the same (the weights inherent to the network, as we shall soon see)
but some things are different - <em>h</em> may have changed. Therefore, unlike
stateless NNs, <em>y</em> is not simply a function of <em>x</em>; in RNNs, identical <em>x</em>s can
produce different <em>y</em>s, because <em>y</em> is a function of <em>x</em> and <em>h</em>, and <em>h</em> can
change between steps.</p>
<p>The <em>character-based</em> part of the model's name means that every input vector
represents a single character (as opposed to, say, a word or part of an image).
<tt class="docutils literal"><span class="pre">min-char-rnn</span></tt> uses one-hot vectors to represent different characters.</p>
<p>A <em>language model</em> is a particular kind of machine learning algorithm that
learns the statistical structure of language by &quot;reading&quot; a large corpus of
text. This model can then reproduce authentic language segments - by predicting
the next character (or word, for word-based models) based on past characters.</p>
</div>
<div class="section" id="internal-structure-of-the-rnn-cell">
<h2>Internal-structure of the RNN cell</h2>
<p>Let's proceed by looking into the internal structure of the RNN cell in
<tt class="docutils literal"><span class="pre">min-char-rnn</span></tt>:</p>
<img alt="RNN cell for min-char-rnn" class="align-center" src="https://eli.thegreenplace.net/images/2018/min-char-rnn-cell.png" />
<ul class="simple">
<li>Bold-faced symbols in reddish color are the model's parameters, weights for
matrix multiplication and biases.</li>
<li>The state vector <em>h</em> is shown twice - once for its past value, and once for
its currently computed value. Whenever the RNN cell is invoked in sequence,
the last computed state <em>h</em> is passed in from the left.</li>
<li>In this diagram <em>y</em> is not the final answer of the cell - we compute a softmax
function on it to obtain <em>p</em> - the probabilities for output characters <a class="footnote-reference" href="#id7" id="id1">[1]</a>.
I'm using these symbols for consistency with the code of <tt class="docutils literal"><span class="pre">min-char-rnn</span></tt>,
though it would probably be more readable to flip the uses of <em>p</em> and <em>y</em>
(making <em>y</em> the actual output of the cell).</li>
</ul>
<p>Mathematically, this cell computes:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/886e94d526c2e538f1ba4414696ae9bf6618f0ff.svg" style="height: 82px;" type="image/svg+xml">
\[\begin{align*}
h^{[t]}&amp;=tanh(W_{hh}\cdot h^{[t-1]}+W_{xh}\cdot x^{[t]}+b_h)\\
y^{[t]}&amp;=W_{hy}\cdot h^{[t]}+b_y\\
p^{[t]}&amp;=softmax(y^{[t]})
\end{align*}\]</object>
</div>
<div class="section" id="learning-model-parameters-with-backpropagation">
<h2>Learning model parameters with backpropagation</h2>
<p>This section will examine how we can <em>learn</em> the parameters <em>W</em> and <em>b</em> for this
model. Mostly it's standard neural-network fare; we'll compute the derivatives
of all the steps involved and will then employ backpropagation to find a
parameter update based on some computed loss.</p>
<p>There's one serious issue we'll have to address first. Backpropagation is
usually defined on <em>acyclic</em> graphs, so it's not entirely clear how to apply it
to our RNN. Is <em>h</em> an input? An output? Both? In the original high-level diagram
of the RNN cell, <em>h</em> is both an input and an output - how can we compute the
gradient for it when we don't know its value yet? <a class="footnote-reference" href="#id8" id="id2">[2]</a></p>
<p>The way out of this conundrum is to <em>unroll</em> the RNN for a few steps. Note that
we're already doing this in the detailed diagram by distinguishing between
<object class="valign-0" data="https://eli.thegreenplace.net/images/math/057276c060e575533321773afb483e778e6a03f1.svg" style="height: 16px;" type="image/svg+xml">h^{[t]}</object> and <object class="valign-0" data="https://eli.thegreenplace.net/images/math/e4bc0503e20a8e6b82d9c86e10eb2c8e1dfe3471.svg" style="height: 16px;" type="image/svg+xml">h^{[t-1]}</object>. This makes every RNN cell <em>locally
acyclic</em>, which makes it possible to use backpropagation on it. This approach
has a cool-sounding name - <em>Backpropagation Through Time</em> (BPTT) - although it's
really the same as regular backpropagation.</p>
<p>Note that the architecture used here is called &quot;synced many-to-many&quot; in
Karpathy's <a class="reference external" href="http://karpathy.github.io/2015/05/21/rnn-effectiveness/">Unreasonable Effectiveness of RNNs post</a>, and it's useful for
training a simple char-based language model - we immediately observe the output
sequence produced by the model while reading the input. Similar unrolling can be
applied to other architectures, like encoder-decoder.</p>
<p>Here's our RNN again, unrolled for 3 steps:</p>
<img alt="Unrolled RNN diagram" class="align-center" src="https://eli.thegreenplace.net/images/2018/rnnunroll.png" />
<p>Now the same diagram, with the gradient flows depicted with orange-ish arrows:</p>
<img alt="Unrolled RNN diagram with gradient flow arrows shown" class="align-center" src="https://eli.thegreenplace.net/images/2018/rnnunrollgrad.png" />
<p>With this unrolling, we have everything we need to compute the actual weight
updates during learning, because when we want to compute the gradients through
step 2, we already have the incoming gradient <object class="valign-m5" data="https://eli.thegreenplace.net/images/math/41bad72882e3d266373df060e8ab3ce36a819679.svg" style="height: 18px;" type="image/svg+xml">\Delta h[2]</object>, and so on.</p>
<p>Do you now wonder what is <object class="valign-m5" data="https://eli.thegreenplace.net/images/math/0bdf36986644d54bc1bccf1410d2b9f0f86cf697.svg" style="height: 18px;" type="image/svg+xml">\Delta h[t]</object> for the final step at time <em>t</em>?</p>
<p>In some models, sequence lengths are fairly limited. For example, when we
translate a single sentence, the sequence length is rarely over a couple dozen
words; for such models we can fully unroll the RNN. The <em>h</em> state output of the
final step doesn't really &quot;go anywhere&quot;, and we assume its gradient is zero.
Similarly, the incoming state <em>h</em> for the first step is zero.</p>
<p>Other models work on potentially infinite sequence lengths, or sequences much
too long for unrolling. The language model in <tt class="docutils literal"><span class="pre">min-char-rnn</span></tt> is a good
example, because it can theoretically ingest and emit text of any length. For
these models we'll perform <em>truncated</em> BPTT, by just assuming that the influence
of the current state extends only <em>N</em> steps into the future. We'll then unroll
the model <em>N</em> times and assume that <object class="valign-m5" data="https://eli.thegreenplace.net/images/math/9851a3637afe3f6d70466ac3a1d1c104935647fd.svg" style="height: 18px;" type="image/svg+xml">\Delta h[N]</object> is zero. Although it
really isn't, for a large enough <em>N</em> this is a fairly safe assumption. RNNs are
hard to train on very long sequences for other reasons, anyway (we'll touch upon
this point again towards the end of the post).</p>
<p>Finally, it's important to remember that although we unroll the RNN cells, all
parameters (weights, biases) are <em>shared</em>. This plays an important part in
ensuring <em>translation invariance</em> for the models - patterns learned in one place
apply to another place <a class="footnote-reference" href="#id9" id="id3">[3]</a>. It leaves the question of how to update the
weights, since we compute gradients for them separately in each step. The answer
is very simple - just add them up. This is similar to other cases where the
output of a cell branches off in two directions - when gradients are computed,
their values are added up along the branches - this is just the basic chain rule
in action.</p>
<p>We now have all the necessary background to understand how an RNN learns. What
remains before looking at the code is figuring out how the gradients propagate
<em>inside</em> the cell; in other words, the derivatives of each operation comprising
the cell.</p>
</div>
<div class="section" id="flowing-the-gradient-inside-an-rnn-cell">
<h2>Flowing the gradient inside an RNN cell</h2>
<p>As we saw above, the formulae for computing the cell's output from its inputs
are:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/886e94d526c2e538f1ba4414696ae9bf6618f0ff.svg" style="height: 82px;" type="image/svg+xml">
\[\begin{align*}
h^{[t]}&amp;=tanh(W_{hh}\cdot h^{[t-1]}+W_{xh}\cdot x^{[t]}+b_h)\\
y^{[t]}&amp;=W_{hy}\cdot h^{[t]}+b_y\\
p^{[t]}&amp;=softmax(y^{[t]})
\end{align*}\]</object>
<p>To be able to learn weights, we have to find the derivatives of the cell's
output w.r.t. the weights. The full backpropagation process was
explained <a class="reference external" href="http://eli.thegreenplace.net/2016/the-chain-rule-of-calculus/">in this post</a>, so here is
only a brief refresher.</p>
<p>Recall that <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/f75a9c33c546d725557a4d452769bfd8fbb6cc22.svg" style="height: 20px;" type="image/svg+xml">p^{[t]}</object> is the predicted output; we compare it with the
&quot;real&quot; output (<object class="valign-0" data="https://eli.thegreenplace.net/images/math/e44181afdf5e5f0f8ad4379f7d5f3ff924379c82.svg" style="height: 16px;" type="image/svg+xml">r^{[t]}</object>) during learning, to find the loss (error):</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/788a210d4bab7831a28e0ae7713ff9c1cd5aef12.svg" style="height: 22px;" type="image/svg+xml">
\[L=L(p^{[t]}, r^{[t]})\]</object>
<p>To perform a gradient descent update, we'll need to find
<object class="valign-m7" data="https://eli.thegreenplace.net/images/math/c9e2c4ffca9564929c45a5244c7fb064465ab005.svg" style="height: 24px;" type="image/svg+xml">\frac{\partial L}{\partial w}</object>, for every weight value <em>w</em>. To do this,
we'll have to:</p>
<ol class="arabic simple">
<li>Find the &quot;local&quot; gradients for every mathematical operation leading from
<em>w</em> to <em>L</em>.</li>
<li>Use the chain rule to propagate the error backwards through these local
gradients until we find <object class="valign-m7" data="https://eli.thegreenplace.net/images/math/c9e2c4ffca9564929c45a5244c7fb064465ab005.svg" style="height: 24px;" type="image/svg+xml">\frac{\partial L}{\partial w}</object>.</li>
</ol>
<p>We start by formulating the chain rule to compute
<object class="valign-m7" data="https://eli.thegreenplace.net/images/math/c9e2c4ffca9564929c45a5244c7fb064465ab005.svg" style="height: 24px;" type="image/svg+xml">\frac{\partial L}{\partial w}</object>:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/45ad1052f6c6b78265143f4d41f2f12f1714ebfb.svg" style="height: 45px;" type="image/svg+xml">
\[\frac{\partial L}{\partial w}=\frac{\partial L}{\partial p^{[t]}}\frac{\partial p^{[t]}}{\partial w}\]</object>
<p>Next comes:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/f7a68f105f4e7483f2781a7bebeaad0ce659bf06.svg" style="height: 45px;" type="image/svg+xml">
\[\frac{\partial p^{[t]}}{\partial w}=\frac{\partial softmax}{\partial y^{[t]}}\frac{\partial y^{[t]}}{\partial w}\]</object>
<p>Let's say the weight <em>w</em> we're interested in is part of <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/5b9174fc1cf8afbecdab52326985d41be6fbc2c8.svg" style="height: 15px;" type="image/svg+xml">W_{hh}</object>, so we
have to propagate some more:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/53bcf70971d45064242463ddfad70e3ba6fb0ec9.svg" style="height: 42px;" type="image/svg+xml">
\[\frac{\partial y^{[t]}}{\partial w}=\frac{\partial y^{[t]}}{\partial h^{[t]}}\frac{\partial h^{[t]}}{\partial w}\]</object>
<p>We'll then proceed to propagate through the <em>tanh</em> function, bias addition and
finally the multiplication by <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/5b9174fc1cf8afbecdab52326985d41be6fbc2c8.svg" style="height: 15px;" type="image/svg+xml">W_{hh}</object>, for which the derivative by <em>w</em> is
computed directly without further chaining.</p>
<p>Let's now see how to compute all the relevant local gradients.</p>
</div>
<div class="section" id="cross-entropy-loss-gradient">
<h2>Cross-entropy loss gradient</h2>
<p>We'll start with the derivative of the loss function, which is cross-entropy in
the <tt class="docutils literal"><span class="pre">min-char-rnn</span></tt> model. I went through a detailed derivation of the gradient
of softmax followed by cross-entropy in <a class="reference external" href="http://eli.thegreenplace.net/2016/the-softmax-function-and-its-derivative">this post</a>;
here is only a brief recap:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/b26f68a12667ba254facf9815252f52ebf2238d9.svg" style="height: 38px;" type="image/svg+xml">
\[xent(p,q)=-\sum_{k}p(k)log(q(k))\]</object>
<p>Re-formulating this for our specific case, the loss is a function of
<object class="valign-m4" data="https://eli.thegreenplace.net/images/math/f75a9c33c546d725557a4d452769bfd8fbb6cc22.svg" style="height: 20px;" type="image/svg+xml">p^{[t]}</object>, assuming the &quot;real&quot; class <em>r</em> is constant for every training
example:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/9ff2ef0e3dbe188129b93dddeb12759fdf909bcb.svg" style="height: 39px;" type="image/svg+xml">
\[L(p^{[t]})=-\sum_{k}r(k)log(p^{[t]}(k))\]</object>
<p>Since inputs and outputs to the cell are 1-hot encoded, let's just use <em>r</em> to
denote the index where <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/20aea9dc9718c5f2b3e11b3ebec11518202f0af1.svg" style="height: 18px;" type="image/svg+xml">r(k)</object> is non-zero. Then the Jacobian of <em>L</em> is
only non-zero at index <em>r</em> and its value there is <object class="valign-m11" data="https://eli.thegreenplace.net/images/math/c4efb22a708d798abd641a16679976b8829f500d.svg" style="height: 27px;" type="image/svg+xml">-\frac{1}{p^{[t]}}(r)</object>.</p>
</div>
<div class="section" id="softmax-gradient">
<h2>Softmax gradient</h2>
<p>A detailed computation of the gradient for the softmax function was also
presented in <a class="reference external" href="http://eli.thegreenplace.net/2016/the-softmax-function-and-its-derivative">this post</a>.
For <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/a197fbc6c2f0e9e1d6b4c51c6fca2756927a3055.svg" style="height: 18px;" type="image/svg+xml">S(y)</object> being the softmax of <em>y</em>, the Jacobian is:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/87fbe94e6b409d31b512cb7a4581c24907d4dd4a.svg" style="height: 42px;" type="image/svg+xml">
\[D_{j}S_{i}=\frac{\partial S_i}{\partial y_j}=S_{i}(\delta_{ij}-S_j)\]</object>
</div>
<div class="section" id="fully-connected-layer-gradient">
<h2>Fully-connected layer gradient</h2>
<p>Next on our path backwards is:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/ead3bdf11cf41b04164a83008db4a7dd0db5a074.svg" style="height: 24px;" type="image/svg+xml">
\[y^{[t]}&amp;=W_{hy}\cdot h^{[t]}+b_y\]</object>
<p>From my earlier <a class="reference external" href="http://eli.thegreenplace.net/2018/backpropagation-through-a-fully-connected-layer/">post on backpropagating through a fully-connected layer</a>,
we know that <object class="valign-m9" data="https://eli.thegreenplace.net/images/math/413d530fbd3e019cc3f49aec6e8f7cb7a8f0c622.svg" style="height: 29px;" type="image/svg+xml">\frac{\partial y^{[t]}}{\partial h^{[t]}}=W_{hy}</object>. But
that's not all; note that on the forward pass <object class="valign-0" data="https://eli.thegreenplace.net/images/math/057276c060e575533321773afb483e778e6a03f1.svg" style="height: 16px;" type="image/svg+xml">h^{[t]}</object> splits in two -
one edge goes into the fully-connected layer, another goes to the next RNN cell
as the state. When we backpropagate the loss gradient to <object class="valign-0" data="https://eli.thegreenplace.net/images/math/057276c060e575533321773afb483e778e6a03f1.svg" style="height: 16px;" type="image/svg+xml">h^{[t]}</object>, we
have to take both edges into account; more specifically, we have to <em>add</em> the
gradients along the two edges. This leads to the following backpropagation
equation:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/e7d5afe050e8e2f3f4b867ae4d9eb510fbe2e583.svg" style="height: 45px;" type="image/svg+xml">
\[\frac{\partial L}{\partial h^{[t]}} =
\frac{\partial y^{[t]}}{\partial h^{[t]}}\frac{\partial L}{\partial y^{[t]}}+\frac{\partial L}{\partial h^{[t+1]}}\frac{\partial h^{[t+1]}}{\partial h^{[t]}}
=W_{hy}\cdot \frac{\partial L}{\partial y^{[t]}}+\frac{\partial L}{\partial h^{[t+1]}}\frac{\partial h^{[t+1]}}{\partial h^{[t]}}\]</object>
<p>In addition, note that this layer already has model parameters that need to be
learned - <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/20c1b19b6e71072b92080b2eb00b5b99123cf057.svg" style="height: 18px;" type="image/svg+xml">W_{hy}</object> and <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/9bd872acdafb9ea752b3ba10b2670499cb65469f.svg" style="height: 19px;" type="image/svg+xml">b_y</object> - a &quot;final&quot; destination for
backpropagation. Please refer to my fully-connected layer backpropagation post
to see how the gradients for these are computed.</p>
</div>
<div class="section" id="gradient-of-tanh">
<h2>Gradient of tanh</h2>
<p>The vector <object class="valign-0" data="https://eli.thegreenplace.net/images/math/057276c060e575533321773afb483e778e6a03f1.svg" style="height: 16px;" type="image/svg+xml">h^{[t]}</object> is produced by applying a hyperbolic tangent
nonlinearity to another fully connected layer.</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/4ce55619f9cd4d96083ec3dadf303cc83a426543.svg" style="height: 22px;" type="image/svg+xml">
\[h^{[t]}&amp;=tanh(W_{hh}\cdot h^{[t-1]}+W_{xh}\cdot x^{[t]}+b_h)\]</object>
<p>To get to the model parameters <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/5b9174fc1cf8afbecdab52326985d41be6fbc2c8.svg" style="height: 15px;" type="image/svg+xml">W_{hh}</object>, <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/4ee22236c608ad6f49adc4807465b6e6896092ec.svg" style="height: 15px;" type="image/svg+xml">W_{xh}</object> and <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/14e5d8599f43750d0cf9dda2d90b085c69079049.svg" style="height: 16px;" type="image/svg+xml">b_h</object>,
we have to first backpropagate the loss gradient through <em>tanh</em>. <em>tanh</em> is a
scalar function; when it's applied to a vector we apply it in <em>element-wise</em>
fashion to every element in the vector independently, and collect the results in
a similarly-shaped result vector.</p>
<p>Its mathematical definition is:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/326a49518bfe326c6be2de37838971407fa5175d.svg" style="height: 39px;" type="image/svg+xml">
\[tanh(x)=\frac{e^x-e^{-x}}{e^x+e^{-x}}\]</object>
<p>To find the derivative of this function, we'll use the formula for deriving
a ratio:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/9e006cf5e9f1f8ccac82ba1f2bcdabd710731756.svg" style="height: 42px;" type="image/svg+xml">
\[(\frac{f}{g})&#x27;=\frac{f&#x27;g-g&#x27;f}{g^2}\]</object>
<p>So:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/dc73b540394a92b823ec3eaabe4d02a7735f146f.svg" style="height: 43px;" type="image/svg+xml">
\[tanh&#x27;(x)=\frac{(e^x+e^{-x})(e^x+e^{-x})-(e^x-e^{-x})(e^x-e^{-x})}{(e^x+e^{-x})^2}=1-(tanh(x))^2\]</object>
<p>Just like for softmax, it turns out that there's a convenient way to express the
derivative of <em>tanh</em> in terms of <em>tanh</em> itself. When we apply the chain rule to
derivatives of <em>tanh</em>, for example: <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/7f748d5017b1817f6d3912d339e85871b81d93b4.svg" style="height: 18px;" type="image/svg+xml">h=tanh(k)</object> where <em>k</em> is a function of
<em>w</em>. We get:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/9137818273fdbac7d5dc0e05df4fcf3f8cb7ea9d.svg" style="height: 39px;" type="image/svg+xml">
\[\frac{\partial h}{\partial w}=\frac{\partial tanh(k)}{\partial k}\frac{\partial k}{\partial w}=(1-h^2)\frac{\partial k}{\partial w}\]</object>
<p>In our case <em>k(w)</em> is a fully-connected layer; to find its derivatives w.r.t.
the weight matrices and bias, please refer to the <a class="reference external" href="http://eli.thegreenplace.net/2018/backpropagation-through-a-fully-connected-layer/">backpropagation through a
fully-connected layer post</a>.</p>
</div>
<div class="section" id="learning-model-parameters-with-adagrad">
<h2>Learning model parameters with Adagrad</h2>
<p>We've just went through all the major parts of the RNN cell and computed local
gradients. Armed with these formulae and the chain rule, it should be possible
to understand how the <tt class="docutils literal"><span class="pre">min-char-rnn</span></tt> code flows the loss gradient backwards.
But that's not the end of the story; once we have the loss derivatives w.r.t. to
some model parameter, how do we update this parameter?</p>
<p>The most straightforward way to do this would be using the gradient descent
algorithm, with some constant learning rate. <a class="reference external" href="http://eli.thegreenplace.net/2016/understanding-gradient-descent/">I've written about gradient
descent</a> in
the past - please take a look for a refresher.</p>
<p>Most real-world learning is done with more advanced algorithms these days,
however. One such algorithm is called Adagrad, <a class="reference external" href="http://jmlr.org/papers/v12/duchi11a.html">proposed in 2011</a> by some experts in mathematical
optimization. <tt class="docutils literal"><span class="pre">min-char-rnn</span></tt> happens to use Adagrad, so here is a simplified
explanation of how it works.</p>
<p>The main idea is to adjust the learning rate separately per parameter, because
in practice some parameters change much more often than others. This could be
due to rare examples in the training data set that affect a parameter that's
not often affected; we'd like to amplify these changes because they are rare,
and dampen changes to parameters that change often.</p>
<p>Therefore the Adagrad algorithm works as follows:</p>
<div class="highlight"><pre><span></span><span class="c1"># Same shape as the parameter array x</span>
<span class="n">memory</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">while</span> <span class="bp">True</span><span class="p">:</span>
<span class="n">dx</span> <span class="o">=</span> <span class="n">compute_grad</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="c1"># Elementwise: each memory element gets the corresponding dx^2 added to it.</span>
<span class="n">memory</span> <span class="o">+=</span> <span class="n">dx</span> <span class="o">*</span> <span class="n">dx</span>
<span class="c1"># The actual parameter update for this step. Note how the learning rate is</span>
<span class="c1"># modified by the memory. epsilon is some very small number to avoid dividing</span>
<span class="c1"># by 0.</span>
<span class="n">x</span> <span class="o">-=</span> <span class="n">learning_rate</span> <span class="o">*</span> <span class="n">dx</span> <span class="o">/</span> <span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">memory</span><span class="p">)</span> <span class="o">+</span> <span class="n">epsilon</span><span class="p">)</span>
</pre></div>
<p>If a given element in <tt class="docutils literal">dx</tt> was updated significantly in the past, its
corresponding <tt class="docutils literal">memory</tt> element will grow and thus the learning rate is
effectively decreased.</p>
</div>
<div class="section" id="gradient-clipping">
<h2>Gradient clipping</h2>
<p>If we unroll the RNN cell 10 times, the gradient will be multiplied by
<object class="valign-m3" data="https://eli.thegreenplace.net/images/math/5b9174fc1cf8afbecdab52326985d41be6fbc2c8.svg" style="height: 15px;" type="image/svg+xml">W_{hh}</object> ten times on its way from the last cell to the first. For some
structures of <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/5b9174fc1cf8afbecdab52326985d41be6fbc2c8.svg" style="height: 15px;" type="image/svg+xml">W_{hh}</object>, this may lead to an &quot;exploding gradient&quot; effect
where the value keeps growing <a class="footnote-reference" href="#id10" id="id5">[4]</a>.</p>
<p>To mitigate this, <tt class="docutils literal"><span class="pre">min-char-rnn</span></tt> uses the <em>gradient clipping</em> trick. Whenever
the gradients are updated, they are &quot;clipped&quot; to some reasonable range (like -5
to 5) so they will never get out of this range. This method is crude, but it
works reasonably well for training RNNs.</p>
<p>The flip side problem of <em>vanishing gradient</em> (wherein the gradients keep
getting smaller with each step) is much harder to solve, and usually requires
more advanced recurrent NN architectures.</p>
</div>
<div class="section" id="min-char-rnn-model-quality">
<h2>min-char-rnn model quality</h2>
<p>While <tt class="docutils literal"><span class="pre">min-char-rnn</span></tt> is a complete RNN implementation that manages to learn,
it's not really good enough for learning a reasonable model for the English
language. The model is too simple for this, and suffers seriously from the
vanishing gradient problem.</p>
<p>For example, when training a 16-step unrolled model on a corpus of Sherlock
Holmes books, it produces the following text after 60,000 iterations (learning
on about a MiB of text):</p>
<blockquote>
one, my dred, roriny. qued bamp gond hilves non froange saws, to mold
his a work, you shirs larcs anverver strepule thunboler
muste, thum and cormed sightourd
so was rewa her besee pilman</blockquote>
<p>It's not complete gibberish, but not really English either. Just for fun, I
wrote a simple <a class="reference external" href="https://github.com/eliben/deep-learning-samples/blob/master/min-char-rnn/markov-model.py">Markov chain generator</a>
and trained it on the same text with a 4-character state. Here's a sample of its
output:</p>
<blockquote>
though throughted with to taken as when it diabolice, and intered the
stairhead, the stood initions of indeed, as burst, his mr. holmes' room,
and now i fellows. the stable. he retails arm</blockquote>
<p>Which, you'll admit, is quite a bit better than our &quot;fancy&quot; deep learning
approach! And it was much faster to train too...</p>
<p>To have a better chance of learning a good model, we'll need a more advanced
architecture like LSTM. LSTMs employ a bunch of tricks to preserve long-term
dependencies through the cells and can learn much better language models. For
example, Andrej Karpathy's char-rnn model from the <a class="reference external" href="http://karpathy.github.io/2015/05/21/rnn-effectiveness/">Unreasonable Effectiveness
of RNNs post</a> is a
multi-layer LSTM, and it can learn fairly nice models for a varied set of
domains, ranging from Shakespeare sonnets to C code snippets in the Linux
kernel.</p>
</div>
<div class="section" id="conclusion">
<h2>Conclusion</h2>
<p>The goal of this post wasn't to develop a very good RNN model; rather, it was to
explain in detail the math behind training a simple RNN. More advanced RNN
architerctures like LSTM are somewhat more complicated, but all the core ideas
are very similar and this post should be helpful in nailing the basics.</p>
<p><em>Update:</em> <a class="reference external" href="https://eli.thegreenplace.net/2018/minimal-character-based-lstm-implementation/">An extension of this post to LSTMs</a>.</p>
<hr class="docutils" />
<table class="docutils footnote" frame="void" id="id7" rules="none">
<colgroup><col class="label" /><col /></colgroup>
<tbody valign="top">
<tr><td class="label"><a class="fn-backref" href="#id1">[1]</a></td><td><p class="first">Computing a softmax makes sense because <em>x</em> is encoded with one-hot over
a vocabulary-sized vector, meaning there's a 1 in the position of the
letter it represents with 0s in all other positions. For example, is we
only care about the 26 lower-case alphabet letters, <em>x</em> could be a
26-element vector. To represent 'a' it would have 1 in position 0 and
zeros elsewhere; to represent 'd' it would have 1 in position 3 and zeros
elsewhere.</p>
<p class="last">The output <em>p</em> here models what the RNN cell thinks the next generated
character should be. Using softmax, it would have probabilities for each
character in the corresponding position, all of them properly summing up
to 1.</p>
</td></tr>
</tbody>
</table>
<table class="docutils footnote" frame="void" id="id8" rules="none">
<colgroup><col class="label" /><col /></colgroup>
<tbody valign="top">
<tr><td class="label"><a class="fn-backref" href="#id2">[2]</a></td><td><p class="first">A slightly more technical explanation: to compute the gradient for the
error w.r.t. weights in the typical backpropagation flow, we'll need
input gradients for <object class="valign-m5" data="https://eli.thegreenplace.net/images/math/0ede500c5edc819b5f962923f98724936ef9d593.svg" style="height: 18px;" type="image/svg+xml">p[t]</object> and <object class="valign-m5" data="https://eli.thegreenplace.net/images/math/897ad2ab624c79d6dcb687ad28f7a3767a76712c.svg" style="height: 18px;" type="image/svg+xml">h[t]</object>. Then, when learning
happens we use the measured error and propagate it backwards. But what is
the measured error for <object class="valign-m5" data="https://eli.thegreenplace.net/images/math/897ad2ab624c79d6dcb687ad28f7a3767a76712c.svg" style="height: 18px;" type="image/svg+xml">h[t]</object>? We don't know it before we compute
the error of the next iteration, and so on - a bit of a chicken-egg
problem.</p>
<p class="last">Unrolling/BPTT helps approximate a solution for this issue. An
alternative solution is to use <em>forward-mode</em> gradient propagation
instead, with an algorithm called RTRL (Real Time Recurrent Learning).
This algorithm works well but has a high computational cost compared to
BPTT. I'd love to explore this topic in more depth, as it ties into the
difference between forward-mode and reverse-mode auto differentiation.
But that would be a topic for another post.</p>
</td></tr>
</tbody>
</table>
<table class="docutils footnote" frame="void" id="id9" rules="none">
<colgroup><col class="label" /><col /></colgroup>
<tbody valign="top">
<tr><td class="label"><a class="fn-backref" href="#id3">[3]</a></td><td>This is similar to convolutional networks, where the convolution filter
weights are reused many times when processing a much larger input. In
such models the invariance is <em>spatial</em>; in sequence models the
invariance is <em>temporal</em>. In fact, space vs. time in models is just a
matter of convention, and it turns out that 1D convolutional models
perform very well on some sequence tasks!</td></tr>
</tbody>
</table>
<table class="docutils footnote" frame="void" id="id10" rules="none">
<colgroup><col class="label" /><col /></colgroup>
<tbody valign="top">
<tr><td class="label"><a class="fn-backref" href="#id5">[4]</a></td><td><p class="first">An easy way to think about it is to imagine some initial value <em>v</em>,
multiplied by another value <em>c</em> many times. We get <object class="valign-0" data="https://eli.thegreenplace.net/images/math/f24f59611d0e1f3043785fe772138687cfd6da97.svg" style="height: 15px;" type="image/svg+xml">vc^N</object> for <em>N</em>
multiplications. If <em>c</em> is larger than 1, it means the result will keep
growing with each multiplication. How quickly will depend on the actual
value of <em>c</em>, but this is basically an exponential runoff. We actually
care about the absolute value of <em>c</em>, of course, since runoff is equally
bad in the positive or negative direction.</p>
<p class="last">Similarly with the absolute value of <em>c</em> smaller than 1, we'll get a
&quot;vanishing&quot; effect since the result will keep getting smaller with each
iteration.</p>
</td></tr>
</tbody>
</table>
</div>
Backpropagation through a fully-connected layer2018-05-22T05:47:00-07:002018-05-22T05:47:00-07:00Eli Benderskytag:eli.thegreenplace.net,2018-05-22:/2018/backpropagation-through-a-fully-connected-layer/<p>The goal of this post is to show the math of backpropagating a derivative for a
fully-connected (FC) neural network layer consisting of matrix multiplication
and bias addition. I have briefly mentioned this in an <a class="reference external" href="http://eli.thegreenplace.net/2016/the-softmax-function-and-its-derivative">earlier post dedicated
to Softmax</a>,
but here I want to give some more attention to …</p><p>The goal of this post is to show the math of backpropagating a derivative for a
fully-connected (FC) neural network layer consisting of matrix multiplication
and bias addition. I have briefly mentioned this in an <a class="reference external" href="http://eli.thegreenplace.net/2016/the-softmax-function-and-its-derivative">earlier post dedicated
to Softmax</a>,
but here I want to give some more attention to FC layers specifically.</p>
<p>Here is a fully-connected layer for input vectors with <em>N</em> elements, producing
output vectors with <em>T</em> elements:</p>
<img alt="Diagram of a fully connected layer" class="align-center" src="https://eli.thegreenplace.net/images/2018/fclayer.png" />
<p>As a formula, we can write:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/0f980ab4c97ad86b4d0a15ede6e9c05901323702.svg" style="height: 17px;" type="image/svg+xml">
\[y=Wx+b\]</object>
<p>Presumably, this layer is part of a network that ends up computing some loss
<em>L</em>. We'll assume we already have the derivative of the loss w.r.t. the output
of the layer <object class="valign-m9" data="https://eli.thegreenplace.net/images/math/9a5154f5e8d64cc77db745d8d3baa723bc6df829.svg" style="height: 26px;" type="image/svg+xml">\frac{\partial{L}}{\partial{y}}</object>.</p>
<p>We'll be interested in two other derivatives:
<object class="valign-m7" data="https://eli.thegreenplace.net/images/math/33d2709b664fdd69317758b433b61b13c1cdc62f.svg" style="height: 24px;" type="image/svg+xml">\frac{\partial{L}}{\partial{W}}</object> and
<object class="valign-m7" data="https://eli.thegreenplace.net/images/math/5f12a50803653cf2ee02135944343ec70506d31c.svg" style="height: 24px;" type="image/svg+xml">\frac{\partial{L}}{\partial{b}}</object>.</p>
<div class="section" id="jacobians-and-the-chain-rule">
<h2>Jacobians and the chain rule</h2>
<p>As a reminder from <a class="reference external" href="http://eli.thegreenplace.net/2016/the-chain-rule-of-calculus">The Chain Rule of Calculus</a>,
we're dealing with functions that map from <em>n</em> dimensions to <em>m</em> dimensions:
<img alt="f:\mathbb{R}^{n} \to \mathbb{R}^{m}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/13f219789047343729036279bb11630db317d98d.png" style="height: 16px;" />. We'll consider the outputs of <em>f</em>
to be numbered from 1 to <em>m</em> as <img alt="f_1,f_2 \dots f_m" class="valign-m4" src="https://eli.thegreenplace.net/images/math/93b446c5209263534d09d617bbede21101d6536e.png" style="height: 16px;" />. For each such
<img alt="f_i" class="valign-m4" src="https://eli.thegreenplace.net/images/math/68bd0dc647944d362ec8df628a22967b91d82c80.png" style="height: 16px;" /> we can compute its partial derivative by any of the <em>n</em> inputs as:</p>
<img alt="\[D_j f_i(a)=\frac{\partial f_i}{\partial a_j}(a)\]" class="align-center" src="https://eli.thegreenplace.net/images/math/30881b5a92e45259714ba01c7a12fbf8f6c56109.png" style="height: 42px;" />
<p>Where <em>j</em> goes from 1 to <em>n</em> and <em>a</em> is a vector with <em>n</em> components. If <em>f</em>
is differentiable at <em>a</em> then the derivative of <em>f</em> at <em>a</em> is the <em>Jacobian
matrix</em>:</p>
<img alt="\[Df(a)=\begin{bmatrix} D_1 f_1(a) &amp;amp; \cdots &amp;amp; D_n f_1(a) \\ \vdots &amp;amp; &amp;amp; \vdots \\ D_1 f_m(a) &amp;amp; \cdots &amp;amp; D_n f_m(a) \\ \end{bmatrix}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/ab09367d48e9ef4d8bc2314a60313dec700193af.png" style="height: 76px;" />
<p>The multivariate chain rule states: given <img alt="g:\mathbb{R}^n \to \mathbb{R}^m" class="valign-m4" src="https://eli.thegreenplace.net/images/math/b4b7d25491897b053abf7e48688fada4a85368bd.png" style="height: 16px;" />
and <img alt="f:\mathbb{R}^m \to \mathbb{R}^p" class="valign-m4" src="https://eli.thegreenplace.net/images/math/ac8a6cea4e02e885538fc3ef969c5733e84712f9.png" style="height: 16px;" /> and a point <img alt="a \in \mathbb{R}^n" class="valign-m1" src="https://eli.thegreenplace.net/images/math/43a85f2c59f396fe5c4e2c403a0453c463fcfb0d.png" style="height: 13px;" />,
if <em>g</em> is differentiable at <em>a</em> and <em>f</em> is differentiable at <img alt="g(a)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/e7373233d49e18a0882e0dce41d9d6aa26964d6b.png" style="height: 18px;" /> then
the composition <img alt="f \circ g" class="valign-m4" src="https://eli.thegreenplace.net/images/math/1247a6ac0bc07bfdbd790831aa70b0b000bad2e4.png" style="height: 16px;" /> is differentiable at <em>a</em> and its derivative
is:</p>
<img alt="\[D(f \circ g)(a)=Df(g(a)) \cdot Dg(a)\]" class="align-center" src="https://eli.thegreenplace.net/images/math/00bdefa904bd34df2dfb50cc385e6497c4e5096e.png" style="height: 18px;" />
<p>Which is the matrix multiplication of <img alt="Df(g(a))" class="valign-m4" src="https://eli.thegreenplace.net/images/math/e567730c48bb2f95c258b630b4d6e997043e09ab.png" style="height: 18px;" /> and <img alt="Dg(a)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/2575fc98e794a733a7aa6237fe67246a41e6c8c5.png" style="height: 18px;" />.</p>
</div>
<div class="section" id="back-to-the-fully-connected-layer">
<h2>Back to the fully-connected layer</h2>
<p>Circling back to our fully-connected layer, we have the loss <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/abf7408ae6d9fb4683480735dc1ebc8555b8fef8.svg" style="height: 18px;" type="image/svg+xml">L(y)</object> - a
scalar function <object class="valign-m1" data="https://eli.thegreenplace.net/images/math/ddef8b9ca23fb246b2a984c719d812f37a41a406.svg" style="height: 16px;" type="image/svg+xml">L:\mathbb{R}^{T} \to \mathbb{R}</object>. We also have the
function <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/f09c295439296549a068b64ffe69a48dd77d1078.svg" style="height: 17px;" type="image/svg+xml">y=Wx+b</object>. If we're interested in the derivative w.r.t the
weights, what are the dimensions of this function? Our &quot;variable part&quot; is then
<em>W</em>, which has <em>NT</em> elements overall, and the output has <em>T</em> elements, so
<object class="valign-m4" data="https://eli.thegreenplace.net/images/math/06178f1d07375b8286afcd48f02bcd34d71537f0.svg" style="height: 19px;" type="image/svg+xml">y:\mathbb{R}^{NT} \to \mathbb{R}^{T}</object> <a class="footnote-reference" href="#id3" id="id1">[1]</a>.</p>
<p>The chain rule tells us how to compute the derivative of <em>L</em> w.r.t. <em>W</em>:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/ee6bc25a34980031f93f0c7eefccc40663b05c76.svg" style="height: 38px;" type="image/svg+xml">
\[\frac{\partial{L}}{\partial{W}}=D(L \circ y)(W)=DL(y(W)) \cdot Dy(W)\]</object>
<p>Since we're backpropagating, we already know <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/dcb2eda345045dac22c425a1ee19113e047126cf.svg" style="height: 18px;" type="image/svg+xml">DL(y(W))</object>; because of the
dimensionality of the <em>L</em> function, the dimensions of <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/dcb2eda345045dac22c425a1ee19113e047126cf.svg" style="height: 18px;" type="image/svg+xml">DL(y(W))</object> are
[1,T] (one row, <em>T</em> columns). <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/d0064a180ddb231bb6868ce25c68ef3ec1c2a464.svg" style="height: 18px;" type="image/svg+xml">y(W)</object> has <em>NT</em> inputs and <em>T</em> outputs,
so the dimensions of <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/b22fe7345e02ae50c68605696f3a447435cd1f9d.svg" style="height: 18px;" type="image/svg+xml">Dy(W)</object> are [T,NT]. Overall, the dimensions of
<object class="valign-m4" data="https://eli.thegreenplace.net/images/math/2f6b6eded4ba20b3eeb59b2b687f84de1e91c04c.svg" style="height: 18px;" type="image/svg+xml">D(L \circ y)(W)</object> are then [1,NT]. This makes sense if you think about it,
because as a function of <em>W</em>, the loss has <em>NT</em> inputs and a single scalar
output.</p>
<p>What remains is to compute <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/b22fe7345e02ae50c68605696f3a447435cd1f9d.svg" style="height: 18px;" type="image/svg+xml">Dy(W)</object>, the Jacobian of <em>y</em> w.r.t. <em>W</em>. As
mentioned above, it has <em>T</em> rows - one for each output element of <em>y</em>, and <em>NT</em>
columns - one for each element in the weight matrix <em>W</em>. Computing such a large
Jacobian may seem daunting, but we'll soon see that it's very easy to generalize
from a simple example. Let's start with <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/6a53741f2a8810da3cae4efadde63c8e7ee2662f.svg" style="height: 12px;" type="image/svg+xml">y_1</object>:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/7190e002ac69968b674aecacfd5a8531ad9cd208.svg" style="height: 55px;" type="image/svg+xml">
\[y_1=\sum_{j=1}^{N}W_{1,j}x_{j}+b_1\]</object>
<p>What's the derivative of this result element w.r.t. each element in <em>W</em>? When
the element is in row 1, the derivative is <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/73058e43db0f4edc791b10f27f913cbc5d361ab6.svg" style="height: 14px;" type="image/svg+xml">x_j</object> (<em>j</em> being the column
of <em>W</em>); when the element is in any other row, the derivative is 0.</p>
<p>Similarly for <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/b9f59182e34baa532fa4e27471acc714f3105d16.svg" style="height: 12px;" type="image/svg+xml">y_2</object>, we'll have non-zero derivatives only for the second
row of <em>W</em> (with the same result of <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/73058e43db0f4edc791b10f27f913cbc5d361ab6.svg" style="height: 14px;" type="image/svg+xml">x_j</object> being the derivative for the
<em>j</em>-th column), and so on.</p>
<p>Generalizing from the example, if we split the index of <em>W</em> to <em>i</em> and <em>j</em>, we
get:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/e28e0d3b44645eb299cceae8dde2319244e86373.svg" style="height: 50px;" type="image/svg+xml">
\[\begin{align}
D_{ij}y_t&amp;=\frac{\partial(\sum_{j=1}^{N}W_{t,j}x_{j}+b_t)}{\partial W_{ij}}
&amp;= \left\{\begin{matrix}
x_j &amp; i = t\\
0 &amp; i \ne t
\end{matrix}\right.
\end{align*}\]</object>
<p>This goes into row <em>t</em>, column <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/ef7b2d987af3c0ceb75381d096c35e8c19085642.svg" style="height: 18px;" type="image/svg+xml">(i-1)N+j</object> in the Jacobian matrix. Overall,
we get the following Jacobian matrix with shape [T,NT]:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/8a59a6251d12196f12eaadb6537289e3a6368d53.svg" style="height: 76px;" type="image/svg+xml">
\[Dy=\begin{bmatrix}
x_1 &amp; x_2 &amp; \cdots &amp; x_N &amp; \cdots &amp; 0 &amp; 0 &amp; \cdots &amp; 0 \\
\vdots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \vdots \\
0 &amp; 0 &amp; \cdots &amp; 0 &amp; \cdots &amp; x_1 &amp; x_2 &amp; \cdots &amp; x_N
\end{bmatrix}\]</object>
<p>Now we're ready to finally multiply the Jacobians together to complete the
chain rule:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/2e9823350972d4874d201a0f8232d89fea710c6f.svg" style="height: 18px;" type="image/svg+xml">
\[D(L \circ y)(W)=DL(y(W)) \cdot Dy(W)\]</object>
<p>The left-hand side is this row vector:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/af6d7af820f16b7493d378e8a40daa87031591f4.svg" style="height: 41px;" type="image/svg+xml">
\[DL(y)=(\frac{\partial L}{\partial y_1}, \frac{\partial L}{\partial y_2},\cdots,\frac{\partial L}{\partial y_T})\]</object>
<p>And we're multiplying it by the matrix <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/baf7d8b700759b28ece347bd62793400ef52a8e0.svg" style="height: 16px;" type="image/svg+xml">Dy</object> shown above. Each item in the
result vector will be a dot product between <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/992673c682a388ea8231ebbd8ea28c9cecae874d.svg" style="height: 18px;" type="image/svg+xml">DL(y)</object> and the corresponding
column in the matrix <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/baf7d8b700759b28ece347bd62793400ef52a8e0.svg" style="height: 16px;" type="image/svg+xml">Dy</object>. Since <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/baf7d8b700759b28ece347bd62793400ef52a8e0.svg" style="height: 16px;" type="image/svg+xml">Dy</object> has a single non-zero element
in each column, the result is fairly trivial. The first <em>N</em> entries are:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/d0e87e4d5cd9feb9d93d153733a182426d175d7e.svg" style="height: 41px;" type="image/svg+xml">
\[\frac{\partial L}{\partial y_1}x_1,
\frac{\partial L}{\partial y_1}x_2,
\cdots,
\frac{\partial L}{\partial y_1}x_N\]</object>
<p>The next <em>N</em> entries are:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/5bd4d8b6b4eb817b071dbf4ddda71680f4bf0392.svg" style="height: 41px;" type="image/svg+xml">
\[\frac{\partial L}{\partial y_2}x_1,
\frac{\partial L}{\partial y_2}x_2,
\cdots,
\frac{\partial L}{\partial y_2}x_N\]</object>
<p>And so on, until the last (<em>T</em>-th) set of <em>N</em> entries is all <em>x</em>-es multiplied
by <object class="valign-m9" data="https://eli.thegreenplace.net/images/math/b44681f2ca721dae2b24a49d88f01463e3a88e50.svg" style="height: 26px;" type="image/svg+xml">\frac{\partial L}{\partial y_T}</object>.</p>
<p>To better see how to apply each derivative to a corresponding element in <em>W</em>, we
can &quot;re-roll&quot; this result back into a matrix of shape [T,N]:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/7cfccbaaa844f8ae994f8e012f12557919927e31.svg" style="height: 129px;" type="image/svg+xml">
\[\frac{\partial{L}}{\partial{W}}=D(L\circ y)(W)=\begin{bmatrix}
\frac{\partial L}{\partial y_1}x_1 &amp; \frac{\partial L}{\partial y_1}x_2 &amp; \cdots &amp; \frac{\partial L}{\partial y_1}x_N \\ \\
\frac{\partial L}{\partial y_2}x_1 &amp; \frac{\partial L}{\partial y_2}x_2 &amp; \cdots &amp; \frac{\partial L}{\partial y_2}x_N \\
\vdots &amp; \vdots &amp; \ddots &amp; \vdots \\
\frac{\partial L}{\partial y_T}x_1 &amp; \frac{\partial L}{\partial y_T}x_2 &amp; \cdots &amp; \frac{\partial L}{\partial y_T}x_N
\end{bmatrix}\]</object>
</div>
<div class="section" id="computational-cost-and-shortcut">
<h2>Computational cost and shortcut</h2>
<p>While the derivation shown above is complete and mathematically correct, it can
also be computationally intensive; in realistic scenarios, the full Jacobian
matrix can be <em>really</em> large. For example, let's say our input is a (modestly
sized) 128x128 image, so <em>N=16,384</em>. Let's also say that <em>T=100</em>. The weight
matrix then has <em>NT=1,638,400</em> elements; respectably big, but nothing out of the
ordinary.</p>
<p>Now consider the size of the full Jacobian matrix: it's <em>T</em> by <em>NT</em>, or over
160 million elements. At 4 bytes per element that's more than half a GiB!</p>
<p>Moreover, to compute every backpropagation we'd be forced to multiply this full
Jacobian matrix by a 100-dimensional vector, performing 160 million
multiply-and-add operations for the dot products. That's a lot of compute.</p>
<p>But the final result <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/89da23923a43fcd95b185bebb6fd362b6d1ac695.svg" style="height: 18px;" type="image/svg+xml">D(L\circ y)(W)</object> is the size of <em>W</em> - 1.6 million
elements. Do we really need 160 million computations to get to it? No, because
the Jacobian is very <em>sparse</em> - most of it is zeros. And in fact, when we look
at the <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/89da23923a43fcd95b185bebb6fd362b6d1ac695.svg" style="height: 18px;" type="image/svg+xml">D(L\circ y)(W)</object> found above - it's fairly straightforward to
compute using a single multiplication per element.</p>
<p>Moreover, if we stare at the <object class="valign-m7" data="https://eli.thegreenplace.net/images/math/33d2709b664fdd69317758b433b61b13c1cdc62f.svg" style="height: 24px;" type="image/svg+xml">\frac{\partial{L}}{\partial{W}}</object> matrix a
bit, we'll notice it has a familiar pattern: this is just the <a class="reference external" href="https://en.wikipedia.org/wiki/Outer_product">outer product</a> between the vectors
<object class="valign-m9" data="https://eli.thegreenplace.net/images/math/9a5154f5e8d64cc77db745d8d3baa723bc6df829.svg" style="height: 26px;" type="image/svg+xml">\frac{\partial{L}}{\partial{y}}</object> and <em>x</em>:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/89bdcdf27feb489a5e3cb1bb8adc7faffcf0207d.svg" style="height: 41px;" type="image/svg+xml">
\[\frac{\partial L}{\partial W}=\frac{\partial L}{\partial y}\otimes x\]</object>
<p>If we have to compute this backpropagation in Python/Numpy, we'll likely write
code similar to:</p>
<div class="highlight"><pre><span></span><span class="c1"># Assuming dy (gradient of loss w.r.t. y) and x are column vectors, by</span>
<span class="c1"># performing a dot product between dy (column) and x.T (row) we get the</span>
<span class="c1"># outer product.</span>
<span class="n">dW</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">dy</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">T</span><span class="p">)</span>
</pre></div>
</div>
<div class="section" id="bias-gradient">
<h2>Bias gradient</h2>
<p>We've just seen how to compute weight gradients for a fully-connected layer.
Computing the gradients for the bias vector is very similar, and a bit simpler.</p>
<p>This is the chain rule equation applied to the bias vector:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/3aee48692aeafe9ffab07037ad374f4c803787a7.svg" style="height: 38px;" type="image/svg+xml">
\[\frac{\partial{L}}{\partial{b}}=D(L \circ y)(b)=DL(y(b)) \cdot Dy(b)\]</object>
<p>The shapes involved here are: <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/ef9baa4141fed9b40c4f1b0ebf189e4d8d28badc.svg" style="height: 18px;" type="image/svg+xml">DL(y(b))</object> is still [1,T], because the
number of elements in <em>y</em> remains <em>T</em>. <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/966d0c5b07f027b02b0ca9eb418ed9ac12f63386.svg" style="height: 18px;" type="image/svg+xml">Dy(b)</object> has <em>T</em> inputs (bias
elements) and <em>T</em> outputs (<em>y</em> elements), so its shape is [T,T]. Therefore, the
shape of the gradient <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/7640378fb78362268ffe48bf5d68a266211673e4.svg" style="height: 18px;" type="image/svg+xml">D(L \circ y)(b)</object> is [1,T].</p>
<p>To see how we'd fill the Jacobian matrix <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/966d0c5b07f027b02b0ca9eb418ed9ac12f63386.svg" style="height: 18px;" type="image/svg+xml">Dy(b)</object>, let's go back to the
formula for <em>y</em>:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/7190e002ac69968b674aecacfd5a8531ad9cd208.svg" style="height: 55px;" type="image/svg+xml">
\[y_1=\sum_{j=1}^{N}W_{1,j}x_{j}+b_1\]</object>
<p>When derived by anything other than <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/c7cd24d955e66b8fe5ce45ded69fd98da5c68ba8.svg" style="height: 17px;" type="image/svg+xml">b_1</object>, this would be 0; when derived
by <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/c7cd24d955e66b8fe5ce45ded69fd98da5c68ba8.svg" style="height: 17px;" type="image/svg+xml">b_1</object> the result is 1. The same applies to every other element of <em>y</em>:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/0511363d44324e04aee704c9cc3094a4e8c8c108.svg" style="height: 44px;" type="image/svg+xml">
\[\frac{\partial y_i}{\partial b_j}=\left\{\begin{matrix}
1 &amp; i=j \\
0 &amp; i\neq j
\end{matrix}\right\]</object>
<p>In matrix form, this is just an identity matrix with dimensions [T,T].
Therefore:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/a64bd506727621a3e78444f7e158769dae30f93b.svg" style="height: 38px;" type="image/svg+xml">
\[\frac{\partial{L}}{\partial{b}}=D(L \circ y)(b)=DL(y(b)) \cdot I =DL(y(b))\]</object>
<p>For a given element of <em>b</em>, its gradient is just the corresponding element in
<object class="valign-m9" data="https://eli.thegreenplace.net/images/math/f004c6bbe71887354e0aad67dd7cbe6650eb58e9.svg" style="height: 26px;" type="image/svg+xml">\frac{\partial L}{\partial y}</object>.</p>
</div>
<div class="section" id="fully-connected-layer-for-a-batch-of-inputs">
<h2>Fully-connected layer for a batch of inputs</h2>
<p>The derivation shown above applies to a FC layer with a single input vector <em>x</em>
and a single output vector <em>y</em>. When we train models, we almost always try to
do so in <em>batches</em> (or <em>mini-batches</em>) to better leverage the parallelism of
modern hardware. So a more typical layer computation would be:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/c26a36b850dd7b2e4288e475a590f343ec3a18a3.svg" style="height: 15px;" type="image/svg+xml">
\[Y=WX+b\]</object>
<p>Where the shape of <em>X</em> is [N,B]; <em>B</em> is the batch size, typically a
not-too-large power of 2, like 32. <em>W</em> and <em>b</em> still have the same shapes, so
the shape of <em>Y</em> is [T,B]. Each column in <em>X</em> is a new input vector (for a
total of <em>B</em> vectors in a batch); a corresponding column in <em>Y</em> is the output.</p>
<p>As before, given <object class="valign-m7" data="https://eli.thegreenplace.net/images/math/d5de3c7d9e0e1bcb4f6c00ea06b4ad808d2ea998.svg" style="height: 24px;" type="image/svg+xml">\frac{\partial{L}}{\partial{Y}}</object>, our goal is to find
<object class="valign-m7" data="https://eli.thegreenplace.net/images/math/33d2709b664fdd69317758b433b61b13c1cdc62f.svg" style="height: 24px;" type="image/svg+xml">\frac{\partial{L}}{\partial{W}}</object> and
<object class="valign-m7" data="https://eli.thegreenplace.net/images/math/5f12a50803653cf2ee02135944343ec70506d31c.svg" style="height: 24px;" type="image/svg+xml">\frac{\partial{L}}{\partial{b}}</object>. While the end results are fairly simple
and pretty much what you'd expect, I still want to go through the full Jacobian
computation to show how to find the gradiends in a rigorous way.</p>
<p>Starting with the weigths, the chain rule is:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/0ffed6ae9645ea6bd0d02932e2f0ca20fb8e7bc6.svg" style="height: 38px;" type="image/svg+xml">
\[\frac{\partial{L}}{\partial{W}}=D(L \circ Y)(W)=DL(Y(W)) \cdot DY(W)\]</object>
<p>The dimensions are:</p>
<ul class="simple">
<li><object class="valign-m4" data="https://eli.thegreenplace.net/images/math/86485acc2c4461f7817626204bf6c9148dad9d87.svg" style="height: 18px;" type="image/svg+xml">DL(Y(W))</object>: [1,TB] because <em>Y</em> has <em>T</em> outputs for each input vector in
the batch.</li>
<li><object class="valign-m4" data="https://eli.thegreenplace.net/images/math/573b889d69d85759886840570c6970345209b332.svg" style="height: 18px;" type="image/svg+xml">DY(W)</object>: [TB,TN] since <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/fe5551953a6c071c738578f2ebc316864078cc81.svg" style="height: 18px;" type="image/svg+xml">Y(W)</object> has <em>TB</em> outputs and <em>TN</em>
inputs overall.</li>
<li><object class="valign-m4" data="https://eli.thegreenplace.net/images/math/b214435777e236879c609900ba7a118e9f0da022.svg" style="height: 18px;" type="image/svg+xml">D(L\circ Y)(W)</object>: [1,TN] same as in the batch-1 case, because the same
weight matrix is used for all inputs in the batch.</li>
</ul>
<p>Also, we'll use the notation <object class="valign-m5" data="https://eli.thegreenplace.net/images/math/5f40e2ad50a0eb5c2f5019c48563f9c6605f84b6.svg" style="height: 24px;" type="image/svg+xml">x_{i}^{[b]}</object> to talk about the <em>i</em>-th
element in the <em>b</em>-th input vector <em>x</em> (out of a total of <em>B</em> such input
vectors).</p>
<p>With this in hand, let's see how the Jacobians look; starting with
<object class="valign-m4" data="https://eli.thegreenplace.net/images/math/86485acc2c4461f7817626204bf6c9148dad9d87.svg" style="height: 18px;" type="image/svg+xml">DL(Y(W))</object>, it's the same as before except that we have to take the batch
into account. Each batch element is independent of the others in loss
computations, so we'll have:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/943d68b7dbda5009cbfd597b4e0fcc46748204a5.svg" style="height: 47px;" type="image/svg+xml">
\[\frac{\partial L}{\partial y_{i}^{[b]}}\]</object>
<p>As the Jacobian element; how do we arrange them in a 1-dimensional vector with
shape [1,TB]? We'll just have to agree on a linearization here - same as we did
with <em>W</em> before. We'll go for row-major again, so in 1-D the array <em>Y</em> would be:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/13122f18953df14ebb5ff74f59441194d3adb445.svg" style="height: 26px;" type="image/svg+xml">
\[Y=(y_{1}^{[1]},y_{1}^{[2]},\cdots,y_{1}^{[B]},
y_{2}^{[1]},y_{2}^{[2]},\cdots,y_{2}^{[B]},\cdots)\]</object>
<p>And so on for <em>T</em> elements. Therefore, the Jacobian of <em>L</em> w.r.t <em>Y</em> is:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/4e79e866dbfbe0d6b6de5f6617762bce00d5f61f.svg" style="height: 48px;" type="image/svg+xml">
\[\frac{\partial L}{\partial Y}=(
\frac{\partial L}{\partial y_{1}^{[1]}},
\frac{\partial L}{\partial y_{1}^{[2]}},\cdots,
\frac{\partial L}{\partial y_{1}^{[B]}},
\frac{\partial L}{\partial y_{2}^{[1]}},
\frac{\partial L}{\partial y_{2}^{[2]}},\cdots,
\frac{\partial L}{\partial y_{2}^{[B]}},\cdots)\]</object>
<p>To find <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/573b889d69d85759886840570c6970345209b332.svg" style="height: 18px;" type="image/svg+xml">DY(W)</object>, let's first see how to compute <em>Y</em>. The <em>i</em>-th element
of <em>Y</em> for batch <em>b</em> is:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/2b9c88f44b9cbec2343ce49418ca3e17dd2e0946.svg" style="height: 55px;" type="image/svg+xml">
\[y_{i}^{[b]}=\sum_{j=1}^{N}W_{i,j}x_{j}^{[b]}+b_i\]</object>
<p>Recall that the Jacobian <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/573b889d69d85759886840570c6970345209b332.svg" style="height: 18px;" type="image/svg+xml">DY(W)</object> now has shape [TB,TN]. Previously we had
to unroll the [T,N] of the weight matrix into the rows. Now we'll also have to
unrill the [T,B] of the output into the columns. As before, first all <em>b</em>-s for
<em>t=1</em>, then all <em>b</em>-s for <em>t=2</em>, etc. If we carefully compute the derivative,
we'll see that the Jacobian matrix has similar structure to the single-batch
case, just with each line repeated <em>B</em> times for each of the batch elements:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/4fe86285c11962226758ecfad2839b2ce6520d2d.svg" style="height: 291px;" type="image/svg+xml">
\[DY(W)=\begin{bmatrix}
x_{1}^{[1]} &amp; x_{2}^{[1]} &amp; \cdots &amp; x_{N}^{[1]} &amp; \cdots &amp; 0 &amp; 0 &amp; \cdots &amp; 0 \\ \\
x_{1}^{[2]} &amp; x_{2}^{[2]} &amp; \cdots &amp; x_{N}^{[2]} &amp; \cdots &amp; 0 &amp; 0 &amp; \cdots &amp; 0 \\
\vdots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \vdots \\
x_{1}^{[B]} &amp; x_{2}^{[B]} &amp; \cdots &amp; x_{N}^{[B]} &amp; \cdots &amp; 0 &amp; 0 &amp; \cdots &amp; 0 \\
\vdots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \vdots \\
0 &amp; 0 &amp; \cdots &amp; 0 &amp; \cdots &amp; x_{1}^{[1]} &amp; x_{2}^{[1]} &amp; \cdots &amp; x_{N}^{[1]} \\ \\
0 &amp; 0 &amp; \cdots &amp; 0 &amp; \cdots &amp; x_{1}^{[2]} &amp; x_{2}^{[2]} &amp; \cdots &amp; x_{N}^{[2]} \\
\vdots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \vdots \\
0 &amp; 0 &amp; \cdots &amp; 0 &amp; \cdots &amp; x_{1}^{[B]} &amp; x_{2}^{[B]} &amp; \cdots &amp; x_{N}^{[B]} \\
\end{bmatrix}\]</object>
<p>Multiplying the two Jacobians together we get the full gradient of <em>L</em> w.r.t.
each element in the weight matrix. Where previously (in the non-batch case) we
had:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/1c762a65e0003e82f7dc5108f23126989b64112b.svg" style="height: 42px;" type="image/svg+xml">
\[\frac{\partial L}{\partial W_{ij}}=\frac{\partial L}{\partial y_i}x_j\]</object>
<p>Now, instead, we'll have:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/14fb5d857d1e43a9f692ad436c448a43c0ea041f.svg" style="height: 54px;" type="image/svg+xml">
\[\frac{\partial L}{\partial W_{ij}}=\sum_{b=1}^{B}\frac{\partial L}{\partial y_{i}^{[b]}}x_{j}^{[b]}\]</object>
<p>Which makes total sense, since it's simply taking the loss gradient computed
from each batch separately and adds them up. This aligns with our intuition of
how gradient for a whole batch is computed - compute the gradient for each batch
element separately and add up all the gradients <a class="footnote-reference" href="#id4" id="id2">[2]</a>.</p>
<p>As before, there's a clever way to express the final gradient using matrix
operations. Note the sum across all batch elements when computing
<object class="valign-m10" data="https://eli.thegreenplace.net/images/math/2d41c4c820515c93e916d32532b9bdc7012e8121.svg" style="height: 27px;" type="image/svg+xml">\frac{\partial L}{\partial W_{ij}}</object>. We can express this as the matrix
multiplication:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/7f6c176f28de451b2d67fcf7ebf238122de9a970.svg" style="height: 38px;" type="image/svg+xml">
\[\frac{\partial L}{\partial W}=\frac{\partial L}{\partial Y}\cdot X^T\]</object>
<p>This is a good place to recall the computation cost again. Previously we've seen
that for a single-input case, the Jacobian can be extremely large ([T,NT] having
about 160 million elements). In the batch case, the Jacobian would be even
larger since its shape is [TB,NT]; with a reasonable batch of 32, it's something
like 5-billion elements strong. It's good that we don't actually have to hold
the full Jacobian in memory and have a shortcut way of computing the gradient.</p>
</div>
<div class="section" id="bias-gradient-for-a-batch">
<h2>Bias gradient for a batch</h2>
<p>For the bias, we have:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/1228909398e9618223e551b8b1d394ac20d697f1.svg" style="height: 38px;" type="image/svg+xml">
\[\frac{\partial{L}}{\partial{b}}=D(L \circ Y)(b)=DL(Y(b)) \cdot DY(b)\]</object>
<p><object class="valign-m4" data="https://eli.thegreenplace.net/images/math/40699cf4e67bde5205359e04102f7b0011dac800.svg" style="height: 18px;" type="image/svg+xml">DL(Y(b))</object> here has the shape [1,TB]; <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/f260ada7edc55af13f145e5786803198a3452f1e.svg" style="height: 18px;" type="image/svg+xml">DY(b)</object> has the shape [TB,T].
Therefore, the shape of <object class="valign-m7" data="https://eli.thegreenplace.net/images/math/5f12a50803653cf2ee02135944343ec70506d31c.svg" style="height: 24px;" type="image/svg+xml">\frac{\partial{L}}{\partial{b}}</object> is [1,T], as
before.</p>
<p>From the formula for computing <em>Y</em>:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/7190e002ac69968b674aecacfd5a8531ad9cd208.svg" style="height: 55px;" type="image/svg+xml">
\[y_1=\sum_{j=1}^{N}W_{1,j}x_{j}+b_1\]</object>
<p>We get, for any batch <em>b</em>:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/7ad81febc17ec33d21c8fba2a2e6956a8b43e1ad.svg" style="height: 49px;" type="image/svg+xml">
\[\frac{\partial y_{i}^{[b]}}{\partial b_j}=\left\{\begin{matrix}
1 &amp; i=j \\
0 &amp; i\neq j
\end{matrix}\right\]</object>
<p>So, whereas <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/f260ada7edc55af13f145e5786803198a3452f1e.svg" style="height: 18px;" type="image/svg+xml">DY(b)</object> was an identity matrix in the no-batch case, here it
looks like this:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/4b8148fb1343ab283fa0f8b0cdb6f3723201df15.svg" style="height: 267px;" type="image/svg+xml">
\[DY(b)=\begin{bmatrix}
1 &amp; 0 &amp; 0 &amp; \cdots &amp; 0 \\
1 &amp; 0 &amp; 0 &amp; \cdots &amp; 0 \\
\vdots &amp; \vdots &amp; \vdots &amp; \ddots &amp; \vdots \\
1 &amp; 0 &amp; 0 &amp; \cdots &amp; 0 \\
0 &amp; 1 &amp; 0 &amp; \cdots &amp; 0 \\
0 &amp; 1 &amp; 0 &amp; \cdots &amp; 0 \\
\vdots &amp; \vdots &amp; \vdots &amp; \ddots &amp; \vdots \\
0 &amp; 0 &amp; 0 &amp; \cdots &amp; 1 \\
0 &amp; 0 &amp; 0 &amp; \cdots &amp; 1 \\
\vdots &amp; \vdots &amp; \vdots &amp; \ddots &amp; \vdots \\
0 &amp; 0 &amp; 0 &amp; \cdots &amp; 1 \\
\end{bmatrix}\]</object>
<p>With <em>B</em> identical rows at a time, for a total of <em>TB</em> rows. Since
<object class="valign-m7" data="https://eli.thegreenplace.net/images/math/c7d9499ae5d7e1fc81bc540909deac668210911d.svg" style="height: 24px;" type="image/svg+xml">\frac{\partial L}{\partial Y}</object> is the same as before, their matrix
multiplication result has this in column <em>j</em>:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/910c6d77b9356303c0a01896a09d62fe2963d8ac.svg" style="height: 57px;" type="image/svg+xml">
\[\frac{\partial{L}}{\partial{b_j}}=\sum_{b=1}^{B}\frac{\partial L}{\partial y_{j}^{[b]}}\]</object>
<p>Which just means adding up the gradient effects from every batch element
independently.</p>
</div>
<div class="section" id="addendum-gradient-w-r-t-x">
<h2>Addendum - gradient w.r.t. x</h2>
<p>This post started by explaining that the parameters of a fully-connected layer
we're usually looking to optimize are the weight matrix and bias. In most cases
this is true; however, in some other cases we're actually interested in
propagating a gradient through <em>x</em> - often when there are more layers before
the fully-connected layer in question.</p>
<p>Let's find the derivative <object class="valign-m7" data="https://eli.thegreenplace.net/images/math/54869ab2743febebc22269d12572c77e057c816e.svg" style="height: 24px;" type="image/svg+xml">\frac{\partial{L}}{\partial{x}}</object>. The chain
rule here is:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/a52a21efbf7f0b992db127d495abddd618677709.svg" style="height: 38px;" type="image/svg+xml">
\[\frac{\partial{L}}{\partial{x}}=D(L \circ y)(x)=DL(y(x)) \cdot Dy(x)\]</object>
<p>Dimensions: <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/2e9597ce1ffd09d94be733216c3f1c1b2ab5f33c.svg" style="height: 18px;" type="image/svg+xml">DL(y(x))</object> is [1, T] as before; <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/a2bef37f23427154c47e53945043549039e36bcf.svg" style="height: 18px;" type="image/svg+xml">Dy(x)</object> has T outputs
(elements of <em>y</em>) and N inputs (elements of <em>x</em>), so its dimensions are [T, N].
Therefore, the dimensions of <object class="valign-m7" data="https://eli.thegreenplace.net/images/math/54869ab2743febebc22269d12572c77e057c816e.svg" style="height: 24px;" type="image/svg+xml">\frac{\partial{L}}{\partial{x}}</object> are [1, N].</p>
<p>From:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/7190e002ac69968b674aecacfd5a8531ad9cd208.svg" style="height: 55px;" type="image/svg+xml">
\[y_1=\sum_{j=1}^{N}W_{1,j}x_{j}+b_1\]</object>
<p>We know that <object class="valign-m10" data="https://eli.thegreenplace.net/images/math/c209b0f19299fee08359f73898212bb0d0df8c30.svg" style="height: 28px;" type="image/svg+xml">\frac{\partial y_1}{\partial x_j}=W_{1,j}</object>. Generalizing
this, we get <object class="valign-m10" data="https://eli.thegreenplace.net/images/math/6abadcb3365f09141a9cac088fdbb17418e75171.svg" style="height: 28px;" type="image/svg+xml">\frac{\partial y_i}{\partial x_j}=W_{i,j}</object>; in other words,
<object class="valign-m4" data="https://eli.thegreenplace.net/images/math/a2bef37f23427154c47e53945043549039e36bcf.svg" style="height: 18px;" type="image/svg+xml">Dy(x)</object> is just the weight matrix <em>W</em>. So
<object class="valign-m8" data="https://eli.thegreenplace.net/images/math/e6c2d66d989f1abdb9e8b492e45f00be1ab2a21b.svg" style="height: 25px;" type="image/svg+xml">\frac{\partial{L}}{\partial{x_i}}</object> is the dot product of <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/2e9597ce1ffd09d94be733216c3f1c1b2ab5f33c.svg" style="height: 18px;" type="image/svg+xml">DL(y(x))</object>
with the <em>i</em>-th column of <em>W</em>.</p>
<p>Computationally, we can express this as follows:</p>
<div class="highlight"><pre><span></span><span class="n">dx</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">dy</span><span class="o">.</span><span class="n">T</span><span class="p">,</span> <span class="n">W</span><span class="p">)</span><span class="o">.</span><span class="n">T</span>
</pre></div>
<p>Again, recall that our vectors are <em>column</em> vectors. Therefore, to multiply <em>dy</em>
from the left by <em>W</em> we have to transpose it to a row vector first. The result
of this matrix multiplication is a [1, N] row-vector, so we transpose it again
to get a column.</p>
<p>An alternative method to compute this would transpose <em>W</em> rather than <em>dy</em> and
then swap the order:</p>
<div class="highlight"><pre><span></span><span class="n">dx</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">W</span><span class="o">.</span><span class="n">T</span><span class="p">,</span> <span class="n">dy</span><span class="p">)</span>
</pre></div>
<p>These two methods produce exactly the same <em>dx</em>; it's important to be familiar
with these tricks, because otherwise it may be confusing to see a transposed <em>W</em>
when we expect the actual <em>W</em> from gradient computations.</p>
<hr class="docutils" />
<table class="docutils footnote" frame="void" id="id3" rules="none">
<colgroup><col class="label" /><col /></colgroup>
<tbody valign="top">
<tr><td class="label"><a class="fn-backref" href="#id1">[1]</a></td><td><p class="first">As explained in the
<a class="reference external" href="http://eli.thegreenplace.net/2016/the-softmax-function-and-its-derivative">softmax post</a>,
we <em>linearize</em> the 2D matrix <em>W</em> into a single vector with <em>NT</em> elements
using some approach like row-major, where the <em>N</em> elements of the first
row go first, then the <em>N</em> elements of the second row, and so on until we
have <em>NT</em> elements for all the rows.</p>
<p class="last">This is a fully general approach as we can linearize any-dimensional
arrays. To work with Jacobians, we're interested in <em>K</em> inputs, no matter
where they came from - they could be a linearization of a 4D array. As
long as we remember which element out of the <em>K</em> corresponds to which
original element, we'll be fine.</p>
</td></tr>
</tbody>
</table>
<table class="docutils footnote" frame="void" id="id4" rules="none">
<colgroup><col class="label" /><col /></colgroup>
<tbody valign="top">
<tr><td class="label"><a class="fn-backref" href="#id2">[2]</a></td><td>In some cases you may hear about <em>averaging</em> the gradients across the
batch. Averaging just means dividing the sum by <em>B</em>; it's a constant
factor that can be consolidated into the learning rate.</td></tr>
</tbody>
</table>
</div>
Depthwise separable convolutions for machine learning2018-04-04T06:21:00-07:002018-04-04T06:21:00-07:00Eli Benderskytag:eli.thegreenplace.net,2018-04-04:/2018/depthwise-separable-convolutions-for-machine-learning/<p>Convolutions are an important tool in modern deep neural networks (DNNs). This
post is going to discuss some common types of convolutions, specifically
regular and depthwise separable convolutions. My focus will be on the
implementation of these operation, showing from-scratch Numpy-based code to
compute them and diagrams that explain how …</p><p>Convolutions are an important tool in modern deep neural networks (DNNs). This
post is going to discuss some common types of convolutions, specifically
regular and depthwise separable convolutions. My focus will be on the
implementation of these operation, showing from-scratch Numpy-based code to
compute them and diagrams that explain how things work.</p>
<p>Note that my main goal here is to explain how depthwise separable convolutions
differ from regular ones; if you're completely new to convolutions I suggest
reading some more introductory resources first.</p>
<p>The code here is compatible with TensorFlow's definition of convolutions in
the <a class="reference external" href="https://www.tensorflow.org/api_docs/python/tf/nn">tf.nn</a> module. After
reading this post, the documentation of TensorFlow's convolution ops should be
easy to decipher.</p>
<div class="section" id="basic-2d-convolution">
<h2>Basic 2D convolution</h2>
<p>The basic idea behind a 2D convolution is sliding a small window (usually called
a &quot;filter&quot;) over a larger 2D array, and performing a dot product between the
filter elements and the corresponding input array elements at every position.</p>
<p>Here's a diagram demonstrating the application of a 3x3 convolution filter to
a 6x6 array, in 3 different positions. <tt class="docutils literal">W</tt> is the filter, and the yellow-ish
array on the right is the result; the red square shows which element in the
result array is being computed.</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/2018/conv2d-single-block.svg" style="width: 400px;" type="image/svg+xml">
Single-channel 2D convolution</object>
<p>The topmost diagram shows the important concept of <em>padding</em>: what should we do
when the window goes &quot;out of bounds&quot; on the input array. There are several
options, with the following two being most common in DNNs:</p>
<ul class="simple">
<li><em>Valid</em> padding: in which only valid, in-bounds windows are considered. This
also makes the output smaller than the input, because border elements can't be
in the center of a filter (unless the filter is 1x1).</li>
<li><em>Same</em> padding: in which we assume there's some constant value outside the
bounds of the input (usually 0) and the filter is applied to every element.
In this case the output array has the same size as the input array. The
diagrams above depict same padding, which I'll keep using throughout the post.</li>
</ul>
<p>There are other options for the basic 2D convolution case. For example, the
filter can be moving over the input in jumps of more than 1, thus not centering
on all elements. This is called <em>stride</em>, and in this post I'm always using
stride of 1. Convolutions can also be dilated (or <em>atrous</em>), wherein the
filter is expanded with gaps between every element. In this post I'm not going
to discuss dilated convolutions and other options - there are plenty of
resources on these topics online.</p>
</div>
<div class="section" id="implementing-the-2d-convolution">
<h2>Implementing the 2D convolution</h2>
<p>Here is a full Python implementation of the simple 2D convolution. It's called
&quot;single channel&quot; to distinguish it from the more general case in which the input
has more than two dimensions; we'll get to that shortly.</p>
<p>This implementation is fully self-contained, and only needs Numpy to work. All
the loops are fully explicit - I specifically avoided vectorizing them for
efficiency to maintain clarity:</p>
<div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">conv2d_single_channel</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">w</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;Two-dimensional convolution of a single channel.</span>
<span class="sd"> Uses SAME padding with 0s, a stride of 1 and no dilation.</span>
<span class="sd"> input: input array with shape (height, width)</span>
<span class="sd"> w: filter array with shape (fd, fd) with odd fd.</span>
<span class="sd"> Returns a result with the same shape as input.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">assert</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="ow">and</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">%</span> <span class="mi">2</span> <span class="o">==</span> <span class="mi">1</span>
<span class="c1"># SAME padding with zeros: creating a new padded array to simplify index</span>
<span class="c1"># calculations and to avoid checking boundary conditions in the inner loop.</span>
<span class="c1"># padded_input is like input, but padded on all sides with</span>
<span class="c1"># half-the-filter-width of zeros.</span>
<span class="n">padded_input</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">pad</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span>
<span class="n">pad_width</span><span class="o">=</span><span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">//</span> <span class="mi">2</span><span class="p">,</span>
<span class="n">mode</span><span class="o">=</span><span class="s1">&#39;constant&#39;</span><span class="p">,</span>
<span class="n">constant_values</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">output</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]):</span>
<span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">output</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]):</span>
<span class="c1"># This inner double loop computes every output element, by</span>
<span class="c1"># multiplying the corresponding window into the input with the</span>
<span class="c1"># filter.</span>
<span class="k">for</span> <span class="n">fi</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]):</span>
<span class="k">for</span> <span class="n">fj</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]):</span>
<span class="n">output</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">]</span> <span class="o">+=</span> <span class="n">padded_input</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="n">fi</span><span class="p">,</span> <span class="n">j</span> <span class="o">+</span> <span class="n">fj</span><span class="p">]</span> <span class="o">*</span> <span class="n">w</span><span class="p">[</span><span class="n">fi</span><span class="p">,</span> <span class="n">fj</span><span class="p">]</span>
<span class="k">return</span> <span class="n">output</span>
</pre></div>
</div>
<div class="section" id="convolutions-in-3-and-4-dimensions">
<h2>Convolutions in 3 and 4 dimensions</h2>
<p>The convolution computed above works in two dimensions; yet, most convolutions
used in DNNs are 4-dimensional. For example, TensorFlow's <tt class="docutils literal">tf.nn.conv2d</tt> op
takes a 4D input tensor and a 4D filter tensor. How come?</p>
<p>The two additional dimensions in the input tensor are <em>channel</em> and <em>batch</em>. A
canonical example of channels is color images in RGB format. Each pixel has a
value for red, green and blue - three channels overall. So instead of seeing it
as a matrix of triples, we can see it as a 3D tensor where one dimension is
height, another width and another channel (also called the <em>depth</em> dimension).</p>
<p>Batch is somewhat different. ML training - with stochastic gradient descent -
is often done in batches for performance; we train the model not on a single
sample at a time, but a &quot;batch&quot; of samples, usually some power of two.
Performing all the operations in tandem on a batch of data makes it easier to
leverage the SIMD capabilities of modern processors. So it doesn't have any
mathematical significance here - it can be seen as an outer loop over all
operations, performing them for a set of inputs and producing a corresponding
set of outputs.</p>
<p>For filters, the 4 dimensions are height, width, input channel and output
channel. Input channel is the same as the input tensor's; output channel
collects multiple filters, each of which can be different.</p>
<p>This can be slightly difficult to grasp from text, so here's a diagram:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/2018/conv2d-3d.svg" style="width: 300px;" type="image/svg+xml">
Multi-channel 2D convolution</object>
<p>In the diagram and the implementation I'm going to ignore the batch dimension,
since it's not really mathematically interesting. So the input image has three
dimensions - in this diagram height and width are 8 and depth is 3. The filter
is 3x3 with depth 3. In each step, the filter is slid over the input <em>in two
dimensions</em>, and all of its elements are multiplied with the corresponding
elements in the input. That's 3x3x3=27 multiplications added into the output
element.</p>
<p>Note that this is different from a 3D convolution, where a filter is moved
across the input in all 3 dimensions; true 3D convolutions are not widely used
in DNNs at this time.</p>
<p>So, to reitarate, to compute the multi-channel convolution as shown in the
diagram above, we compute each of the 64 output elements by a dot-product of the
filter with the relevant parts of the input tensor. This produces a single
output channel. To produce additional output channels, we perform the
convolution with additional filters. So if our filter has dimensions (3, 3, 3,
4) this means 4 different 3x3x3 filters. The output will thus have dimensions
8x8 for the spatials and 4 for depth.</p>
<p>Here's the Numpy implementation of this algorithm:</p>
<div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">conv2d_multi_channel</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">w</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;Two-dimensional convolution with multiple channels.</span>
<span class="sd"> Uses SAME padding with 0s, a stride of 1 and no dilation.</span>
<span class="sd"> input: input array with shape (height, width, in_depth)</span>
<span class="sd"> w: filter array with shape (fd, fd, in_depth, out_depth) with odd fd.</span>
<span class="sd"> in_depth is the number of input channels, and has the be the same as</span>
<span class="sd"> input&#39;s in_depth; out_depth is the number of output channels.</span>
<span class="sd"> Returns a result with shape (height, width, out_depth).</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">assert</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="ow">and</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">%</span> <span class="mi">2</span> <span class="o">==</span> <span class="mi">1</span>
<span class="n">padw</span> <span class="o">=</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">//</span> <span class="mi">2</span>
<span class="n">padded_input</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">pad</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span>
<span class="n">pad_width</span><span class="o">=</span><span class="p">((</span><span class="n">padw</span><span class="p">,</span> <span class="n">padw</span><span class="p">),</span> <span class="p">(</span><span class="n">padw</span><span class="p">,</span> <span class="n">padw</span><span class="p">),</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">)),</span>
<span class="n">mode</span><span class="o">=</span><span class="s1">&#39;constant&#39;</span><span class="p">,</span>
<span class="n">constant_values</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">,</span> <span class="n">in_depth</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">shape</span>
<span class="k">assert</span> <span class="n">in_depth</span> <span class="o">==</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span>
<span class="n">out_depth</span> <span class="o">=</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">,</span> <span class="n">out_depth</span><span class="p">))</span>
<span class="k">for</span> <span class="n">out_c</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">out_depth</span><span class="p">):</span>
<span class="c1"># For each output channel, perform 2d convolution summed across all</span>
<span class="c1"># input channels.</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">height</span><span class="p">):</span>
<span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">width</span><span class="p">):</span>
<span class="c1"># Now the inner loop also works across all input channels.</span>
<span class="k">for</span> <span class="n">c</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">in_depth</span><span class="p">):</span>
<span class="k">for</span> <span class="n">fi</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]):</span>
<span class="k">for</span> <span class="n">fj</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]):</span>
<span class="n">w_element</span> <span class="o">=</span> <span class="n">w</span><span class="p">[</span><span class="n">fi</span><span class="p">,</span> <span class="n">fj</span><span class="p">,</span> <span class="n">c</span><span class="p">,</span> <span class="n">out_c</span><span class="p">]</span>
<span class="n">output</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">,</span> <span class="n">out_c</span><span class="p">]</span> <span class="o">+=</span> <span class="p">(</span>
<span class="n">padded_input</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="n">fi</span><span class="p">,</span> <span class="n">j</span> <span class="o">+</span> <span class="n">fj</span><span class="p">,</span> <span class="n">c</span><span class="p">]</span> <span class="o">*</span> <span class="n">w_element</span><span class="p">)</span>
<span class="k">return</span> <span class="n">output</span>
</pre></div>
<p>An interesting point to note here w.r.t. TensorFlow's <tt class="docutils literal">tf.nn.conv2d</tt> op. If
you read its semantics you'll see discussion of <em>layout</em> or <em>data format</em>, which
is <tt class="docutils literal">NHWC</tt> by default. NHWC simply means the order of dimensions in a 4D
tensor is:</p>
<ul class="simple">
<li><strong>N</strong>: batch</li>
<li><strong>H</strong>: height (spatial dimension)</li>
<li><strong>W</strong>: width (spatial dimension)</li>
<li><strong>C</strong>: channel (depth)</li>
</ul>
<p><tt class="docutils literal">NHWC</tt> is the default layout for TensorFlow; another commonly used layout is
<tt class="docutils literal">NCHW</tt>, because it's the format preferred by NVIDIA's DNN libraries. The code
samples here follow the default.</p>
</div>
<div class="section" id="depthwise-convolution">
<h2>Depthwise convolution</h2>
<p>Depthwise convolutions are a variation on the operation discussed so far. In the
regular 2D convolution performed over multiple input channels, the filter is as
deep as the input and lets us freely mix channels to generate each element in
the output. Depthwise convolutions don't do that - each channel is kept separate
- hence the name <em>depthwise</em>. Here's a diagram to help explain how that works:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/2018/conv2d-depthwise.svg" style="width: 500px;" type="image/svg+xml">
Depthwise 2D convolution</object>
<p>There are three conceptual stages here:</p>
<ol class="arabic simple">
<li>Split the input into channels, and split the filter into channels (the number
of channels between input and filter must match).</li>
<li>For each of the channels, convolve the input with the corresponding filter,
producing an output tensor (2D).</li>
<li>Stack the output tensors back together.</li>
</ol>
<p>Here's the code implementing it:</p>
<div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">depthwise_conv2d</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">w</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;Two-dimensional depthwise convolution.</span>
<span class="sd"> Uses SAME padding with 0s, a stride of 1 and no dilation. A single output</span>
<span class="sd"> channel is used per input channel (channel_multiplier=1).</span>
<span class="sd"> input: input array with shape (height, width, in_depth)</span>
<span class="sd"> w: filter array with shape (fd, fd, in_depth)</span>
<span class="sd"> Returns a result with shape (height, width, in_depth).</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">assert</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="ow">and</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">%</span> <span class="mi">2</span> <span class="o">==</span> <span class="mi">1</span>
<span class="n">padw</span> <span class="o">=</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">//</span> <span class="mi">2</span>
<span class="n">padded_input</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">pad</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span>
<span class="n">pad_width</span><span class="o">=</span><span class="p">((</span><span class="n">padw</span><span class="p">,</span> <span class="n">padw</span><span class="p">),</span> <span class="p">(</span><span class="n">padw</span><span class="p">,</span> <span class="n">padw</span><span class="p">),</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">)),</span>
<span class="n">mode</span><span class="o">=</span><span class="s1">&#39;constant&#39;</span><span class="p">,</span>
<span class="n">constant_values</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">,</span> <span class="n">in_depth</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">shape</span>
<span class="k">assert</span> <span class="n">in_depth</span> <span class="o">==</span> <span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">,</span> <span class="n">in_depth</span><span class="p">))</span>
<span class="k">for</span> <span class="n">c</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">in_depth</span><span class="p">):</span>
<span class="c1"># For each input channel separately, apply its corresponsing filter</span>
<span class="c1"># to the input.</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">height</span><span class="p">):</span>
<span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">width</span><span class="p">):</span>
<span class="k">for</span> <span class="n">fi</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]):</span>
<span class="k">for</span> <span class="n">fj</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">w</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]):</span>
<span class="n">w_element</span> <span class="o">=</span> <span class="n">w</span><span class="p">[</span><span class="n">fi</span><span class="p">,</span> <span class="n">fj</span><span class="p">,</span> <span class="n">c</span><span class="p">]</span>
<span class="n">output</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">,</span> <span class="n">c</span><span class="p">]</span> <span class="o">+=</span> <span class="p">(</span>
<span class="n">padded_input</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="n">fi</span><span class="p">,</span> <span class="n">j</span> <span class="o">+</span> <span class="n">fj</span><span class="p">,</span> <span class="n">c</span><span class="p">]</span> <span class="o">*</span> <span class="n">w_element</span><span class="p">)</span>
<span class="k">return</span> <span class="n">output</span>
</pre></div>
<p>In TensorFlow, the corresponding op is <tt class="docutils literal">tf.nn.depthwise_conv2d</tt>; this op has
the notion of <em>channel multiplier</em> which lets us compute multiple outputs for
each input channel (somewhat like the number of output channels concept in
<tt class="docutils literal">conv2d</tt>).</p>
</div>
<div class="section" id="depthwise-separable-convolution">
<h2>Depthwise separable convolution</h2>
<p>The depthwise convolution shown above is more commonly used in combination with
an additional step to mix in the channels - <em>depthwise separable convolution</em>
<a class="footnote-reference" href="#id2" id="id1">[1]</a>:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/2018/conv2d-depthwise-separable.svg" style="width: 500px;" type="image/svg+xml">
Depthwise separable convolution</object>
<p>After completing the depthwise convolution, and additional step is performed: a
1x1 convolution across channels. This is exactly the same operation as the
&quot;convolution in 3 dimensions discussed earlier&quot; - just with a 1x1 spatial
filter. This step can be repeated multiple times for different output channels.
The output channels all take the output of the depthwise step and mix it up
with different 1x1 convolutions. Here's the implementation:</p>
<div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">separable_conv2d</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">w_depth</span><span class="p">,</span> <span class="n">w_pointwise</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;Depthwise separable convolution.</span>
<span class="sd"> Performs 2d depthwise convolution with w_depth, and then applies a pointwise</span>
<span class="sd"> 1x1 convolution with w_pointwise on the result.</span>
<span class="sd"> Uses SAME padding with 0s, a stride of 1 and no dilation. A single output</span>
<span class="sd"> channel is used per input channel (channel_multiplier=1) in w_depth.</span>
<span class="sd"> input: input array with shape (height, width, in_depth)</span>
<span class="sd"> w_depth: depthwise filter array with shape (fd, fd, in_depth)</span>
<span class="sd"> w_pointwise: pointwise filter array with shape (in_depth, out_depth)</span>
<span class="sd"> Returns a result with shape (height, width, out_depth).</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="c1"># First run the depthwise convolution. Its result has the same shape as</span>
<span class="c1"># input.</span>
<span class="n">depthwise_result</span> <span class="o">=</span> <span class="n">depthwise_conv2d</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">w_depth</span><span class="p">)</span>
<span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">,</span> <span class="n">in_depth</span> <span class="o">=</span> <span class="n">depthwise_result</span><span class="o">.</span><span class="n">shape</span>
<span class="k">assert</span> <span class="n">in_depth</span> <span class="o">==</span> <span class="n">w_pointwise</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">out_depth</span> <span class="o">=</span> <span class="n">w_pointwise</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">,</span> <span class="n">out_depth</span><span class="p">))</span>
<span class="k">for</span> <span class="n">out_c</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">out_depth</span><span class="p">):</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">height</span><span class="p">):</span>
<span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">width</span><span class="p">):</span>
<span class="k">for</span> <span class="n">c</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">in_depth</span><span class="p">):</span>
<span class="n">w_element</span> <span class="o">=</span> <span class="n">w_pointwise</span><span class="p">[</span><span class="n">c</span><span class="p">,</span> <span class="n">out_c</span><span class="p">]</span>
<span class="n">output</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">,</span> <span class="n">out_c</span><span class="p">]</span> <span class="o">+=</span> <span class="n">depthwise_result</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">,</span> <span class="n">c</span><span class="p">]</span> <span class="o">*</span> <span class="n">w_element</span>
<span class="k">return</span> <span class="n">output</span>
</pre></div>
<p>In TensorFlow, this op is called <tt class="docutils literal">tf.nn.separable_conv2d</tt>. Similarly to our
implementation it takes two different filter parameters: <tt class="docutils literal">depthwise_filter</tt>
for the depthwise step and <tt class="docutils literal">pointwise_filter</tt> for the mixing step.</p>
<p>Depthwise separable convolutions have become popular in DNN models recently, for
two reasons:</p>
<ol class="arabic simple">
<li>They have fewer parameters than &quot;regular&quot; convolutional layers, and thus are
less prone to overfitting.</li>
<li>With fewer parameters, they also require less operations to compute, and thus
are cheaper and faster.</li>
</ol>
<p>Let's examine the difference between the number of parameters first. We'll start
with some definitions:</p>
<ul class="simple">
<li><tt class="docutils literal">S</tt>: spatial dimension - width and height, assuming square inputs.</li>
<li><tt class="docutils literal">F</tt>: filter width and height, assuming square filter.</li>
<li><tt class="docutils literal">inC</tt>: number of input channels.</li>
<li><tt class="docutils literal">outC</tt>: number of output channels.</li>
</ul>
<p>We also assume <tt class="docutils literal">SAME</tt> padding as discussed above, so that the spatial size
of the output matches the input.</p>
<p>In a regular convolution there are <tt class="docutils literal">F*F*inC*outC</tt> parameters, because every
filter is 3D and there's one such filter per output channel.</p>
<p>In depthwise separable convolutions there are <tt class="docutils literal">F*F*inC</tt> parameters for the
depthwise part, and then <tt class="docutils literal">inC*outC</tt> parameters for the mixing part. It should
be obvious that for a non-trivial <tt class="docutils literal">outC</tt>, the sum of these two is significanly
smaller than <tt class="docutils literal">F*F*inC*outC</tt>.</p>
<p>Now on to computational cost. For a regular convolution, we perform <tt class="docutils literal">F*F*inC</tt>
operations at each position of the input (to compute the 2D convolution over 3
dimensions). For the whole input, the number of computations is thus
<tt class="docutils literal">F*F*inC*S*S</tt> and taking all the output channels we get <tt class="docutils literal">F*F*inC*S*S*outC</tt>.</p>
<p>For depthwise separable convolutions we need <tt class="docutils literal">F*F*inC*S*S*</tt> operations for
the depthwise part; then we need <tt class="docutils literal">S*S*inC*outC</tt> operations for the mixing
part. Let's use some real numbers to get a feel for the difference:</p>
<p>We'll assume <tt class="docutils literal">S=128</tt>, <tt class="docutils literal">F=3</tt>, <tt class="docutils literal">inC=3</tt>, <tt class="docutils literal">outC=16</tt>. For regular
convolution:</p>
<ul class="simple">
<li>Parameters: <tt class="docutils literal">3*3*3*16 = 432</tt></li>
<li>Computation cost: <tt class="docutils literal">3*3*3*128*128*16 = ~7e6</tt></li>
</ul>
<p>For depthwise separable convolution:</p>
<ul class="simple">
<li>Parameters: <tt class="docutils literal">3*3*3+3*16 = 75</tt></li>
<li>Computation cost: <tt class="docutils literal">3*3*3*128*128+128*128*3*16 = ~1.2e6</tt></li>
</ul>
<hr class="docutils" />
<table class="docutils footnote" frame="void" id="id2" rules="none">
<colgroup><col class="label" /><col /></colgroup>
<tbody valign="top">
<tr><td class="label"><a class="fn-backref" href="#id1">[1]</a></td><td>The term <em>separable</em> comes from image processing, where
<em>spatially separable convolutions</em> are sometimes used to save on
computation resources. A spatial convolution is separable when the 2D
convolution filter can be expressed as an outer product of two vectors.
This lets us compute some 2D convolutions more cheaply. In the case of
DNNs, the spatial filter is not necessarily separable but the channel
dimension is separable from the spatial dimensions.</td></tr>
</tbody>
</table>
</div>
Logistic regression2016-11-02T05:45:00-07:002016-11-02T05:45:00-07:00Eli Benderskytag:eli.thegreenplace.net,2016-11-02:/2016/logistic-regression/<p>This article covers logistic regression - arguably the simplest classification
model in machine learning; it starts with basic binary classification, and ends
up with some techniques for multinomial classification (selecting between
multiple possibilities). The final examples using the softmax function can also
be viewed as an example of a single-layer fully …</p><p>This article covers logistic regression - arguably the simplest classification
model in machine learning; it starts with basic binary classification, and ends
up with some techniques for multinomial classification (selecting between
multiple possibilities). The final examples using the softmax function can also
be viewed as an example of a single-layer fully connected neural network.</p>
<p>This article is the theoretical part; in addition, there's quite a bit of
accompanying code <a class="reference external" href="https://github.com/eliben/deep-learning-samples/tree/master/logistic-regression">here</a>.
All the models discussed in the article are implemented from scratch in Python
using only Numpy.</p>
<div class="section" id="linear-model-for-binary-classification">
<h2>Linear model for binary classification</h2>
<p>Using a linear model for binary classification is very similar to <a class="reference external" href="http://eli.thegreenplace.net/2016/linear-regression/">linear
regression</a>, except that
we expect a binary (yes/no) answer rather than a numeric answer.</p>
<p>We want to come up with a parameter vector <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" />, such that for every
data vector <strong>x</strong> we can compute <a class="footnote-reference" href="#id10" id="id1">[1]</a>:</p>
<img alt="\[\hat{y}(x) = \theta_0 x_0 + \theta_1 x_1 + \cdots + \theta_n x_n\]" class="align-center" src="https://eli.thegreenplace.net/images/math/ae682f9fda97c28c8e100c87aecad635c7c1d96c.png" style="height: 18px;" />
<p>And then make a binary decision based on the value of <img alt="\hat{y}(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/11533fb1b0218620907f5859e6e22aeb65c12cd8.png" style="height: 18px;" />. A
simple way to make a decision is to say &quot;yes&quot; if <img alt="\hat{y}(x)\geq 0" class="valign-m4" src="https://eli.thegreenplace.net/images/math/c30aad52f5af131a89f1a8805e25aa8e354795dc.png" style="height: 18px;" /> and
&quot;no&quot; otherwise. Note that this is arbitrary, as we could flip the condition for
&quot;yes&quot; and for &quot;no&quot;. We could also compare <img alt="\hat{y}(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/11533fb1b0218620907f5859e6e22aeb65c12cd8.png" style="height: 18px;" /> to some value other
than zero, and the model would learn equally well <a class="footnote-reference" href="#id12" id="id2">[2]</a>.</p>
<p>Let's make this more concrete, also assigning numeric values to &quot;yes&quot; and &quot;no&quot;,
which will make some computations simpler later on. For &quot;yes&quot; we'll (again,
arbitrarily) select +1, and for &quot;no&quot; we'll go with -1. So, a linear model for
binary classification is parameterized by some <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" />, such that:</p>
<img alt="\[\hat{y}(x) = \theta_0 x_0 + \theta_1 x_1 + \cdots + \theta_n x_n\]" class="align-center" src="https://eli.thegreenplace.net/images/math/ae682f9fda97c28c8e100c87aecad635c7c1d96c.png" style="height: 18px;" />
<p>And:</p>
<img alt="\[class(x)=\left\{\begin{matrix} +1 &amp;amp; \operatorname{if}\ \hat{y}(x)\geq 0\\ -1 &amp;amp; \operatorname{if}\ \hat{y}(x)&amp;lt; 0 \end{matrix}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/092debeba72a26bd76603bd3ce140fc798e5f692.png" style="height: 43px;" />
<p>It helps seeing a graphical example of how this looks in practice. As usual,
we'll have to stick to low dimensionality if we want to visualize things, so
let's use 2D data points.</p>
<p>Since our data is in 2D, we need a 3D <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> (<img alt="\theta_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/ba6201ddbe2fd0bb66e0704ad8b3c6bdb36f37aa.png" style="height: 15px;" /> for the
bias). Let's pick <img alt="\theta=(4,-0.5, -1)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/6cb259a86870d3bd0a5ad2f839d0515bfc70f0d7.png" style="height: 18px;" />. Plotting <img alt="\hat{y}(x)=\theta \cdot x" class="valign-m4" src="https://eli.thegreenplace.net/images/math/e0a45fd444b0526e19a0f22fb3c264b026fb3bcf.png" style="height: 18px;" /> will give us a plane in 3D, but what we're really interested in is just
to know whether <img alt="\hat{y}(x) \geq 0" class="valign-m4" src="https://eli.thegreenplace.net/images/math/39ad82f3252b80454caa343952948440827f2961.png" style="height: 18px;" />. So we can draw this plane's
intersection with the x/y axis:</p>
<img alt="Line for binary classification" class="align-center" src="https://eli.thegreenplace.net/images/2016/binary-classification-line.png" />
<p>We can play with some sample points to see that everything &quot;to the right&quot; of
the line gives us <img alt="\hat{y}(x) &amp;gt; 0" class="valign-m4" src="https://eli.thegreenplace.net/images/math/d686dc49d4c08e21f67c22cbb42aab2a1f3d3875.png" style="height: 18px;" />, and everything &quot;to the left&quot; of
it gives us <img alt="\hat{y}(x) &amp;lt; 0" class="valign-m4" src="https://eli.thegreenplace.net/images/math/d8a7e77c45cecd8e4ba7c8f7d1f02944e9b55ecf.png" style="height: 18px;" /> <a class="footnote-reference" href="#id13" id="id3">[3]</a>.</p>
</div>
<div class="section" id="loss-functions-for-binary-classification">
<h2>Loss functions for binary classification</h2>
<p>How do we find the right <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> for a classification problem? Similarly
to linear regression, we're going to define a &quot;loss function&quot; and then train a
classifier by minimizing this loss with gradient descent. However, here picking
a good loss function is not as simple - it turns out square loss doesn't work
very well, as we'll see soon.</p>
<p>Let's start by considering the most logical loss function to use for
classification - the number of misclassified data samples. This is called the
0/1 loss, and it's the true measure of how well a classifier works. Say we have
1000 samples, our classifier placed 960 of them in the right category, and
got the wrong answer for the other 40 samples. So the loss would be 40. A better
classifier may get it wrong only 35 times, so its loss would be smaller.</p>
<p>It will be helpful to plot loss functions, so let's add another definition we're
going to be using a lot here: the <em>margin</em>. For a given sample <strong>x</strong>, and its
correct classification <em>y</em>, the margin of classification is
<img alt="m=\hat{y}(x)y" class="valign-m4" src="https://eli.thegreenplace.net/images/math/fc8c312b137c8aafaaebd881836e4332cc14e61f.png" style="height: 18px;" />. Recall that <em>y</em> is either +1 or -1, so the margin
is either <img alt="\hat{y}(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/11533fb1b0218620907f5859e6e22aeb65c12cd8.png" style="height: 18px;" /> or its negation, depending on the correct answer.
Note that the margin is positive when our guess is correct (both
<img alt="\hat{y}(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/11533fb1b0218620907f5859e6e22aeb65c12cd8.png" style="height: 18px;" /> and y have the same sign) and negative when our guess is
wrong. With this in hand, we define 0/1 loss as:</p>
<img alt="\[L_{01}(m) = \mathbb{I}(m \leq 0)\]" class="align-center" src="https://eli.thegreenplace.net/images/math/e9731883ade0db9b166741b2ff53a8167a8e3ffd.png" style="height: 18px;" />
<p>Where <img alt="\mathbb{I}" class="valign-0" src="https://eli.thegreenplace.net/images/math/3dcdffb11a6b55b62a0c9e29d85dd9120f5945f4.png" style="height: 12px;" /> is an <em>indicator function</em> taking the value 1 when its
condition is true and the value 0 otherwise. Here is the plot of <img alt="L_{01}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/3ed6799c7063de4663bdeab8fa126196f41bcd0f.png" style="height: 16px;" />
as a function of margin:</p>
<img alt="0/1 loss for binary classification" class="align-center" src="https://eli.thegreenplace.net/images/2016/binary-01-loss.png" />
<p>Unfortunately, the 0/1 loss is fairly hostile to gradient descent optimization,
since it's not convex. This is easy to see intuitively. Suppose we have some
<img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> that gives us a margin of -1.5. The 0/1 loss for this margin is
1, but how can we improve it? Small nudges to <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> will still give us
a margin very close to -1.5, which results in exactly the same loss. We don't
know which way to nudge <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> since either way we get the same outcome.
In other words, there's no slope to follow here.</p>
<p>That's not to say all is lost. Some work is being done with optimizing 0/1
losses for classification, but this is a bit outside the mainstream of machine
learning. Here's an <a class="reference external" href="http://jmlr.org/proceedings/papers/v28/nguyen13a.pdf">interesting paper</a> that discusses some
approaches. It's fascinating for computer science geeks since it uses
combinatorial search techniques. The rest of this post, however, will use 0/1
loss only as an idealized limit, trying other kinds of loss we can actually
run gradient descent with.</p>
<p>The first such loss that comes to mind is square loss, the same one we use in
linear regression. We'll define the square loss as a function of margin:</p>
<img alt="\[L_2(m) = (m - 1)^2\]" class="align-center" src="https://eli.thegreenplace.net/images/math/ea06356db44999485977e3a7e6ff5e97e617b1bb.png" style="height: 21px;" />
<p>The reason we do this is to get two desired outcomes at important points: at
<img alt="m=1" class="valign-m1" src="https://eli.thegreenplace.net/images/math/002d212eace214d48ccf82c7bc33021b1d9cdb91.png" style="height: 13px;" /> we want the loss to be 0, since this is actually the correct
classification: we only get <img alt="m=1" class="valign-m1" src="https://eli.thegreenplace.net/images/math/002d212eace214d48ccf82c7bc33021b1d9cdb91.png" style="height: 13px;" /> when either both <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/45f0241f56d9823eb2d24a228d7ffe62c5fdcdc2.svg" style="height: 16px;" type="image/svg+xml">y=1</object> and
<object class="valign-m4" data="https://eli.thegreenplace.net/images/math/c5f34fb4e66b84bde15d596cf76efd468983c4d5.svg" style="height: 17px;" type="image/svg+xml">\hat{y}=1</object> or when both <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/ad8ddd3de86ba8af8476af79d20b151a251ec117.svg" style="height: 16px;" type="image/svg+xml">y=-1</object> and <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/4ae2d248963bac1702c9e5e1f1d0769126f0c479.svg" style="height: 17px;" type="image/svg+xml">\hat{y}=-1</object>.</p>
<p>Furthermore, to approximate the 0/1 loss, we want our loss at <img alt="m=0" class="valign-0" src="https://eli.thegreenplace.net/images/math/5e49227d625a223efeaa8d7bc48bb0b87f878bff.png" style="height: 12px;" /> to be
1. Here's a plot of the square loss together with 0/1 loss:</p>
<img alt="0/1 loss and square loss for binary classification" class="align-center" src="https://eli.thegreenplace.net/images/2016/binary-01-with-square-loss.png" />
<p>A couple of problems are immediately apparent with the square loss:</p>
<ol class="arabic simple">
<li>It penalizes correct classification as well, in case the margin is very
positive. This is not something we want! Ideally, we want the loss to be
0 starting with <img alt="m=1" class="valign-m1" src="https://eli.thegreenplace.net/images/math/002d212eace214d48ccf82c7bc33021b1d9cdb91.png" style="height: 13px;" /> and for all subsequent values of <em>m</em>.</li>
<li>It very strongly penalizes outliers. One sample that we misclassified badly
can shift the training too much.</li>
</ol>
<p>We could try to fix these problems by using clamping of some sort, but there is
another loss function which serves as a much better approximation to 0/1 loss.
It's called &quot;hinge loss&quot;:</p>
<img alt="\[L_h(m) = max(0, 1-m)\]" class="align-center" src="https://eli.thegreenplace.net/images/math/dd883f12c7f609fe9256e0e6bb4cfdf319d07844.png" style="height: 18px;" />
<p>And its plot, along with the previously shown losses:</p>
<img alt="0/1 loss, square loss and hinge loss for binary classification" class="align-center" src="https://eli.thegreenplace.net/images/2016/binary-01-with-square-and-hinge-loss.png" />
<p>Note that the hinge loss also matches 0/1 loss on the two important points:
<img alt="m=0" class="valign-0" src="https://eli.thegreenplace.net/images/math/5e49227d625a223efeaa8d7bc48bb0b87f878bff.png" style="height: 12px;" /> and <img alt="m=1" class="valign-m1" src="https://eli.thegreenplace.net/images/math/002d212eace214d48ccf82c7bc33021b1d9cdb91.png" style="height: 13px;" />. It also has some nice properties:</p>
<ol class="arabic simple">
<li>It doesn't penalize correct classification after <img alt="m=1" class="valign-m1" src="https://eli.thegreenplace.net/images/math/002d212eace214d48ccf82c7bc33021b1d9cdb91.png" style="height: 13px;" />.</li>
<li>It penalizes incorrect classifications, but not as much as square loss.</li>
<li>It's convex (at least where it matters - where the loss is nonzero)! If we
get <object class="valign-m1" data="https://eli.thegreenplace.net/images/math/5abbd129a48c53a04b0caa6eef4d760329f02149.svg" style="height: 14px;" type="image/svg+xml">m=-1.5</object> we can actually examine the loss in its very close
vicinity and find a slope we can use to improve the loss. So, unlike 0/1
loss, it's amenable to gradient descent optimization.</li>
</ol>
<p>There are other loss functions used to train binary classifiers, such as log
loss, but I will leave them out of this post.</p>
<p>This is a good place to mention that hinge loss leads naturally to <a class="reference external" href="https://en.wikipedia.org/wiki/Support_vector_machine#SVM_and_the_hinge_loss">SVMs</a>
(support vector machines), an interesting technique I'll leave for some other
time.</p>
</div>
<div class="section" id="finding-a-classifier-with-gradient-descent">
<h2>Finding a classifier with gradient descent</h2>
<p>With a loss function in hand, we can use <a class="reference external" href="http://eli.thegreenplace.net/2016/understanding-gradient-descent/">gradient descent</a> to find a
good classifier for some data. The procedure is very similar to what we've been
doing for linear regression:</p>
<p>Given a loss function, we compute the loss gradient with respect to each
<img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> and update <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> for the next step:</p>
<img alt="\[\theta_{j}=\theta_{j}-\eta\frac{\partial L}{\partial \theta_{j}}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/561a940034503fe1bb00e86c90ac130cb351d73b.png" style="height: 42px;" />
<p>Where <img alt="\eta" class="valign-m4" src="https://eli.thegreenplace.net/images/math/2899aeb886ad0fa72652bffd5511e452aaf084ab.png" style="height: 12px;" /> is the learning rate.</p>
</div>
<div class="section" id="computing-gradients-for-our-loss-functions-with-regularization">
<h2>Computing gradients for our loss functions, with regularization</h2>
<p>The only missing part remaining is computing the gradients for the square and
loss hinge functions we've defined. In addition, I'm going to add &quot;<img alt="L_2" class="valign-m3" src="https://eli.thegreenplace.net/images/math/0d2398f5890edff3f40f1686fc3b51528209bf9b.png" style="height: 15px;" />
regularization&quot; to the loss as a means to prevent overfitting for the training
data. <a class="reference external" href="https://en.wikipedia.org/wiki/Regularization_(mathematics)">Regularization</a> is an important
component of the learning algorithm. <img alt="L_2" class="valign-m3" src="https://eli.thegreenplace.net/images/math/0d2398f5890edff3f40f1686fc3b51528209bf9b.png" style="height: 15px;" /> regularization adds the sum of
the squares of all parameters to the loss, and thus &quot;tries&quot; to keep parameters
low. This way, we don't end up over-emphasizing one or a group of parameters
over the others.</p>
<p>Here is square loss with regularization <a class="footnote-reference" href="#id14" id="id4">[4]</a>:</p>
<img alt="\[L_2=\frac{1}{k}\sum_{i=1}^{k}(m^{(i)}-1)^2+\frac{\beta}{2}\sum_{j=0}^{n}\theta_{j}^2\]" class="align-center" src="https://eli.thegreenplace.net/images/math/a9735ff6606b3ad3454c3dfefc541c21b926d541.png" style="height: 56px;" />
<p>This is assuming we have <em>k</em> data points (<em>n+1</em> dimensional) and <em>n+1</em>
parameters (including the special 0th parameter representing the bias). The
total loss is the square loss averaged over all data points, plus the
regularization loss. <img alt="\beta" class="valign-m4" src="https://eli.thegreenplace.net/images/math/6499d503bfc00cadae1440b191c52a8632e2f8c4.png" style="height: 16px;" /> is the regularization &quot;strength&quot; (another
hyper-parameter in the learning algorithm).</p>
<p>Let's start by computing the derivative of the margin. Using
superscripts for indexing data items, recall that:</p>
<img alt="\[m^{(i)}=\hat{y}^{(i)}y^{(i)}=(\theta_0 x_0^{(i)}+\cdots + \theta_n x_n^{(i)})y^{(i)}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/bce48f26ac61cbfd37c8bfbaad0004e5c30ccbbc.png" style="height: 26px;" />
<p>Therefore:</p>
<img alt="\[\frac{\partial m^{(i)}}{\partial \theta_j}=x_j^{(i)}y^{(i)}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/fd79e2321a3ee607dbf3840535d1a8a2327e2117.png" style="height: 47px;" />
<p>With this in hand, it's easy to compute the gradient of <img alt="L_2" class="valign-m3" src="https://eli.thegreenplace.net/images/math/0d2398f5890edff3f40f1686fc3b51528209bf9b.png" style="height: 15px;" /> loss.</p>
<img alt="\[\frac{\partial L_2}{\partial \theta_j}=\frac{2}{k}\sum_{i=1}^{k}(m^{(i)}-1)x_{j}^{(i)}y^{(i)}+\beta\theta_j\]" class="align-center" src="https://eli.thegreenplace.net/images/math/2340ff828a85ab17aa5067b4985cf9da4fd5fae7.png" style="height: 54px;" />
<p>Now let's turn to hinge loss. The total loss for the data set with regularization
is:</p>
<img alt="\[L_h=\frac{1}{k}\sum_{i=1}^{k}max(0, 1-m^{(i)})+\frac{\beta}{2}\sum_{j=0}^{n}\theta_{j}^2\]" class="align-center" src="https://eli.thegreenplace.net/images/math/2ce4a6debf2650ea4c8a1ff24ce8e42f3d370a6e.png" style="height: 56px;" />
<p>The tricky part here is finding the derivative of the <img alt="max" class="valign-0" src="https://eli.thegreenplace.net/images/math/0706025b2bbcec1ed8d64822f4eccd96314938d0.png" style="height: 8px;" /> function
with respect to <img alt="\theta_j" class="valign-m6" src="https://eli.thegreenplace.net/images/math/56adcea6f10a3cd4a439536412c7fb690f803bc9.png" style="height: 18px;" />. I find it easier to reason about functions
like <img alt="max" class="valign-0" src="https://eli.thegreenplace.net/images/math/0706025b2bbcec1ed8d64822f4eccd96314938d0.png" style="height: 8px;" /> when the different cases are cleanly separated:</p>
<img alt="\[max(0,1-m^{(i)})=\left\{\begin{matrix} 1-m^{(i)} &amp;amp; \operatorname{if}\ m^{(i)}&amp;lt; 1\\ 0 &amp;amp; \operatorname{if}\ m^{(i)}\geq 1 \end{matrix}\right.\]" class="align-center" src="https://eli.thegreenplace.net/images/math/884d533e1ff8dd51ae43a229bc2f86bc72e82c2a.png" style="height: 46px;" />
<p>We already know the derivative of <img alt="m^{(i)}" class="valign-0" src="https://eli.thegreenplace.net/images/math/0971cbdfca7ab3d5c094d8a8e75c77ccf66e4715.png" style="height: 17px;" /> with respect to
<img alt="\theta_j" class="valign-m6" src="https://eli.thegreenplace.net/images/math/56adcea6f10a3cd4a439536412c7fb690f803bc9.png" style="height: 18px;" />. So it's easy to derive this expression case-by-case:</p>
<img alt="\[\frac{\partial max(0,1-m^{(i)})}{\partial \theta_j}=\left\{\begin{matrix} -x_j^{(i)}y^{(i)} &amp;amp; \operatorname{if}\ m^{(i)}&amp;lt; 1\\ 0 &amp;amp; \operatorname{if}\ m^{(i)}\geq 1 \end{matrix}\right.\]" class="align-center" src="https://eli.thegreenplace.net/images/math/4feb3f18ab008352c513de8508c4e8f877510167.png" style="height: 54px;" />
<p>And the overall gradient of the hinge loss is:</p>
<img alt="\[\frac{\partial L_h}{\partial \theta_j}=\frac{1}{k}\sum_{i=1}^{k}\frac{\partial max(0,1-m^{(i)})}{\partial \theta_j}+\beta\theta_j\]" class="align-center" src="https://eli.thegreenplace.net/images/math/d3113e543be93630457f9501379fe0b6956d9342.png" style="height: 54px;" />
</div>
<div class="section" id="experiments-with-synthetic-data">
<h2>Experiments with synthetic data</h2>
<p>Let's see an example of learning binary classifier in action. <a class="reference external" href="https://github.com/eliben/deep-learning-samples/blob/master/logistic-regression/simple_binary_classifier.py">This code sample</a>
generates some synthetic data in two dimensions and then uses the approach
described so far in the post to train a binary classifier. Here's a sample data
set:</p>
<img alt="Synthetic data for binary classification" class="align-center" src="https://eli.thegreenplace.net/images/2016/synthetic-data.png" />
<p>The data points for which the correct answer is positive (<em>y=1</em>) are the green
crosses; the ones for which the correct answer is negative (<em>y=-1</em>) are the red
dots. Note that I include a small number of negative outliers (red dots where
we'd expect only green crosses to be) to test the classifier on realistic,
imperfect data.</p>
<p>The sample code can use combinatorial search to find a &quot;best&quot; set of parameters
that results in the lowest 0/1 loss - the lowest number of misclassified data
items. Note that misclassifying some items in this data set is inevitable (with
a linear classifier), because of the outliers. Here is the contour line showing
how the classification decision is made with parameters found by doing the
combinatorial search:</p>
<img alt="Synthetic data for binary classification with only 0/1 loss" class="align-center" src="https://eli.thegreenplace.net/images/2016/synthetic-data-only-01-loss.png" />
<p>The 0/1 loss - number of misclassified data items - for this set of parameters
is 20 out of 400 data items (95% correct prediction rate).</p>
<p>Next, the code trains a classifier using square loss, and another using hinge
loss. I'm not using regularization for this data set, since with only 3
parameters there can't be too much selective bias between them; in other words,
<img alt="\beta=0" class="valign-m4" src="https://eli.thegreenplace.net/images/math/3bb1ac87ba8d8d0c95fd43b91640c0b96f8e72d9.png" style="height: 16px;" />.</p>
<p>A classifier trained with square loss misclassifies 32 items (92% success rate).
A classifier trained with hinge loss misclassifies 26 items (93.5% success rate,
much closer to the &quot;perfect&quot; rate). This is to be expected from the earlier
discussion - square loss very strongly penalizes outliers, which makes it more
skewed on this data <a class="footnote-reference" href="#id15" id="id5">[5]</a>. Here are the contour plots for all losses that
demonstrate this graphically:</p>
<img alt="Synthetic data for binary classification with all losses" class="align-center" src="https://eli.thegreenplace.net/images/2016/synthetic-data-all-losses.png" />
</div>
<div class="section" id="binary-classification-of-mnist-digits">
<h2>Binary classification of MNIST digits</h2>
<p>The <a class="reference external" href="https://en.wikipedia.org/wiki/MNIST_database">MNIST dataset</a> is the
&quot;hello world&quot; of machine learning these days. It's a database of grayscale
images representing handwritten digits, with a correct label for each of these
images.</p>
<p>MNIST is usually employed for the more general multinomial classification
problem - classifying a given data item into one of multiple classes (0 to 9 in
the case of MNIST). We'll address this in a later section.</p>
<p>Here, however, we can experiment with training a binary classifier on MNIST. The
idea is to train a classifier that recognizes some single label. For example, a
classifier answering the question &quot;is this an image of the digit 4&quot;. This is a
binary classification problem, since there are only two answers - &quot;yes&quot; and
&quot;no&quot;.</p>
<p><a class="reference external" href="https://github.com/eliben/deep-learning-samples/blob/master/logistic-regression/mnist_binary_classifier.py">Here's a code sample</a>
that trains such a classifier, using the hinge loss function (since we've already
determined it gives better results than square loss for classification
problems).</p>
<p>It starts by converting the correct labels of MNIST from the numeric range 0-9
to +1 or -1 based on whether the label is 4:</p>
<div class="highlight"><pre><span></span> 0 -1
1 -1
4 1
9 -1
y = 3 ==&gt; -1
8 -1
5 -1
...
4 1
</pre></div>
<p>Then all we have is a binary classification problem, albeit one that is
785-dimensional (784 dimensions for each of the 28x28 pixels in the input
images, plus one for bias). Visualizing the separating contours would be quite
challenging here, but we can now trust the math to know what's going on. Other
than this, the code for gradient descent is <em>exactly the same</em> as for the simple
2D synthetic data shown earlier.</p>
<p>My goal here is not to design a state-of-the-art machine learning architecture,
but to explain how the main parts work. So I didn't tune the model too much, but
it's possible to get 98% accuracy on this binary formulation of MNIST by tuning
the code a bit. While 98% sounds great, recall that we could get 90% just by
saying &quot;no&quot; to every digit :-) Feel free to play with the code to see if you
can get even higher numbers; I don't really expect record-beating numbers from
this model, though, since it's so simple.</p>
</div>
<div class="section" id="logistic-regression-predicting-probabilities">
<h2>Logistic regression - predicting probabilities</h2>
<p>So far the predictors we've been looking at were trained to return a binary
yes/no response; a more useful model would also tell us how sure it is. For
example &quot;what is the chance of rain tomorrow&quot;, rather than &quot;will there be rain,
yes or no&quot;? The probability gives additional information. &quot;90% chance of rain&quot;
vs. &quot;56% chance of rain&quot; gives us additional information over the binary &quot;yes&quot;
for both cases (assuming a 50% cutoff).</p>
<p>Moreover, note that the linear model we've trained actually provides more
information already, giving a numerical answer. We choose to cut it off at 0,
saying yes for positive and no for negative numbers. But some numbers are more
positive (or negative) than others!</p>
<p>Quick thought experiment: can we somehow interpret the response before cutoff as
probability? The main problem here is that probabilities must be in the range
[0, 1], while the linear model gives us an arbitrary real number. We may end up
with negative probabilities or probabilities over 1, neither of which makes much
sense. So we'll want to find some mathematical way to &quot;squish&quot; the result into
the valid [0, 1] range. A common way to do this is to use the logistic function:</p>
<img alt="\[S(z) = \frac{1}{1 + e^{-z}}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/62429be191903e2433ba80f92aaf1044568b831d.png" style="height: 38px;" />
<p>It's also known as the &quot;sigmoid&quot; function because of its S-like shape:</p>
<img alt="Sigmoid function" class="align-center" src="https://eli.thegreenplace.net/images/2016/sigmoid.png" />
<p>We're going to assign <img alt="\hat{y}(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/11533fb1b0218620907f5859e6e22aeb65c12cd8.png" style="height: 18px;" /> into the <em>z</em> variable of the sigmoid,
to get the function:</p>
<img alt="\[S(x) = \frac{1}{1 + e^{-(\theta_0 x_0 + \theta_1 x_1 + \cdots + \theta_n x_n)}}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/2b9f6770ff23ed08c38a9ab5c3b5972f5d002ddb.png" style="height: 39px;" />
<p>And now, the answer we get can be interpreted as a probability between 0 and 1
(without actually touching either asymptote) <a class="footnote-reference" href="#id16" id="id6">[6]</a>. We can train a model to get
as close to 1 as possible for training samples where the true answer is &quot;yes&quot;
and as close to 0 as possible for training samples where the true answer is
&quot;no&quot;. This is called &quot;logistic regression&quot; due to the use of the logistic
function.</p>
</div>
<div class="section" id="training-logistic-regression-with-the-cross-entropy-loss">
<h2>Training logistic regression with the cross-entropy loss</h2>
<p>Earlier in this post, we've seen how a number of loss functions fare for the
binary classifier problem. It turns out that for logistic regression, a very
natural loss function exists that's called <a class="reference external" href="https://en.wikipedia.org/wiki/Cross_entropy#Cross-entropy_error_function_and_logistic_regression">cross-entropy</a>
(also sometimes &quot;logistic loss&quot; or &quot;log loss&quot;). This loss function is derived
from probability and information theory, and its derivation is outside the scope
of this post (check out <a class="reference external" href="http://neuralnetworksanddeeplearning.com/chap3.html">Chapter 3 of Michael Nielsen's online book</a> for a nice intuitive
explanation for why this loss function makes sense).</p>
<p>The formulation of cross-entropy we're going to use here starts from the most
general:</p>
<img alt="\[C(x^{(i)})=-\sum_{t} p^{(i)}_t log(p(y^{(i)}=t|\theta))\]" class="align-center" src="https://eli.thegreenplace.net/images/math/a689c6537836933fae93c80a71cd52ff88703a78.png" style="height: 41px;" />
<p>Let's unravel this definition, step by step. The parenthesized superscript
<img alt="x^{(i)}" class="valign-0" src="https://eli.thegreenplace.net/images/math/233014006c0adbee71ec71ba3a70f22ad1b906a1.png" style="height: 17px;" /> denotes, as usual, the <em>ith</em> input sample. <em>t</em> runs over all the
possible outcomes; <img alt="p_t" class="valign-m4" src="https://eli.thegreenplace.net/images/math/aaf082725869f54161f39f7d9c39fff25c52ac94.png" style="height: 12px;" /> is the actual probability of outcome <em>t</em> and
inside the <em>log</em> we have the conditional probability of this outcome given the
regression parameters - in other words, this is the model's prediction <a class="footnote-reference" href="#id17" id="id7">[7]</a>.</p>
<p>To make this more concrete, in our case we have two possible outcomes in the
training data: either <img alt="y^{(i)}=+1" class="valign-m4" src="https://eli.thegreenplace.net/images/math/3e3495884df85359610f062a6a6428fba7891bb8.png" style="height: 21px;" /> or <img alt="y^{(i)}=-1" class="valign-m4" src="https://eli.thegreenplace.net/images/math/8465f16030efd8eab0982f1e60b8ff292317cdbe.png" style="height: 21px;" />. Given any such
outcome, its &quot;actual&quot; probability is either 1 (when we get this outcome in the
training data) or 0 (when we don't). So for any given sample, one of the
two possible values of <em>t</em> has <img alt="p^{(i)}_t=0" class="valign-m5" src="https://eli.thegreenplace.net/images/math/eedcbf364060646a9b6abfccb8e9dda67a645ff0.png" style="height: 25px;" /> and the other has
<img alt="p^{(i)}=1" class="valign-m4" src="https://eli.thegreenplace.net/images/math/e44b2858aeb1c845d09a851cbea5fdc9c465199e.png" style="height: 21px;" />. Therefore, we get <a class="footnote-reference" href="#id18" id="id8">[8]</a>:</p>
<img alt="\[C(x^{(i)})=\left\{ \begin{matrix} -log(S(x^{(i)}) &amp;amp; \operatorname{if}\ y^{(i)}=+1 \\ -log(1-S(x^{(i)})) &amp;amp; \operatorname{if}\ y^{(i)}=-1 \end{matrix}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/97e3fd44d870673c7a74047b82e30c993a9bec59.png" style="height: 46px;" />
<p>The second possibility has <img alt="-log(1-S(x^{(i)}))" class="valign-m4" src="https://eli.thegreenplace.net/images/math/0b34c17378147a8a82db655998c07649ca71ed39.png" style="height: 21px;" /> because we define
<img alt="S(z)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/61bc9efb9d2c99669df519617ee7daee7670e156.png" style="height: 18px;" /> to predict the probability of the answer being +1; therefore, the
probability of the answer being -1 is <img alt="1-S(z)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/d006e787dd01f802c9c5cb570e39a44cb133b2ce.png" style="height: 18px;" />.</p>
<p>This is the cross-entropy loss for a single sample <img alt="x^{(i)}" class="valign-0" src="https://eli.thegreenplace.net/images/math/233014006c0adbee71ec71ba3a70f22ad1b906a1.png" style="height: 17px;" />. To get the
total loss over a data set, we take the average sample loss, as usual:</p>
<img alt="\[C = \frac{1}{k}\sum_{i=1}^{k} C(x^{(i)})\]" class="align-center" src="https://eli.thegreenplace.net/images/math/642351dc03ee1f11eca503f558971282d5c700e7.png" style="height: 54px;" />
<p>Now let's compute the gradient of this loss function, so we can use it to
train a model. Starting with the +1 case, we have:</p>
<img alt="\[C_{+1} = -log(S(x^{(i)}))\]" class="align-center" src="https://eli.thegreenplace.net/images/math/986418ed0bf4c05742c9a412a0918ed00108d93d.png" style="height: 23px;" />
<p>Then:</p>
<img alt="\[\frac{\partial C_{+1}}{\partial \theta_j} = \frac{-1}{S(x^{(i)})}\frac{\partial S(x^{(i)})}{\partial \theta_j}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/c91fbcf0bb4630112d1efa5adbb8756c25512c68.png" style="height: 47px;" />
<p>Here it will be helpful to use the following identity, which can be easily
verified by going through the math <a class="footnote-reference" href="#id19" id="id9">[9]</a>:</p>
<img alt="\[S&amp;#x27;(z)=S(z)(1-S(z))\]" class="align-center" src="https://eli.thegreenplace.net/images/math/3d880e07d60096518b916e877cd6a8496c39bc37.png" style="height: 20px;" />
<p>Since in our case <img alt="S(x^{(i)})" class="valign-m4" src="https://eli.thegreenplace.net/images/math/8a85ab5b49ac41fe751ac8b29e2f2e76f34650bb.png" style="height: 21px;" /> is actually <img alt="S(\hat{y}(x^{(i})))" class="valign-m4" src="https://eli.thegreenplace.net/images/math/915144b3b0a3b41ff5d71f88e798c702386cfea8.png" style="height: 21px;" />
where <img alt="\hat{y}(x) = \theta_0 x_0 + \theta_1 x_1 + \cdots + \theta_n x_n" class="valign-m4" src="https://eli.thegreenplace.net/images/math/7ad144258d3d91e1ada8fd7f94a7d0b0538faa2d.png" style="height: 18px;" />,
we can apply the chain rule:</p>
<img alt="\[\frac{\partial S(x^{(i)})}{\partial \theta_j}=S(x^{(i)})(1-S(x^{(i)}))x^{(i)}_j\]" class="align-center" src="https://eli.thegreenplace.net/images/math/6bb3f809b570699a74428c137ea715d97b08b58d.png" style="height: 47px;" />
<p>Substituting back into <img alt="\frac{\partial C_{+1}}{\partial \theta_j}" class="valign-m10" src="https://eli.thegreenplace.net/images/math/cc23ed0ff22b532e2ab3fec04117c8c968318629.png" style="height: 29px;" />, we get:</p>
<img alt="\[\begin{align*} \frac{\partial C_{+1}}{\partial \theta_j} &amp;amp;= \frac{-1}{S(x^{(i)})}S(x^{(i)})(1-S(x^{(i)}))x^{(i)}_j \\ &amp;amp;= (S(x^{(i)})-1)x^{(i)}_j \end{align*}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/dc2a24be9f61f7066fbaeb48805bb59c51e445c0.png" style="height: 76px;" />
<p>Similarly, for <img alt="C_{-1}=-log(1-S(x^{(i)}))" class="valign-m4" src="https://eli.thegreenplace.net/images/math/91460ff9e118d56fbbce2f0557bb9208a4d438a4.png" style="height: 21px;" /> we can compute:</p>
<img alt="\[\frac{\partial C_{-1}}{\partial \theta_j} = S(x^{(i)})x^{(i)}_j\]" class="align-center" src="https://eli.thegreenplace.net/images/math/5e5fd85290c89289aacb1486d9f706bd9fca8fdc.png" style="height: 42px;" />
<p>Putting it all together, we find that the contribution of <img alt="x^{(i)}" class="valign-0" src="https://eli.thegreenplace.net/images/math/233014006c0adbee71ec71ba3a70f22ad1b906a1.png" style="height: 17px;" /> to the
gradient of <img alt="\theta_j" class="valign-m6" src="https://eli.thegreenplace.net/images/math/56adcea6f10a3cd4a439536412c7fb690f803bc9.png" style="height: 18px;" /> is:</p>
<img alt="\[\frac{\partial C(x^{(i)})}{\partial \theta_j}=\left\{ \begin{matrix} (S(x^{(i)})-1)x^{(i)}_j &amp;amp; \operatorname{if}\ y^{(i)}=+1 \\ S(x^{(i)})x^{(i)}_j &amp;amp; \operatorname{if}\ y^{(i)}=-1 \end{matrix}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/21cf7ba3c128242b99272a3e47b5ab5c09cb24bf.png" style="height: 56px;" />
<p>Using these formulae, we can train a binary logistic classifier for MNIST that
gives us a probability of some input image being a 4, rather than a yes/no
answer. The <a class="reference external" href="https://github.com/eliben/deep-learning-samples/blob/master/logistic-regression/mnist_binary_classifier.py">binary MNIST code sample</a>
trains either a binary or a logistic classifier using a lot of shared
infrastructure.</p>
<p>The probability gives us more information than just a yes/no answer. Consider,
for example the following image from the MNIST database. When I trained a binary
classifier with hinge loss to recognize the image 4 for 1200 steps, it wrongly
predicted it is a 4:</p>
<img alt="Image of a 9 from MNIST" class="align-center" src="https://eli.thegreenplace.net/images/2016/mnist-test-9740.png" />
<p>The model clearly made a mistake here, but can we know <em>how</em> wrong it was? It
would be hard to know with a binary classifier that only gives us a yes/no
answer. However, when I run a logistic regression model on the same image, it
tells me it is 53% confident this is a 4. Since our cutoff for yes/no is 50%,
this is quite close to the threshold and thus I'd say the model didn't make
a huge mistake here.</p>
</div>
<div class="section" id="multiclass-logistic-regression">
<h2>Multiclass logistic regression</h2>
<p>The previous example is a great transition into the topic of multiclass logistic
regression. Most real-life problems have more than one possible answer and it
would be nice to train models to select the most suitable answer for any given
input.</p>
<p>Our input is still a vector <strong>x</strong>, but now instead of assigning +1 or -1 as the
answer, we'll be assigning one of a fixed set of classes. If there are T
classes, the answer will be a number in the closed range [0..T-1].</p>
<p>The good news is that we can use the building blocks developed in this post to
put together a multiclass classifier. There are many ways to do this; here I'll
focus on two: one-vs-all classification and softmax.</p>
</div>
<div class="section" id="one-vs-all-classification">
<h2>One-vs-all classification</h2>
<p>The One-vs-all (OvA), also known as one-vs-rest (OvR) approach is a natural
extension of binary classification:</p>
<ol class="arabic simple">
<li>For each class <img alt="t\in[0..T-1]" class="valign-m5" src="https://eli.thegreenplace.net/images/math/6eda0dcb5f9805e0e0e4c3d0af82aacdf1295efd.png" style="height: 18px;" /> we train a logistic classifier where we
set <em>t</em> as the &quot;correct&quot; answer, and the other classes as the &quot;incorrect&quot;
answers (+1 and -1 respectively).</li>
<li>The result of each such classifier is the probability that an input sample
belongs to class <em>t</em>.</li>
<li>Given a new input, we run all <em>T</em> classifiers on it and the one that gives
us the highest probability is chosen as the true class of the input.</li>
</ol>
<p>As a completely synthetic example to make this clearer, suppose that <em>T=3</em>. We
take the training data and train 3 logistic regressions. In the first -
<img alt="C_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/33e4cbb170d6026eb67de894c0d01e8702fb065d.png" style="height: 15px;" />, we set 0 as the right answer, 1 and 2 as the wrong answers. In the
second - <img alt="C_1" class="valign-m4" src="https://eli.thegreenplace.net/images/math/c538a6221da718dd38230dcbb6e1a8fb40561f7a.png" style="height: 16px;" /> we set 1 as the right answer, 0 and 2 as the wrong answers.
Finally in the third - <img alt="C_2" class="valign-m3" src="https://eli.thegreenplace.net/images/math/e65b6ebf7cbd7ef19069cc4837331af9d119cfe6.png" style="height: 15px;" /> we set 2 as the right answer, 0 and 1 the
wrong answers.</p>
<p>Now, given a new input vector <strong>x</strong> we run <img alt="C_0(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/5ed83eb3961cbf4855ce46814719658cdc79e5f2.png" style="height: 18px;" />, <img alt="C_1(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/47b77ff17810fb0a0d4f6b86f50d403e8a59a7a7.png" style="height: 18px;" /> and
<img alt="C_2(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/c063fddc4bcdd77e1131dc70ec5b578b5ec887ef.png" style="height: 18px;" />. Each of these gives us the probability of <strong>x</strong> belonging to
the respective class. If we put all the classifiers in a vector, we get:</p>
<img alt="\[C(x)=[C_0(x), C_1(x), C_2(x)]\]" class="align-center" src="https://eli.thegreenplace.net/images/math/9e4c42a11867dda976b1f7b1ac6aaa46b6625ee9.png" style="height: 19px;" />
<p>We pick the class where the probability is highest. Mathematically, we can use
the <a class="reference external" href="https://en.wikipedia.org/wiki/Arg_max">argmax function</a> for this purpose.
<em>argmax</em> returns the index of the maximal element in the given vector. For
example, given:</p>
<img alt="\[C(x)=[0.45, 0.42, 0.09]\]" class="align-center" src="https://eli.thegreenplace.net/images/math/983ab6a6770f41c06b3eb32f811678aab7f6fb5b.png" style="height: 19px;" />
<p>We get:</p>
<img alt="\[\underset{t \in [0..2]}{argmax}(C(x))=0\]" class="align-center" src="https://eli.thegreenplace.net/images/math/1fa779c771c2d0abcaca9a759ab2e99608842f82.png" style="height: 34px;" />
<p>Therefore, the chosen class is 0. These class/index numbers are just labels of
course. They can stand for anything depending on the problem domain: medical
condition names, digits and so on.</p>
<p>This approach doesn't require any additional math over what we've already
covered in this post.
<a class="reference external" href="https://github.com/eliben/deep-learning-samples/blob/master/logistic-regression/mnist_multinomial_classifier.py">This multinomial MNIST classifier code sample</a>
implements it. The error rate it achieves is ~11%, similar to what
<a class="reference external" href="http://yann.lecun.com/exdb/publis/index.html#lecun-98">LeCun's 1998 paper</a>
achieved with a simple linear classifier. Much better than 11% can be done for
MNIST, even with a single-layer linear model. However, my model is very far from
the state of art - there's no preprocessing, no artificially-enlarged training
set, no adaptive learning rate; I didn't even spend time tuning the
hyperparameters (regularization type and constants, learning rate, batch size
etc.) The goal here was just to demonstrate the basics of logistic regression,
not to compete for the state of the art in MNIST.</p>
</div>
<div class="section" id="softmax">
<h2>Softmax</h2>
<p>An alternative to OvA is to use the softmax function. I covered softmax <a class="reference external" href="http://eli.thegreenplace.net/2016/the-softmax-function-and-its-derivative/">in some
detail</a>
previously; just briefly, softmax is a function
<object class="valign-m4" data="https://eli.thegreenplace.net/images/math/1dd52a52398e38c9549b289449de49ba5fbb98b7.svg" style="height: 19px;" type="image/svg+xml">S(\mathbf{a}):\mathbb{R}^{N}\rightarrow \mathbb{R}^{N}</object> such that:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/5470218612381816a8c9a897d43201757560e646.svg" style="height: 46px;" type="image/svg+xml">
\[S_j=\frac{e^{a_j}}{\sum_{k=1}^{N}e^{a_k}}
\qquad \forall j \in 1..N\]</object>
<p>It is very useful for multiclass classification, since it lets us generate
probabilities of the input belonging to one of <em>N</em> classes. Similarly to the
OvA case, here we have to train 10 different parameter vectors, one for each
digit. However, unlike OvA, this training doesn't happen separately but occurs
at the same time. Instead of training a model to find a single parameter vector
each time, we train a parameter <em>matrix</em> once.</p>
<p>The model structure is as follows:</p>
<img alt="Model of softmax logistic regression" class="align-center" src="https://eli.thegreenplace.net/images/2016/softmax-logistic-model.png" />
<p>I've chosen the number of classes to be 10 to reflect MNIST where we have 10
possible digits to assign to every input. In MNIST <em>N</em> is 785 (784 for each of
28x28 pixels in the image, plus one for bias). &quot;Logits&quot; is a common name to
assign to the output of a fully connected layer (which is what we have with the
matrix-vector multiplication in the first stage); the logits are arbitrary real
numbers. The softmax function is responsibe for squeezing them into the range of
probabilities (0, 1) and making sure they all add up to 1.</p>
<p>This diagram shows what happens to a single input as it goes through the model.
In a realistic program, there will be another dimension - the batch dimension,
used to vectorize the computation over a whole batch of inputs.</p>
<p>For training this model, we need a loss function. It turns out cross-entropy is
a very popular loss function to use for softmax. In the
<a class="reference external" href="http://eli.thegreenplace.net/2016/the-softmax-function-and-its-derivative/">softmax post</a>
I also covered how to compute the gradient of cross-entropy on a softmax, so
we're all set to write some code: the <a class="reference external" href="https://github.com/eliben/deep-learning-samples/blob/master/logistic-regression/mnist_softmax_classifier.py">full sample is here</a>.
Running it on MNIST for a couple of minutes produces a 9.5% error rate -
slightly better than the OvA approach, but very close. This is to be expected,
since OvA and softmax compute very similar results (finding the maximal
probability from a set of probabilities), just in a different way. Softmax
regression is much faster, however, since we can vectorize the training for all
10 digits in the same run.</p>
<hr class="docutils" />
<table class="docutils footnote" frame="void" id="id10" rules="none">
<colgroup><col class="label" /><col /></colgroup>
<tbody valign="top">
<tr><td class="label"><a class="fn-backref" href="#id1">[1]</a></td><td>In this post I'm following many of the conventions established in
my post on <a class="reference external" href="http://eli.thegreenplace.net/2016/linear-regression/">linear regression</a>.
In particular, by construction <img alt="x_0=1" class="valign-m3" src="https://eli.thegreenplace.net/images/math/0c1d7f319728a07a57d000f2379b5215e4130147.png" style="height: 15px;" /> so that <img alt="\theta_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/ba6201ddbe2fd0bb66e0704ad8b3c6bdb36f37aa.png" style="height: 15px;" />
is the bias.</td></tr>
</tbody>
</table>
<table class="docutils footnote" frame="void" id="id12" rules="none">
<colgroup><col class="label" /><col /></colgroup>
<tbody valign="top">
<tr><td class="label"><a class="fn-backref" href="#id2">[2]</a></td><td>Why? Because we have the bias as part of the model, so any constant
offset can be absorbed into the learned bias.</td></tr>
</tbody>
</table>
<table class="docutils footnote" frame="void" id="id13" rules="none">
<colgroup><col class="label" /><col /></colgroup>
<tbody valign="top">
<tr><td class="label"><a class="fn-backref" href="#id3">[3]</a></td><td>Note that this outcome is, once again, somewhat arbitrary. We could find
another plane that intersects the x/y axis on the same line, and get
a different classification. For example, if we flip the sign of all the
elements of <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" />, we get the same intersection line. In that
case, however, values &quot;to the right&quot; of the line give us
<img alt="\hat{y}(x) &amp;lt; 0" class="valign-m4" src="https://eli.thegreenplace.net/images/math/d8a7e77c45cecd8e4ba7c8f7d1f02944e9b55ecf.png" style="height: 18px;" />. Since the labels we attach are arbitrary, this
really makes no difference. The only important thing is that we find
a line that separates &quot;true&quot; from &quot;false&quot; samples and be consistent with
our signs and labels throughout the process.</td></tr>
</tbody>
</table>
<table class="docutils footnote" frame="void" id="id14" rules="none">
<colgroup><col class="label" /><col /></colgroup>
<tbody valign="top">
<tr><td class="label"><a class="fn-backref" href="#id4">[4]</a></td><td>Note that both the loss and the regularization are called <img alt="L_2" class="valign-m3" src="https://eli.thegreenplace.net/images/math/0d2398f5890edff3f40f1686fc3b51528209bf9b.png" style="height: 15px;" />.
This is a bit confusing, but both are essentially 2nd norms. It's best
to ignore the name of the regularization factor and just refer to it as
&quot;regularization&quot;. I thought it's important to mention initially as there
are other kinds of regularization being used for machine-learning
algorithms and I wanted to make it clear which one is being used here.</td></tr>
</tbody>
</table>
<table class="docutils footnote" frame="void" id="id15" rules="none">
<colgroup><col class="label" /><col /></colgroup>
<tbody valign="top">
<tr><td class="label"><a class="fn-backref" href="#id5">[5]</a></td><td>As an exercise, play with the code to increase or decrease the number
of outliers (the code makes it easily controllable), and observe the
effects on the misclassification rates of the different loss functions.</td></tr>
</tbody>
</table>
<table class="docutils footnote" frame="void" id="id16" rules="none">
<colgroup><col class="label" /><col /></colgroup>
<tbody valign="top">
<tr><td class="label"><a class="fn-backref" href="#id6">[6]</a></td><td>Note that using the logistic function on the model's output is strictly
a generalization of the binary classifier. We can still make a binary
interpretation of the result if we're so inclined, interpreting
<img alt="S(z) \geq 0.5" class="valign-m4" src="https://eli.thegreenplace.net/images/math/763035b41ff594d664c57d9fcc03c85808d0ccce.png" style="height: 18px;" /> as &quot;yes&quot; and otherwise as &quot;no&quot;. In terms of the
input to <img alt="S(z)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/61bc9efb9d2c99669df519617ee7daee7670e156.png" style="height: 18px;" />, this means &quot;yes&quot; for
<img alt="z=\hat{y}(x) \geq 0" class="valign-m4" src="https://eli.thegreenplace.net/images/math/2599fe02308a2d43e5b29b2f9387ee45c5c67a1b.png" style="height: 18px;" />
which is exactly the formulation we've been using for the binary
classifier.</td></tr>
</tbody>
</table>
<table class="docutils footnote" frame="void" id="id17" rules="none">
<colgroup><col class="label" /><col /></colgroup>
<tbody valign="top">
<tr><td class="label"><a class="fn-backref" href="#id7">[7]</a></td><td>In essence, cross entropy is computed between two probability
distributions. Here, one of them is the &quot;real&quot; distribution observed in
the <em>y</em> data. The other is what we predict given <em>X</em> data and our
regression parameters <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" />. The observed real probability is
either 0 or 1 for any given data item, and the corresponding predicted
probability is our model's output. I also discussed cross-entropy in
the <a class="reference external" href="http://eli.thegreenplace.net/2016/the-softmax-function-and-its-derivative/">post about softmax</a>.</td></tr>
</tbody>
</table>
<table class="docutils footnote" frame="void" id="id18" rules="none">
<colgroup><col class="label" /><col /></colgroup>
<tbody valign="top">
<tr><td class="label"><a class="fn-backref" href="#id8">[8]</a></td><td>Many resources online condense this formula to a single line without
the condition: <img alt="C(x)=-ylog(S(x))-(1-y)log(1-S(x))" class="valign-m4" src="https://eli.thegreenplace.net/images/math/0f9309397d7ef59c72cdb2d861e5532292978ca6.png" style="height: 18px;" />. I'm avoiding
this formulation on purpose, because it requires the possible values of
<em>y</em> to be 0 and 1, not -1 and +1. Although it's possible to play with
constants a bit to reformulate the -1/+1 case in a similarly condensed
fashion, I find the version with the condition more explicit and thus
easier to follow, even if it requires a bit more typing.</td></tr>
</tbody>
</table>
<table class="docutils footnote" frame="void" id="id19" rules="none">
<colgroup><col class="label" /><col /></colgroup>
<tbody valign="top">
<tr><td class="label"><a class="fn-backref" href="#id9">[9]</a></td><td>See also <a class="reference external" href="http://eli.thegreenplace.net/2016/the-chain-rule-of-calculus/">my post</a>
about the chain rule, where this derivation is shown.</td></tr>
</tbody>
</table>
</div>
The Softmax function and its derivative2016-10-18T05:20:00-07:002016-10-18T05:20:00-07:00Eli Benderskytag:eli.thegreenplace.net,2016-10-18:/2016/the-softmax-function-and-its-derivative/<p>The softmax function takes an N-dimensional vector of arbitrary real values and
produces another N-dimensional vector with real values in the range (0, 1) that
add up to 1.0. It maps <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/1dd52a52398e38c9549b289449de49ba5fbb98b7.svg" style="height: 19px;" type="image/svg+xml">S(\mathbf{a}):\mathbb{R}^{N}\rightarrow \mathbb{R}^{N}</object>:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/cd593d87595e496072aebf5100dd87c37c889f25.svg" style="height: 86px;" type="image/svg+xml">
\[S(\mathbf{a}):\begin{bmatrix}
a_1\\
a_2\\
\cdots …</object><p>The softmax function takes an N-dimensional vector of arbitrary real values and
produces another N-dimensional vector with real values in the range (0, 1) that
add up to 1.0. It maps <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/1dd52a52398e38c9549b289449de49ba5fbb98b7.svg" style="height: 19px;" type="image/svg+xml">S(\mathbf{a}):\mathbb{R}^{N}\rightarrow \mathbb{R}^{N}</object>:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/cd593d87595e496072aebf5100dd87c37c889f25.svg" style="height: 86px;" type="image/svg+xml">
\[S(\mathbf{a}):\begin{bmatrix}
a_1\\
a_2\\
\cdots\\
a_N
\end{bmatrix}
\rightarrow
\begin{bmatrix}
S_1\\
S_2\\
\cdots\\
S_N
\end{bmatrix}\]</object>
<p>And the actual per-element formula is:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/5470218612381816a8c9a897d43201757560e646.svg" style="height: 46px;" type="image/svg+xml">
\[S_j=\frac{e^{a_j}}{\sum_{k=1}^{N}e^{a_k}}
\qquad \forall j \in 1..N\]</object>
<p>It's easy to see that <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/cb8b5683be866b4c177c0c319e14085f25bec523.svg" style="height: 18px;" type="image/svg+xml">S_j</object> is always positive (because of the exponents);
moreover, since the numerator appears in the denominator summed up with some
other positive numbers, <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/5a34de9dd188a5a6f758bb0f7daabb58e03045ec.svg" style="height: 18px;" type="image/svg+xml">S_j&lt;1</object>. Therefore, it's in the range (0, 1).</p>
<p>For example, the 3-element vector <tt class="docutils literal">[1.0, 2.0, 3.0]</tt> gets transformed into
<tt class="docutils literal">[0.09, 0.24, 0.67]</tt>. The order of elements by relative size is
preserved, and they add up to 1.0. Let's tweak this vector slightly into:
<tt class="docutils literal">[1.0, 2.0, 5.0]</tt>. We get the output <tt class="docutils literal">[0.02, 0.05, 0.93]</tt>, which still
preserves these properties. Note that as the last element is farther away
from the first two, it's softmax value is dominating the overall slice of size
1.0 in the output. Intuitively, the softmax function is a &quot;soft&quot; version of the
maximum function. Instead of just selecting one maximal element, softmax breaks
the vector up into parts of a whole (1.0) with the maximal input element getting
a proportionally larger chunk, but the other elements getting some of it as well
<a class="footnote-reference" href="#id3" id="id1">[1]</a>.</p>
<div class="section" id="probabilistic-interpretation">
<h2>Probabilistic interpretation</h2>
<p>The properties of softmax (all output values in the range (0, 1) and sum up to
1.0) make it suitable for a probabilistic interpretation that's very useful
in machine learning. In particular, in multiclass classification tasks, we
often want to assign probabilities that our input belongs to one of a set of
output classes.</p>
<p>If we have N output classes, we're looking for an N-vector of probabilities that
sum up to 1; sounds familiar?</p>
<p>We can interpret softmax as follows:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/4510f717b770547b90526c714355f4c81d1b4a50.svg" style="height: 19px;" type="image/svg+xml">
\[S_j=P(y=j|a)\]</object>
<p>Where <em>y</em> is the output class numbered <object class="valign-m1" data="https://eli.thegreenplace.net/images/math/310debdf2f7fe03ad7888e95000c78a0efae5500.svg" style="height: 13px;" type="image/svg+xml">1..N</object>. <em>a</em> is any N-vector. The
most basic example is <a class="reference external" href="http://eli.thegreenplace.net/2016/logistic-regression/">multiclass logistic regression</a>, where an input
vector <em>x</em> is multiplied by a weight matrix <em>W</em>, and the result of this dot
product is fed into a softmax function to produce probabilities. This
architecture is explored in detail later in the post.</p>
<p>It turns out that - from a probabilistic point of view - softmax is optimal
for <a class="reference external" href="https://en.wikipedia.org/wiki/Maximum_likelihood_estimation">maximum-likelihood estimation</a> of the model's
parameters. This is beyond the scope of this post, though. See chapter 5 of
the <a class="reference external" href="http://www.deeplearningbook.org/">&quot;Deep Learning&quot; book</a> for more details.</p>
</div>
<div class="section" id="some-preliminaries-from-vector-calculus">
<h2>Some preliminaries from vector calculus</h2>
<p>Before diving into computing the derivative of softmax, let's start with some
preliminaries from vector calculus.</p>
<p>Softmax is fundamentally a vector function. It takes a vector as input and
produces a vector as output; in other words, it has multiple inputs and multiple
outputs. Therefore, we cannot just ask for &quot;the derivative of softmax&quot;; We
should instead specify:</p>
<ol class="arabic simple">
<li>Which component (output element) of softmax we're seeking to find the
derivative of.</li>
<li>Since softmax has multiple inputs, with respect to which input element the
partial derivative is computed.</li>
</ol>
<p>If this sounds complicated, don't worry. This is exactly why the notation of
vector calculus was developed. What we're looking for is the partial
derivatives:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/2eae0a040f9eb82a2cf0a596c926aca49a3cdb66.svg" style="height: 42px;" type="image/svg+xml">
\[\frac{\partial S_i}{\partial a_j}\]</object>
<p>This is the partial derivative of the i-th output w.r.t. the j-th input. A
shorter way to write it that we'll be using going forward is: <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/ca95d97dc85a733a280ccaab680d01727376e383.svg" style="height: 18px;" type="image/svg+xml">D_{j}S_i</object>.</p>
<p>Since softmax is a <object class="valign-m1" data="https://eli.thegreenplace.net/images/math/91b745aec8f7c3a5501975b040a4aef477c31412.svg" style="height: 16px;" type="image/svg+xml">\mathbb{R}^{N}\rightarrow \mathbb{R}^{N}</object> function,
the most general derivative we compute for it is the Jacobian matrix:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/7af5ba48ed18f62f0fa31b60ba35e8e94054931c.svg" style="height: 76px;" type="image/svg+xml">
\[DS=\begin{bmatrix}
D_1 S_1 &amp; \cdots &amp; D_N S_1 \\
\vdots &amp; \ddots &amp; \vdots \\
D_1 S_N &amp; \cdots &amp; D_N S_N
\end{bmatrix}\]</object>
<p>In ML literature, the term &quot;gradient&quot; is commonly used to stand in for the
derivative. Strictly speaking, gradients are only defined for scalar functions
(such as loss functions in ML); for vector functions like softmax it's imprecise
to talk about a &quot;gradient&quot;; the Jacobian is the fully general derivate of a
vector function, but in most places I'll just be saying &quot;derivative&quot;.</p>
</div>
<div class="section" id="derivative-of-softmax">
<h2>Derivative of softmax</h2>
<p>Let's compute <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/166e309484516e7fea86d27f36f42639ab73b471.svg" style="height: 18px;" type="image/svg+xml">D_j S_i</object> for arbitrary <em>i</em> and <em>j</em>:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/dbee1c4ac839a1eef7202f447f754341eec98904.svg" style="height: 53px;" type="image/svg+xml">
\[D_j S_i=\frac{\partial S_i}{\partial a_j}=
\frac{\partial \frac{e^{a_i}}{\sum_{k=1}^{N}e^{a_k}}}{\partial a_j}\]</object>
<p>We'll be using the quotient rule of derivatives. For
<object class="valign-m9" data="https://eli.thegreenplace.net/images/math/25ee22368ab19a6e8608ac7417cf62e235794e54.svg" style="height: 29px;" type="image/svg+xml">f(x) = \frac{g(x)}{h(x)}</object>:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/c0fd805caf8b7d8336e8c52f2759b3ce73295315.svg" style="height: 43px;" type="image/svg+xml">
\[f&#x27;(x) = \frac{g&#x27;(x)h(x) - h&#x27;(x)g(x)}{[h(x)]^2}\]</object>
<p>In our case, we have:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/167b7392a9d51fbc4016901d48995f091f627e3a.svg" style="height: 82px;" type="image/svg+xml">
\[\begin{align*}
g_i&amp;=e^{a_i} \\
h_i&amp;=\sum_{k=1}^{N}e^{a_k}
\end{align*}\]</object>
<p>Note that no matter which <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/c2d2e987a5cb0df2f497d2dba0da0960fb6fbcc0.svg" style="height: 14px;" type="image/svg+xml">a_j</object> we compute the derivative of <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/969951984c96d748d949ee5e5322f4c2dbb75087.svg" style="height: 16px;" type="image/svg+xml">h_i</object>
for, the answer will always be <object class="valign-0" data="https://eli.thegreenplace.net/images/math/a4c5fca09246e4e7c55473070976f788e032c514.svg" style="height: 12px;" type="image/svg+xml">e^{a_j}</object>. This is not the case for
<object class="valign-m4" data="https://eli.thegreenplace.net/images/math/d141c63d6e5b4ff91ec2936c9b320454461258a0.svg" style="height: 12px;" type="image/svg+xml">g_i</object>, howewer. The derivative of <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/d141c63d6e5b4ff91ec2936c9b320454461258a0.svg" style="height: 12px;" type="image/svg+xml">g_i</object> w.r.t. <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/c2d2e987a5cb0df2f497d2dba0da0960fb6fbcc0.svg" style="height: 14px;" type="image/svg+xml">a_j</object> is
<object class="valign-0" data="https://eli.thegreenplace.net/images/math/a4c5fca09246e4e7c55473070976f788e032c514.svg" style="height: 12px;" type="image/svg+xml">e^{a_j}</object> only if <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/8e4587fc82ce6377530643c5622b41e53cdf3dd3.svg" style="height: 16px;" type="image/svg+xml">i=j</object>, because only then <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/d141c63d6e5b4ff91ec2936c9b320454461258a0.svg" style="height: 12px;" type="image/svg+xml">g_i</object> has
<object class="valign-m6" data="https://eli.thegreenplace.net/images/math/c2d2e987a5cb0df2f497d2dba0da0960fb6fbcc0.svg" style="height: 14px;" type="image/svg+xml">a_j</object> anywhere in it. Otherwise, the derivative is 0.</p>
<p>Going back to our <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/166e309484516e7fea86d27f36f42639ab73b471.svg" style="height: 18px;" type="image/svg+xml">D_j S_i</object>; we'll start with the <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/8e4587fc82ce6377530643c5622b41e53cdf3dd3.svg" style="height: 16px;" type="image/svg+xml">i=j</object> case. Then,
using the quotient rule we have:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/d7489693552878c00ad6788a0c8987416cbb0796.svg" style="height: 53px;" type="image/svg+xml">
\[\frac{\partial \frac{e^{a_i}}{\sum_{k=1}^{N}e^{a_k}}}{\partial a_j}=
\frac{{}e^{a_i}\Sigma-e^{a_j}e^{a_i}}{\Sigma^2}\]</object>
<p>For simplicity <object class="valign-0" data="https://eli.thegreenplace.net/images/math/cb5615b3fcee824f137c372e351ccca3ff3a3292.svg" style="height: 12px;" type="image/svg+xml">\Sigma</object> stands for <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/2c3662fbb97e3b5c528e8b1cdf89e108bfeed206.svg" style="height: 23px;" type="image/svg+xml">\sum_{k=1}^{N}e^{a_k}</object>.
Reordering a bit:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/2634d0ab6532983a88a1f55a33cf6a6719a291ee.svg" style="height: 123px;" type="image/svg+xml">
\[\begin{align*}
\frac{\partial \frac{e^{a_i}}{\sum_{k=1}^{N}e^{a_k}}}{\partial a_j}&amp;=
\frac{e^{a_i}\Sigma-e^{a_j}e^{a_i}}{\Sigma^2}\\
&amp;=\frac{e^{a_i}}{\Sigma}\frac{\Sigma - e^{a_j}}{\Sigma}\\
&amp;=S_i(1-S_j)
\end{align*}\]</object>
<p>The final formula expresses the derivative in terms of <object class="valign-m3" data="https://eli.thegreenplace.net/images/math/3e218c43050832e5df45f69fb2c8b8a01f7f5a52.svg" style="height: 15px;" type="image/svg+xml">S_i</object> itself - a
common trick when functions with exponents are involved.</p>
<p>Similarly, we can do the <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/09eca402f8bc6311cca3a98625e29e75cc336d31.svg" style="height: 17px;" type="image/svg+xml">i\ne j</object> case:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/d788a4ff0e07827862aaf0ded5befbf1665d90cc.svg" style="height: 123px;" type="image/svg+xml">
\[\begin{align*}
\frac{\partial \frac{e^{a_i}}{\sum_{k=1}^{N}e^{a_k}}}{\partial a_j}&amp;=
\frac{0-e^{a_j}e^{a_i}}{\Sigma^2}\\
&amp;=-\frac{e^{a_j}}{\Sigma}\frac{e^{a_i}}{\Sigma}\\
&amp;=-S_j S_i
\end{align*}\]</object>
<p>To summarize:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/f776365373202f727625c0be825d55a2fde47882.svg" style="height: 43px;" type="image/svg+xml">
\[D_j S_i=\left\{\begin{matrix}
S_i(1-S_j) &amp; i=j\\
-S_j S_i &amp; i\ne j
\end{matrix}\right\]</object>
<p>I like seeing this explicit breakdown by cases, but if anyone is taking more
pride in being concise and clever than programmers, it's mathematicians. This
is why you'll find various &quot;condensed&quot; formulations of the same equation in the
literature. One of the most common ones is using the Kronecker delta function:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/ff38cb90472289e31bd7f79c1c85c455d7962cbb.svg" style="height: 43px;" type="image/svg+xml">
\[\delta_{ij}=\left\{\begin{matrix}
1 &amp; i=j\\
0 &amp; i\ne j
\end{matrix}\right\]</object>
<p>To write:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/6e4b626a68faabba991f9d1e83a12c74fcec0e63.svg" style="height: 19px;" type="image/svg+xml">
\[D_j S_i = S_i (\delta_{ij}-S_j)\]</object>
<p>Which is, of course, the same thing. There are a couple of other formulations
one sees in the literature:</p>
<ol class="arabic simple">
<li>Using the matrix formulation of the Jacobian directly to replace
<object class="valign-0" data="https://eli.thegreenplace.net/images/math/3a6a16552e246af497720ffdfe6091b42d2f8938.svg" style="height: 12px;" type="image/svg+xml">\delta</object> with <object class="valign-0" data="https://eli.thegreenplace.net/images/math/ca73ab65568cd125c2d27a22bbd9e863c10b675d.svg" style="height: 12px;" type="image/svg+xml">I</object> - the identity matrix, whose elements are
expressing <object class="valign-0" data="https://eli.thegreenplace.net/images/math/3a6a16552e246af497720ffdfe6091b42d2f8938.svg" style="height: 12px;" type="image/svg+xml">\delta</object> in matrix form.</li>
<li>Using &quot;1&quot; as the function name instead of the Kroneker delta, as follows:
<object class="valign-m6" data="https://eli.thegreenplace.net/images/math/a4fa3293a004c9dc1f5171ddb590ac9cb7178102.svg" style="height: 20px;" type="image/svg+xml">D_j S_i = S_i (1(i=j)-S_j)</object>. Here <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/d9e260212cd116b69ffa42e9c9f824b2bcf6a217.svg" style="height: 18px;" type="image/svg+xml">1(i=j)</object> means the value 1
when <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/8e4587fc82ce6377530643c5622b41e53cdf3dd3.svg" style="height: 16px;" type="image/svg+xml">i=j</object> and the value 0 otherwise.</li>
</ol>
<p>The condensed notation comes useful when we want to compute more complex
derivatives that depend on the softmax derivative; otherwise we'd have to
propagate the condition everywhere.</p>
</div>
<div class="section" id="computing-softmax-and-numerical-stability">
<h2>Computing softmax and numerical stability</h2>
<p>A simple way of computing the softmax function on a given vector in Python is:</p>
<div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;Compute the softmax of vector x.&quot;&quot;&quot;</span>
<span class="n">exps</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="k">return</span> <span class="n">exps</span> <span class="o">/</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">exps</span><span class="p">)</span>
</pre></div>
<p>Let's try it with the sample 3-element vector we've used as an example earlier:</p>
<div class="highlight"><pre><span></span>In [146]: softmax([1, 2, 3])
Out[146]: array([ 0.09003057, 0.24472847, 0.66524096])
</pre></div>
<p>However, if we run this function with larger numbers (or large negative numbers)
we have a problem:</p>
<div class="highlight"><pre><span></span>In [148]: softmax([1000, 2000, 3000])
Out[148]: array([ nan, nan, nan])
</pre></div>
<p>The numerical range of the floating-point numbers used by Numpy
is limited. For <tt class="docutils literal">float64</tt>, the maximal representable number is on the order
of <object class="valign-m1" data="https://eli.thegreenplace.net/images/math/91d9772e2d01d53580c14ba9801ea3303f45cac7.svg" style="height: 16px;" type="image/svg+xml">10^{308}</object>. Exponentiation in the softmax function makes it possible to
easily overshoot this number, even for fairly modest-sized inputs.</p>
<p>A nice way to avoid this problem is by normalizing the inputs to be
not too large or too small, by observing that we can use an arbitrary constant
<em>C</em> as follows:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/21c627f153906b6de2c2723f4a20629a610945ba.svg" style="height: 46px;" type="image/svg+xml">
\[S_j=\frac{e^{a_j}}{\sum_{k=1}^{N}e^{a_k}}=\frac{Ce^{a_j}}{\sum_{k=1}^{N}Ce^{a_k}}\]</object>
<p>And then pushing the constant into the exponent, we get:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/c5b631159b49e84338269e0943e00da2fb7f5d21.svg" style="height: 51px;" type="image/svg+xml">
\[S_j=\frac{e^{a_j+log(C)}}{\sum_{k=1}^{N}e^{a_k+log(C)}}\]</object>
<p>Since <em>C</em> is just an arbitrary constant, we can instead write:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/7ae51c811f1348f4762e3eee1a3cc9e8aad1890c.svg" style="height: 49px;" type="image/svg+xml">
\[S_j=\frac{e^{a_j+D}}{\sum_{k=1}^{N}e^{a_k+D}}\]</object>
<p>Where <em>D</em> is also an arbitrary constant. This formula is equivalent to the
original <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/cb8b5683be866b4c177c0c319e14085f25bec523.svg" style="height: 18px;" type="image/svg+xml">S_j</object> for any <em>D</em>, so we're free to choose a <em>D</em> that will make
our computation better numerically. A good choice is the maximum between all
inputs, negated:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/0433b741304b0b54a6e11be1602b63d4b6326e98.svg" style="height: 18px;" type="image/svg+xml">
\[D=-max(a_1, a_2, \cdots, a_N)\]</object>
<p>This will shift the inputs to a range close to zero, assuming the inputs
themselves are not too far from each other. Crucially, it shifts them all to be
negative (except the maximal <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/c2d2e987a5cb0df2f497d2dba0da0960fb6fbcc0.svg" style="height: 14px;" type="image/svg+xml">a_j</object> which turns into a zero). Negatives
with large exponents &quot;saturate&quot; to zero rather than infinity, so we have a
better chance of avoiding NaNs.</p>
<div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">stablesoftmax</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;Compute the softmax of vector x in a numerically stable way.&quot;&quot;&quot;</span>
<span class="n">shiftx</span> <span class="o">=</span> <span class="n">x</span> <span class="o">-</span> <span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">exps</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">shiftx</span><span class="p">)</span>
<span class="k">return</span> <span class="n">exps</span> <span class="o">/</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">exps</span><span class="p">)</span>
</pre></div>
<p>And now:</p>
<div class="highlight"><pre><span></span>In [150]: stablesoftmax([1000, 2000, 3000])
Out[150]: array([ 0., 0., 1.])
</pre></div>
<p>Note that this is still imperfect, since mathematically softmax would never
really produce a zero, but this is much better than NaNs, and since the distance
between the inputs is very large it's expected to get a result extremely close
to zero anyway.</p>
</div>
<div class="section" id="the-softmax-layer-and-its-derivative">
<h2>The softmax layer and its derivative</h2>
<p>A common use of softmax appears in machine learning, in particular in logistic
regression: the softmax &quot;layer&quot;, wherein we apply softmax to the output of a
fully-connected layer (matrix multiplication):</p>
<img alt="Generic softmax layer diagram" class="align-center" src="https://eli.thegreenplace.net/images/2016/softmax-layer-generic.png" />
<p>In this diagram, we have an input <em>x</em> with N features, and T possible
output classes. The weight matrix <em>W</em> is used to transform <em>x</em> into a vector
with T elements (called &quot;logits&quot; in ML folklore), and the softmax function is
used to &quot;collapse&quot; the logits into a vector of probabilities denoting the
probability of <em>x</em> belonging to each one of the T output classes.</p>
<p>How do we compute the derivative of this &quot;softmax layer&quot; (fully-connected matrix
multiplication followed by softmax)? Using the chain rule, of course! You'll
find any number of derivations of this derivative online, but I want to approach
it from first principles, by carefully applying the <a class="reference external" href="http://eli.thegreenplace.net/2016/the-chain-rule-of-calculus/">multivariate chain rule</a> to the
Jacobians of the functions involved.</p>
<p>An important point before we get started: you may think that <em>x</em> is a natural
variable to compute the derivative for. But it's not. In fact, in machine learning
we usually want to find the best weight matrix <em>W</em>, and thus it is <em>W</em> we want
to update with every step of <a class="reference external" href="http://eli.thegreenplace.net/2016/understanding-gradient-descent">gradient descent</a>. Therefore,
we'll be computing the derivative of this layer w.r.t. <em>W</em>.</p>
<p>Let's start by rewriting this diagram as a composition of vector functions.
First, we have the matrix multiplication, which we denote <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/a0e38e0d2b015bcbf88c39139b08982ae8b9529d.svg" style="height: 18px;" type="image/svg+xml">g(W)</object>. It maps
<object class="valign-m1" data="https://eli.thegreenplace.net/images/math/41cbe7438e5529bcab383579b09d611cd97f0444.svg" style="height: 16px;" type="image/svg+xml">\mathbb{R}^{NT}\rightarrow \mathbb{R}^{T}</object>, because the input (matrix
<em>W</em>) has <em>N times T</em> elements, and the output has T elements.</p>
<p>Next we have the softmax. If we denote the vector of logits as <object class="valign-0" data="https://eli.thegreenplace.net/images/math/b3931f1ce298c536432fd324b3a1ab4337120689.svg" style="height: 12px;" type="image/svg+xml">\lambda</object>,
we have <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/9af2279e6f8c350d3e301ff7ed97ff2d23d2b478.svg" style="height: 19px;" type="image/svg+xml">S(\lambda):\mathbb{R}^{T}\rightarrow \mathbb{R}^{T}</object>. Overall,
we have the function composition:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/10e8a3123f66fe60ae76a3fe83b2a9b73ea3fa57.svg" style="height: 45px;" type="image/svg+xml">
\[\begin{align*}
P(W)&amp;=S(g(W)) \\
&amp;=(S\circ g)(W)
\end{align*}\]</object>
<p>By applying the multivariate chain rule, the Jacobian of <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/f6dd867bfc20ac609f598f54ed834172e0985b0b.svg" style="height: 18px;" type="image/svg+xml">P(W)</object> is:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/80f6a3c715eb405a68968e4c579d3a2b562cfab0.svg" style="height: 18px;" type="image/svg+xml">
\[DP(W)=D(S\circ g)(W)=DS(g(W))\cdot Dg(W)\]</object>
<p>We've computed the Jacobian of <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/7f3a73c41d966d0cade30c5b1fadd35290358a15.svg" style="height: 18px;" type="image/svg+xml">S(a)</object> earlier in this post; what's
remaining is the Jacobian of <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/a0e38e0d2b015bcbf88c39139b08982ae8b9529d.svg" style="height: 18px;" type="image/svg+xml">g(W)</object>. Since <em>g</em> is a very simple function,
computing its Jacobian is easy; the only complication is dealing with the
indices correctly. We have to keep track of which weight each derivative is for.
Since <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/50dd3f482e6e8490b6b54b110c2b8e9018c6a607.svg" style="height: 19px;" type="image/svg+xml">g(W):\mathbb{R}^{NT}\rightarrow \mathbb{R}^{T}</object>, its Jacobian has
<em>T</em> rows and <em>NT</em> columns:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/0d59698eb2307932fdb5a94b7f089da40688f368.svg" style="height: 76px;" type="image/svg+xml">
\[Dg=\begin{bmatrix}
D_1 g_1 &amp; \cdots &amp; D_{NT} g_1 \\
\vdots &amp; \ddots &amp; \vdots \\
D_1 g_T &amp; \cdots &amp; D_{NT} g_T
\end{bmatrix}\]</object>
<p>In a sense, the weight matrix <em>W</em> is &quot;linearized&quot; to a vector of length <em>NT</em>. If
you're familiar with the <a class="reference external" href="http://eli.thegreenplace.net/2015/memory-layout-of-multi-dimensional-arrays">memory layout of multi-dimensional arrays</a>,
it should be easy to understand how it's done. In our case, one simple thing we
can do is linearize it in row-major order, where the first row is consecutive,
followed by the second row, etc. Mathematically, <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/14147644eaa95a20bf61a81af56045475f386a83.svg" style="height: 18px;" type="image/svg+xml">W_{ij}</object> will get column
number <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/ef7b2d987af3c0ceb75381d096c35e8c19085642.svg" style="height: 18px;" type="image/svg+xml">(i-1)N+j</object> in the Jacobian. To populate <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/38b655437da0880bd70168fcbadb50ebdbf46ca5.svg" style="height: 16px;" type="image/svg+xml">Dg</object>, let's recall
what <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/434575851c19a9826fb6be1ca130ffa3243a2a34.svg" style="height: 12px;" type="image/svg+xml">g_1</object> is:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/64a7924d431e1a8e82f753f1f04943ddd619fedb.svg" style="height: 16px;" type="image/svg+xml">
\[g_1=W_{11}x_1+W_{12}x_2+\cdots +W_{1N}x_N\]</object>
<p>Therefore:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/2a64f2f7fdb74ca1e0e3bf86da7e9874e8855928.svg" style="height: 177px;" type="image/svg+xml">
\[\begin{align*}
D_1g_1&amp;=x_1 \\
D_2g_1&amp;=x_2 \\
\cdots \\
D_Ng_1&amp;=x_N \\
D_{N+1}g_1&amp;=0 \\
\cdots \\
D_{NT}g_1&amp;=0
\end{align*}\]</object>
<p>If we follow the same approach to compute <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/eeb76bb8cb07245435e01abcd03dec71f9c051df.svg" style="height: 12px;" type="image/svg+xml">g_2...g_T</object>, we'll get the
Jacobian matrix:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/5b0d880f118ea950dd4c676a9aad2e481d83b0bf.svg" style="height: 76px;" type="image/svg+xml">
\[Dg=\begin{bmatrix}
x_1 &amp; x_2 &amp; \cdots &amp; x_N &amp; \cdots &amp; 0 &amp; 0 &amp; \cdots &amp; 0 \\
\vdots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \ddots &amp; \vdots \\
0 &amp; 0 &amp; \cdots &amp; 0 &amp; \cdots &amp; x_1 &amp; x_2 &amp; \cdots &amp; x_N
\end{bmatrix}\]</object>
<p>Looking at it differently, if we split the index of <em>W</em> to <em>i</em> and <em>j</em>, we get:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/3ca9791a8734377178476d2069bbb072b7e345ac.svg" style="height: 44px;" type="image/svg+xml">
\[\begin{align*}
D_{ij}g_t&amp;=\frac{\partial(W_{t1}x_1+W_{t2}x_2+\cdots+W_{tN}x_N)}{\partial W_{ij}}
&amp;= \left\{\begin{matrix}
x_j &amp; i = t\\
0 &amp; i \ne t
\end{matrix}\right.
\end{align*}\]</object>
<p>This goes into row <em>t</em>, column <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/ef7b2d987af3c0ceb75381d096c35e8c19085642.svg" style="height: 18px;" type="image/svg+xml">(i-1)N+j</object> in the Jacobian matrix.</p>
<p>Finally, to compute the full Jacobian of the softmax layer, we just do a dot
product between <object class="valign-0" data="https://eli.thegreenplace.net/images/math/2ee0d2dca289c3eb54f4cc5e98db8d63e9b0794b.svg" style="height: 12px;" type="image/svg+xml">DS</object> and <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/38b655437da0880bd70168fcbadb50ebdbf46ca5.svg" style="height: 16px;" type="image/svg+xml">Dg</object>. Note that
<object class="valign-m4" data="https://eli.thegreenplace.net/images/math/be12618361f03651d2f459ce0fa3ac82aad3b766.svg" style="height: 19px;" type="image/svg+xml">P(W):\mathbb{R}^{NT}\rightarrow \mathbb{R}^{T}</object>, so the Jacobian
dimensions work out. Since <object class="valign-0" data="https://eli.thegreenplace.net/images/math/2ee0d2dca289c3eb54f4cc5e98db8d63e9b0794b.svg" style="height: 12px;" type="image/svg+xml">DS</object> is <em>TxT</em> and <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/38b655437da0880bd70168fcbadb50ebdbf46ca5.svg" style="height: 16px;" type="image/svg+xml">Dg</object> is <em>TxNT</em>, their
dot product <object class="valign-0" data="https://eli.thegreenplace.net/images/math/9f2059fa4172536236c9acfa22a911f918547e55.svg" style="height: 12px;" type="image/svg+xml">DP</object> is <em>TxNT</em>.</p>
<p>In literature you'll see a much shortened derivation of the derivative of the
softmax layer. That's fine, since the two functions involved are simple and well
known. If we carefully compute a dot product between a row in <object class="valign-0" data="https://eli.thegreenplace.net/images/math/2ee0d2dca289c3eb54f4cc5e98db8d63e9b0794b.svg" style="height: 12px;" type="image/svg+xml">DS</object> and a
column in <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/38b655437da0880bd70168fcbadb50ebdbf46ca5.svg" style="height: 16px;" type="image/svg+xml">Dg</object>:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/699151b941880c8adf5d363048a97c6731482ed6.svg" style="height: 54px;" type="image/svg+xml">
\[D_{ij}P_t=\sum_{k=1}^{T}D_kS_t\cdot D_{ij}g_k\]</object>
<p><object class="valign-m4" data="https://eli.thegreenplace.net/images/math/38b655437da0880bd70168fcbadb50ebdbf46ca5.svg" style="height: 16px;" type="image/svg+xml">Dg</object> is mostly zeros, so the end result is simpler. The only <em>k</em> for which
<object class="valign-m6" data="https://eli.thegreenplace.net/images/math/fca24bbbbf8cac80ccc0253802b13d2749770585.svg" style="height: 18px;" type="image/svg+xml">D_{ij}g_k</object> is nonzero is when <object class="valign-0" data="https://eli.thegreenplace.net/images/math/f4b7e42a4b8c52f40eb9458e68e81c74d70c1c61.svg" style="height: 13px;" type="image/svg+xml">i=k</object>; then it's equal to
<object class="valign-m6" data="https://eli.thegreenplace.net/images/math/73058e43db0f4edc791b10f27f913cbc5d361ab6.svg" style="height: 14px;" type="image/svg+xml">x_j</object>. Therefore:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/7f5cbb15243987230b4fa5741769938a78c9c2f2.svg" style="height: 44px;" type="image/svg+xml">
\[\begin{align*}
D_{ij}P_t&amp;=D_iS_tx_j \\
&amp;=S_t(\delta_{ti}-S_i)x_j
\end{align*}\]</object>
<p>So it's entirely possible to compute the derivative of the softmax layer without
actual Jacobian matrix multiplication; and that's good, because matrix
multiplication is expensive! The reason we can avoid most computation is that
the Jacobian of the fully-connected layer is <em>sparse</em>.</p>
<p>That said, I still felt it's important to show how this derivative comes to life
from first principles based on the composition of Jacobians for the functions
involved. The advantage of this approach is that it works exactly the same for
more complex compositions of functions, where the &quot;closed form&quot; of the derivative
for each element is much harder to compute otherwise.</p>
</div>
<div class="section" id="softmax-and-cross-entropy-loss">
<h2>Softmax and cross-entropy loss</h2>
<p>We've just seen how the softmax function is used as part of a machine learning
network, and how to compute its derivative using the multivariate chain rule.
While we're at it, it's worth to take a look at a loss function that's
commonly used along with softmax for training a network: cross-entropy.</p>
<p><a class="reference external" href="https://en.wikipedia.org/wiki/Cross_entropy">Cross-entropy</a> has an
interesting probabilistic and information-theoretic interpretation, but here
I'll just focus on the mechanics. For two discrete probability distributions <em>p</em>
and <em>q</em>, the cross-entropy function is defined as:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/b26f68a12667ba254facf9815252f52ebf2238d9.svg" style="height: 38px;" type="image/svg+xml">
\[xent(p,q)=-\sum_{k}p(k)log(q(k))\]</object>
<p>Where <em>k</em> goes over all the possible values of the random variable the
distributions are defined for. Specifically, in our case there are <em>T</em> output
classes, so <em>k</em> would go from 1 to <em>T</em>.</p>
<p>If we start from the softmax output <em>P</em> - this is one probability distribution
<a class="footnote-reference" href="#id4" id="id2">[2]</a>. The other probability distribution is the &quot;correct&quot; classification
output, usually denoted by <em>Y</em>. This is a one-hot encoded vector of size <em>T</em>,
where all elements except one are 0.0, and one element is 1.0 - this element
marks the correct class for the data being classified. Let's rephrase the
cross-entropy loss formula for our domain:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/b02b400caa1de3f720f3c51b4891204a85a0d482.svg" style="height: 54px;" type="image/svg+xml">
\[xent(Y, P)=-\sum_{k=1}^{T}Y(k)log(P(k))\]</object>
<p><em>k</em> goes over all the output classes. <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/1801d6549d7f256091d8d687062875facf870a80.svg" style="height: 18px;" type="image/svg+xml">P(k)</object> is the probability of the
class as predicted by the model. <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/369b88be91e9aecb20f084f95946d171096ec2ad.svg" style="height: 18px;" type="image/svg+xml">Y(k)</object> is the &quot;true&quot; probability of the
class as provided by the data. Let's mark the sole index where <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/bf2a1a90dbf5ee8f3e1240a2aff2b64220f3e876.svg" style="height: 18px;" type="image/svg+xml">Y(k)=1.0</object>
by <em>y</em>. Since for all <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/e0e4ad3507e9dde8cc37658b436305ef9eb14ca0.svg" style="height: 17px;" type="image/svg+xml">k\ne y</object> we have <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/9d1a77958eb2fd853cb41001e41efcfa46a099d3.svg" style="height: 18px;" type="image/svg+xml">Y(k)=0</object>, the cross-entropy
formula can be simplified to:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/ca79e575abc3ff07571f9b7bd9ee477c4cac1b7a.svg" style="height: 18px;" type="image/svg+xml">
\[xent(Y, P)=-log(P(y))\]</object>
<p>Actually, let's make it a function of just <em>P</em>, treating <em>y</em> as a constant.
Moreover, since in our case <em>P</em> is a vector, we can express <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/033e08901a43a52bb55ac6d36bcb0cebb8781a4e.svg" style="height: 18px;" type="image/svg+xml">P(y)</object> as
the <em>y</em>-th element of <em>P</em>, or <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/12b5ad2733328bc7191f23d13e05c4e246bb8e26.svg" style="height: 18px;" type="image/svg+xml">P_y</object>:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/e659a9fdd830a347c3aae214b31013eb52c59dc7.svg" style="height: 19px;" type="image/svg+xml">
\[xent(P)=-log(P_y)\]</object>
<p>The Jacobian of <em>xent</em> is a <em>1xT</em> matrix (a row vector), since the output is a
scalar and we have <em>T</em> inputs (the vector <em>P</em> has <em>T</em> elements):</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/2e515cd8235b0385a95e5cbfff5fbcca9a78c631.svg" style="height: 22px;" type="image/svg+xml">
\[Dxent=\begin{bmatrix}
D_1xent &amp; D_2xent &amp; \cdots &amp; D_Txent
\end{bmatrix}\]</object>
<p>Now recall that <em>P</em> can be expressed as a function of input weights:
<object class="valign-m4" data="https://eli.thegreenplace.net/images/math/ad179bfd313d392ad156b509370b8f407e7bd20a.svg" style="height: 18px;" type="image/svg+xml">P(W)=S(g(W))</object>. So we have another function composition:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/ab2f487f02c386d5f532900ffd0927c28ed23b7c.svg" style="height: 18px;" type="image/svg+xml">
\[xent(W)=(xent\circ P)(W)=xent(P(W))\]</object>
<p>And we can, once again, use the multivariate chain rule to find the gradient of
<em>xent</em> w.r.t. <em>W</em>:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/ef938751c387283a7be6461ab0c244ac09db85be.svg" style="height: 18px;" type="image/svg+xml">
\[Dxent(W)=D(xent\circ P)(W)=Dxent(P(W))\cdot DP(W)\]</object>
<p>Let's check that the dimensions of the Jacobian matrices work out. We already
computed <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/3f90f5becd4cc377e50cd6885718feb039eabcc9.svg" style="height: 18px;" type="image/svg+xml">DP(W)</object>; it's <em>TxNT</em>. <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/676107d1b425649d04d82c75a37b391aa99edcf1.svg" style="height: 18px;" type="image/svg+xml">Dxent(P(W))</object> is <em>1xT</em>, so the
resulting Jacobian <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/86f6a5ad8eb3128d2d86c826df3d8831403e64ac.svg" style="height: 18px;" type="image/svg+xml">Dxent(W)</object> is <em>1xNT</em>, which makes sense because the
whole network has one output (the cross-entropy loss - a scalar value) and <em>NT</em>
inputs (the weights).</p>
<p>Here again, there's a straightforward way to find a simple formula for
<object class="valign-m4" data="https://eli.thegreenplace.net/images/math/86f6a5ad8eb3128d2d86c826df3d8831403e64ac.svg" style="height: 18px;" type="image/svg+xml">Dxent(W)</object>, since many elements in the matrix multiplication end up
cancelling out. Note that <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/bb805dc98dfe8b48ded94e4f27a90e74b64371e4.svg" style="height: 18px;" type="image/svg+xml">xent(P)</object> depends only on the <em>y</em>-th element of
<em>P</em>. Therefore, only <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/fb396ced0aaf5ee006e13bb7b0925ba833e01a12.svg" style="height: 18px;" type="image/svg+xml">D_{y}xent</object> is non-zero in the Jacobian:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/c2b2a7aa200023fd2988991212edc5053a85731e.svg" style="height: 22px;" type="image/svg+xml">
\[Dxent=\begin{bmatrix}
0 &amp; 0 &amp; D_{y}xent &amp; \cdots &amp; 0
\end{bmatrix}\]</object>
<p>And <object class="valign-m10" data="https://eli.thegreenplace.net/images/math/ba6bd8869680cb3dab4a5138b909d4f4155ae6a8.svg" style="height: 26px;" type="image/svg+xml">D_{y}xent=-\frac{1}{P_y}</object>. Going back to the full Jacobian
<object class="valign-m4" data="https://eli.thegreenplace.net/images/math/86f6a5ad8eb3128d2d86c826df3d8831403e64ac.svg" style="height: 18px;" type="image/svg+xml">Dxent(W)</object>, we multiply <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/3845e2788792dc92a7072833fa019ce1182f4dbc.svg" style="height: 18px;" type="image/svg+xml">Dxent(P)</object> by each column of <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/31bc3dde97a870d7b85f78efe4d178d38eae0fdb.svg" style="height: 18px;" type="image/svg+xml">D(P(W))</object>
to get each element in the resulting row-vector. Recall that the row vector
represents the whole weight matrix <em>W</em> &quot;linearized&quot; in row-major order. We'll
index into it with <em>i</em> and <em>j</em> for clarity (<object class="valign-m6" data="https://eli.thegreenplace.net/images/math/d82e04a1bce5f5f685c8b6ac356997c847fa95a5.svg" style="height: 18px;" type="image/svg+xml">D_{ij}</object> points to element
number <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/d62c595542e46b7fdf2d4fae8243232cde20dc17.svg" style="height: 16px;" type="image/svg+xml">iN+j</object> in the row vector):</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/b61c7d91efebf65b53f3dada643d86b63d06b6b5.svg" style="height: 54px;" type="image/svg+xml">
\[D_{ij}xent(W)=\sum_{k=1}^{T}D_{k}xent(P)\cdot D_{ij}P_k(W)\]</object>
<p>Since only the <em>y</em>-th element in <object class="valign-m4" data="https://eli.thegreenplace.net/images/math/2b703a6ad534070bbe698f8d8a3a1261b5bb4549.svg" style="height: 18px;" type="image/svg+xml">D_{k}xent(P)</object> is non-zero, we get the
following, also substituting the derivative of the softmax layer from earlier in
the post:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/d7823846aecfb3673906d65e8da6b290b7b2f608.svg" style="height: 68px;" type="image/svg+xml">
\[\begin{align*}
D_{ij}xent(W)&amp;=D_{y}xent(P)\cdot D_{ij}P_y(W) \\
&amp;=-\frac{1}{P_y}\cdot S_y(\delta_{yi}-S_i)x_j
\end{align*}\]</object>
<p>By our definition, <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/2ec0ba51607b94096ad077ab55cc181698494e1a.svg" style="height: 18px;" type="image/svg+xml">P_y=S_y</object>, so we get:</p>
<object class="align-center" data="https://eli.thegreenplace.net/images/math/e417398a544821300668c777d55ad489934d744c.svg" style="height: 96px;" type="image/svg+xml">
\[\begin{align*}
D_{ij}xent(W)&amp;=-\frac{1}{S_y}\cdot S_y(\delta_{yi}-S_i)x_j \\
&amp;=-(\delta_{yi}-S_i)x_j \\
&amp;=(S_i-\delta_{yi})x_j
\end{align*}\]</object>
<p>Once again, even though in this case the end result is nice and clean, it didn't
necessarily have to be so. The formula for <object class="valign-m6" data="https://eli.thegreenplace.net/images/math/b0cfb602e63642cc6146ca57731821d6a9866a1e.svg" style="height: 20px;" type="image/svg+xml">D_{ij}xent(W)</object> could end up
being a fairly involved sum (or sum of sums). The technique of multiplying
Jacobian matrices is oblivious to all this, as the computer can do all the sums
for us. All we have to do is compute the individial Jacobians, which is usually
easier because they are for simpler, non-composed functions. This is the beauty
and utility of the multivariate chain rule.</p>
<hr class="docutils" />
<table class="docutils footnote" frame="void" id="id3" rules="none">
<colgroup><col class="label" /><col /></colgroup>
<tbody valign="top">
<tr><td class="label"><a class="fn-backref" href="#id1">[1]</a></td><td>To play more with sample inputs and Softmax outputs, Michael Nielsen's
online book has a <a class="reference external" href="http://neuralnetworksanddeeplearning.com/chap3.html#softmax">nice interactive Javascript visualization</a> - check
it out.</td></tr>
</tbody>
</table>
<table class="docutils footnote" frame="void" id="id4" rules="none">
<colgroup><col class="label" /><col /></colgroup>
<tbody valign="top">
<tr><td class="label"><a class="fn-backref" href="#id2">[2]</a></td><td>Take a moment to recall that, by definition, the output of the softmax
function is indeed a valid discrete probability distribution.</td></tr>
</tbody>
</table>
</div>
The Chain Rule of Calculus2016-10-10T06:24:00-07:002016-10-10T06:24:00-07:00Eli Benderskytag:eli.thegreenplace.net,2016-10-10:/2016/the-chain-rule-of-calculus/<p>The chain rule of derivatives is, in my opinion, the most important formula in
differential calculus. In this post I want to explain how the chain rule works
for single-variable and multivariate functions, with some interesting examples
along the way.</p>
<div class="section" id="preliminaries-composition-of-functions-and-differentiability">
<h2>Preliminaries: composition of functions and differentiability</h2>
<p>We denote a function …</p></div><p>The chain rule of derivatives is, in my opinion, the most important formula in
differential calculus. In this post I want to explain how the chain rule works
for single-variable and multivariate functions, with some interesting examples
along the way.</p>
<div class="section" id="preliminaries-composition-of-functions-and-differentiability">
<h2>Preliminaries: composition of functions and differentiability</h2>
<p>We denote a function <em>f</em> that maps from the domain <em>X</em> to the codomain <em>Y</em> as
<img alt="f:X \rightarrow Y" class="valign-m4" src="https://eli.thegreenplace.net/images/math/e2f7fcdddf5b36735350a805eeb7cae36895ab1e.png" style="height: 16px;" />. With this <em>f</em> and given <img alt="g:Y \rightarrow Z" class="valign-m4" src="https://eli.thegreenplace.net/images/math/e664f090a7bb62573ae65c910ef7c81e5f086cf6.png" style="height: 16px;" />, we
can define <img alt="g \circ f:X \rightarrow Z" class="valign-m4" src="https://eli.thegreenplace.net/images/math/7f049e7749d289236edefaeb6399795a11afeb44.png" style="height: 16px;" /> as the composition of <em>g</em> and <em>f</em>.
It's defined for <img alt="\forall x \in X" class="valign-m1" src="https://eli.thegreenplace.net/images/math/76545a2a780098fe8c8d581192fa77deccae0848.png" style="height: 14px;" /> as:</p>
<img alt="\[(g \circ f)(x)=g(f(x))\]" class="align-center" src="https://eli.thegreenplace.net/images/math/8b9c8e67c9d2ec7fd3eefce043f380512f1230d3.png" style="height: 18px;" />
<p>In calculus we are usually concerned with the real number domain of some
dimensionality. In the single-variable case, we can think of <img alt="f" class="valign-m4" src="https://eli.thegreenplace.net/images/math/4a0a19218e082a343a1b17e5333409af9d98f0f5.png" style="height: 16px;" /> and
<img alt="g" class="valign-m4" src="https://eli.thegreenplace.net/images/math/54fd1711209fb1c0781092374132c66e79e2241b.png" style="height: 12px;" /> as two regular real-valued functions:
<img alt="f:\mathbb{R} \rightarrow \mathbb{R}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/62ac71ec4fa066b12854a09cddef9ba062924d68.png" style="height: 16px;" /> and
<img alt="g:\mathbb{R} \rightarrow \mathbb{R}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/974c68b8e4454d31c7a2eb389c94bbbfd11ac9da.png" style="height: 16px;" />.</p>
<p>As an example, say <img alt="f(x)=x+1" class="valign-m4" src="https://eli.thegreenplace.net/images/math/027c36c348c172740dd168c66fbfe75d8a8da0c3.png" style="height: 18px;" /> and <img alt="g(x)=x^2" class="valign-m4" src="https://eli.thegreenplace.net/images/math/9b74ba3074b06d93dacb65e40b0082897aa85b3d.png" style="height: 19px;" />. Then:</p>
<img alt="\[(g \circ f)(x)=g(f(x))=g(x+1)=(x+1)^2\]" class="align-center" src="https://eli.thegreenplace.net/images/math/f80635cd447f9f82452529c9289d16811394ea6c.png" style="height: 21px;" />
<p>We can compose the functions the other way around as well:</p>
<img alt="\[(f \circ g)(x)=f(g(x))=f(x^2)=x^2+1\]" class="align-center" src="https://eli.thegreenplace.net/images/math/13c07f9e990c72b1edaf651fccec5c4ad7c0f155.png" style="height: 21px;" />
<p>Obviously, we shouldn't expect composition to be commutative.
It is, however, associative. <img alt="h \circ (g \circ f)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/ef6897e4aad0050d8f69248de3ecd8aaa3ad51de.png" style="height: 18px;" /> and
<img alt="(h \circ g) \circ f" class="valign-m4" src="https://eli.thegreenplace.net/images/math/03ac0c8bb4a409ff1ec1badfee9693280bb2f241.png" style="height: 18px;" /> are equivalent, and both end up being
<img alt="h(g(f(x)))" class="valign-m4" src="https://eli.thegreenplace.net/images/math/bc1a23574da8c77a4fc40d5cbbad2c5e1e95da86.png" style="height: 18px;" /> for <img alt="\forall x \in X" class="valign-m1" src="https://eli.thegreenplace.net/images/math/76545a2a780098fe8c8d581192fa77deccae0848.png" style="height: 14px;" />.</p>
<p>To better handle compositions in one's head it sometimes helps to denote the
independent variable of the outer function (<em>g</em> in our case) by a different
letter (such as <img alt="g(a)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/e7373233d49e18a0882e0dce41d9d6aa26964d6b.png" style="height: 18px;" />). For simple cases it doesn't matter, but I'll
be using this technique occasionally throughout the article. The important thing
to remember here is that the name of the independent variable is completely
arbitrary, and we should always be able to replace it by another name throughout
the formula without any semantic change.</p>
<p>The other preliminary I want to mention is <em>differentiability</em>. The function <em>f</em>
is differentiable at some point <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" /> if the following limit exists:</p>
<img alt="\[\lim_{h \to 0}\frac{f(x_0+h)-f(x_0)}{h}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/34b3ce83a20775cf99b8d204d2b845dfde5727cc.png" style="height: 39px;" />
<p>This limit is then the derivative of <em>f</em> at the point <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" />, or
<img alt="{f}&amp;#x27;(x_0)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/e9b2a1134fcdc276843ee4b522039359117026ee.png" style="height: 18px;" />. Another way to express this is <img alt="\frac{d}{dx}f(x_0)" class="valign-m6" src="https://eli.thegreenplace.net/images/math/b0d6f765abf215972d5dbb982f77f1a83c233066.png" style="height: 22px;" />.
Note that <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" /> can be any arbitrary point on the real line. I sometimes
say something like &quot;<em>f</em> is differentiable at <img alt="g(x_0)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/9d8c0deeca951fab05e474395fbb9fab226cf1f2.png" style="height: 18px;" />&quot;. Here too,
<img alt="g(x_0)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/9d8c0deeca951fab05e474395fbb9fab226cf1f2.png" style="height: 18px;" /> is just a real value that happens to be the value of the function
<em>g</em> at <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" />.</p>
</div>
<div class="section" id="the-single-variable-chain-rule">
<h2>The single-variable chain rule</h2>
<p>The chain rule for single-variable functions states: if <em>g</em> is differentiable at
<img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" /> and <em>f</em> is differentiable at <img alt="g(x_0)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/9d8c0deeca951fab05e474395fbb9fab226cf1f2.png" style="height: 18px;" />, then <img alt="f \circ g" class="valign-m4" src="https://eli.thegreenplace.net/images/math/1247a6ac0bc07bfdbd790831aa70b0b000bad2e4.png" style="height: 16px;" />
is differentiable at <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" /> and its derivative is:</p>
<img alt="\[(f \circ g)&amp;#x27;(x_0)={f}&amp;#x27;(g(x_0)){g}&amp;#x27;(x_0)\]" class="align-center" src="https://eli.thegreenplace.net/images/math/77fb8b77b35d687c20379179b0178ebdd9b2cee1.png" style="height: 20px;" />
<p>The proof of the chain rule is a bit tricky - I left it for the appendix.
However, we can get a better feel for it using some intuition and a couple of
examples.</p>
<p>First, the intuituion. By definition:</p>
<img alt="\[{g}&amp;#x27;(x_0)=\lim_{h \to 0}\frac{g(x_0+h)-g(x_0)}{h}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/cdc3e4a3bced3a7527a15cd76a688d5cc1c06aab.png" style="height: 39px;" />
<p>Multiplying both sides by <em>h</em> we get <a class="footnote-reference" href="#id6" id="id1">[1]</a>:</p>
<img alt="\[{g}&amp;#x27;(x_0)h=\lim_{h \to 0}g(x_0+h)-g(x_0)\]" class="align-center" src="https://eli.thegreenplace.net/images/math/daf52cabed3806986d4c8c29dd60e4ce4fa9247d.png" style="height: 29px;" />
<p>Therefore we can say that when <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" /> changes by some very small amount,
<img alt="g(x_0)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/9d8c0deeca951fab05e474395fbb9fab226cf1f2.png" style="height: 18px;" /> changes by <img alt="{g}&amp;#x27;(x_0)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/fbb4d7279a750f6d80eebeff2e2c25765b304f16.png" style="height: 18px;" /> times that small amount.</p>
<p>Similarly <img alt="{f}&amp;#x27;(a_0)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/d04139c8f65536c3042f975a6966ed49f5f15832.png" style="height: 18px;" /> is the amount of change in the value of <em>f</em> for some
very small change from <img alt="a_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/4a5997da73aadd118038761e69d01e24586bf958.png" style="height: 11px;" />. However, since in our case we compose
<img alt="f \circ g" class="valign-m4" src="https://eli.thegreenplace.net/images/math/1247a6ac0bc07bfdbd790831aa70b0b000bad2e4.png" style="height: 16px;" />, we can say that <img alt="a_0=g(x_0)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/e198d0bc24284bd638c564e0b46edf975d5831d4.png" style="height: 18px;" />, evaluating
<img alt="f(g(x_0))" class="valign-m4" src="https://eli.thegreenplace.net/images/math/04a8bf9b7bd565f95f2cb3e0fe6de123b247e3be.png" style="height: 18px;" />. Suppose we shift <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" /> by a small amount <em>h</em>. This
causes <img alt="g(x_0)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/9d8c0deeca951fab05e474395fbb9fab226cf1f2.png" style="height: 18px;" /> to shift by <img alt="{g}&amp;#x27;(x_0)h" class="valign-m4" src="https://eli.thegreenplace.net/images/math/0e1e11ca765684cf07722c40de2bd86b208ca7c1.png" style="height: 18px;" />. So the input <img alt="a_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/4a5997da73aadd118038761e69d01e24586bf958.png" style="height: 11px;" />
of <em>f</em> shifted by <img alt="{g}&amp;#x27;(x_0)h" class="valign-m4" src="https://eli.thegreenplace.net/images/math/0e1e11ca765684cf07722c40de2bd86b208ca7c1.png" style="height: 18px;" /> - this is still a small amount! Therefore,
the total change in the value of <em>f</em> should be <img alt="{f}&amp;#x27;(g(x_0)){g}&amp;#x27;(x_0)h" class="valign-m4" src="https://eli.thegreenplace.net/images/math/b761eb11c7502754575d0413e7ba040f4a106d0d.png" style="height: 18px;" /> <a class="footnote-reference" href="#id7" id="id2">[2]</a>.</p>
<p>Now, a couple of simple examples. Let's take the function <img alt="f(x)=(x+1)^2" class="valign-m4" src="https://eli.thegreenplace.net/images/math/8db433d3f263ad489e31931ef4a3ddccbd7bece0.png" style="height: 19px;" />.
The idea is to think of this function as a composition of simpler functions.
In this case, one option is: <img alt="g(x)=x+1" class="valign-m4" src="https://eli.thegreenplace.net/images/math/8b2ec3a2221203b211c8a0ed975841cb508b193c.png" style="height: 18px;" /> and then <img alt="w(g(x))=g(x)^2" class="valign-m4" src="https://eli.thegreenplace.net/images/math/2ed44118a1efadf34f5bf169d2ca450246519d1d.png" style="height: 19px;" />,
so the original <em>f</em> is now the composition <img alt="w \circ g" class="valign-m4" src="https://eli.thegreenplace.net/images/math/4edc28332d30c68727a56fbd473126441850c4f0.png" style="height: 12px;" />.</p>
<p>The derivative of this composition is <img alt="{w}&amp;#x27;(g(x)){g}&amp;#x27;(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/05261e48f79f6e8b129bb26dee7fa8a07bcbf876.png" style="height: 18px;" />, or
<img alt="2(x+1)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/39b598cf32e125c7ae18b7623043d5f8133eba78.png" style="height: 18px;" /> since <img alt="{g}&amp;#x27;(x)=1" class="valign-m4" src="https://eli.thegreenplace.net/images/math/36cc7eeced1b708dcf6166dcaae955f733f93ded.png" style="height: 18px;" />. Note that <em>w</em> is differentiable at any
point, so this derivative always exists.</p>
<p>Another example will use a longer chain of composition. Let's differentiate
<img alt="f(x)=sin[(x+1)^2]" class="valign-m5" src="https://eli.thegreenplace.net/images/math/3e3a23e0dd5d4ee105bcca545bddb058917e2c9c.png" style="height: 20px;" />. This is a composition of three functions:</p>
<img alt="\[\begin{align*} g(x)&amp;amp;=x+1\\ w(x)&amp;amp;=x^2\\ v(x)&amp;amp;=sin(x) \end{align*}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/6981c04536025d8e43d07bf9b067252c2028feab.png" style="height: 73px;" />
<p>Function composition is associative, so <em>f</em> can be expressed as either
<img alt="v \circ (w \circ g)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/1c2a8a63ec4fb6e489b0896b544e277823228906.png" style="height: 18px;" /> or <img alt="(v \circ w) \circ g" class="valign-m4" src="https://eli.thegreenplace.net/images/math/be47fface92aa5db8bade3049da31d065ef8244b.png" style="height: 18px;" />. Since we already
know what the derivative of <img alt="w \circ g" class="valign-m4" src="https://eli.thegreenplace.net/images/math/4edc28332d30c68727a56fbd473126441850c4f0.png" style="height: 12px;" /> is, let's use the former:</p>
<img alt="\[\begin{align*} \frac{df(x)}{dx}=\frac{d v(w(g(x)))}{dx}&amp;amp;={v}&amp;#x27;(w(g(x))){w(g(x))}&amp;#x27;(x)\\ &amp;amp;=cos(w(g(x)))2(x+1)\\ &amp;amp;=2cos[(x+1)^2](x+1) \end{align*}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/f63f9a07295583911873238c3ee6e84e8c3722ca.png" style="height: 93px;" />
</div>
<div class="section" id="the-chain-rule-as-a-computational-procedure">
<h2>The chain rule as a computational procedure</h2>
<p>As the last example demonstrates, the chain rule can be applied multiple times
in a single derivation. This makes the chain rule a powerful tool for computing
derivatives of very complex functions, which can be broken up into compositions
of simpler functions. I like to draw a parallel between this process and
programming; a function in a programming language can be seen as a computational
procedure - we have a set of input parameters and we produce outputs. On the
way, several transformations happen that can be expressed mathematically. These
transformations are composed, so their derivatives can be computed naturally
with the chain rule.</p>
<p>This may be somewhat abstract, so let's use another example. We'll compute the
derivative of the Sigmoid function - a very important function in machine
learning:</p>
<img alt="\[S(x)=\frac{1}{1+e^{-x}}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/9a39d0495ce32da5840b76adaf508a0349394c49.png" style="height: 38px;" />
<p>To make the equivalence between functions and computational procedures clearer,
let's think how we'd compute <em>S</em> in Python:</p>
<div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">sigmoid</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="k">return</span> <span class="mi">1</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">math</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="n">x</span><span class="p">))</span>
</pre></div>
<p>This doesn't look much different, but that's just because Python is a high level
language with arbitrarily nested expressions. Its VM (or the CPU in general)
would execute this computation step by step. Let's break it up to be clearer,
assuming we can only apply a single operation at every step:</p>
<div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">sigmoid</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="n">f</span> <span class="o">=</span> <span class="o">-</span><span class="n">x</span>
<span class="n">g</span> <span class="o">=</span> <span class="n">math</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">f</span><span class="p">)</span>
<span class="n">w</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">+</span> <span class="n">g</span>
<span class="n">v</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">w</span>
<span class="k">return</span> <span class="n">v</span>
</pre></div>
<p>I hope you're starting to see the resemblance to our chain rule examples at this
point. Sacrificing some rigor in the notation for the sake of expressiveness, we
can write:</p>
<img alt="\[S&amp;#x27;=v&amp;#x27;(w)w&amp;#x27;(g)g&amp;#x27;(f)f&amp;#x27;(x)\]" class="align-center" src="https://eli.thegreenplace.net/images/math/b3029d842b915e7bf0ea1aa91372ab071dd8b80e.png" style="height: 20px;" />
<p>This is the chain rule applied to <img alt="v \circ (w \circ (g \circ f))" class="valign-m4" src="https://eli.thegreenplace.net/images/math/caaf8ea9ee60bb84d61d422c6dee5d6cd173f0ab.png" style="height: 18px;" />. Solving
this is easy because every single derivative in the chain above is trivial:</p>
<img alt="\[\begin{align*} S&amp;#x27;&amp;amp;=v&amp;#x27;(w)w&amp;#x27;(g)g&amp;#x27;(f)(-1)\\ &amp;amp;=v&amp;#x27;(w)w&amp;#x27;(g)e^{-x}(-1)\\ &amp;amp;=v&amp;#x27;(w)(1)e^{-x}(-1)\\ &amp;amp;=\frac{-1}{(1+e^{-x})^2}e^{-x}(-1)\\ &amp;amp;=\frac{e^{-x}}{(1+e^{-x})^2} \end{align*}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/b987461a2551ca622908f40f791519f3afe3b452.png" style="height: 171px;" />
<p>Now you may be thinking:</p>
<ol class="arabic simple">
<li>Every function computable by a program can be broken down to trivial steps
like our <tt class="docutils literal">sigmoid</tt> above.</li>
<li>Using the chain rule, we can easily find the derivative of such a sequence
of steps... therefore:</li>
<li>We can easily find the derivative of any function computable by a program!!</li>
</ol>
<p>An you'll be right. This is precisely the basis for the technique known as
<a class="reference external" href="https://en.wikipedia.org/wiki/Automatic_differentiation">automatic differentiation</a>, which is widely
used in scienctific computing. The most notable use of automatic differentiation
in recent times is the backpropagation algorithm - an essential backbone of
modern machine learning. I personally find automatic differentiation
fascinating, and will write a more dedicated article about it in the future.</p>
</div>
<div class="section" id="multivariate-chain-rule-general-formulation">
<h2>Multivariate chain rule - general formulation</h2>
<p>So far this article has been looking at functions with a single input and
output: <img alt="f:\mathbb{R} \to \mathbb{R}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/2e28467c90f978580e43c376716981ec5906a01d.png" style="height: 16px;" />. In the most general case of
multi-variate calculus, we're dealing with functions that map from <em>n</em>
dimensions to <em>m</em> dimensions: <img alt="f:\mathbb{R}^{n} \to \mathbb{R}^{m}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/13f219789047343729036279bb11630db317d98d.png" style="height: 16px;" />.
Because every one of the <em>m</em> outputs of <em>f</em> can be considered a separate
function dependent on <em>n</em> variables, it's very natural to deal with such math
using vectors and matrices.</p>
<p>First let's define some notation. We'll consider the outputs of <em>f</em> to be
numbered from 1 to <em>m</em> as <img alt="f_1,f_2 \dots f_m" class="valign-m4" src="https://eli.thegreenplace.net/images/math/93b446c5209263534d09d617bbede21101d6536e.png" style="height: 16px;" />. For each such <img alt="f_i" class="valign-m4" src="https://eli.thegreenplace.net/images/math/68bd0dc647944d362ec8df628a22967b91d82c80.png" style="height: 16px;" />
we can compute its partial derivative by any of the <em>n</em> inputs as:</p>
<img alt="\[D_j f_i(a)=\frac{\partial f_i}{\partial a_j}(a)\]" class="align-center" src="https://eli.thegreenplace.net/images/math/30881b5a92e45259714ba01c7a12fbf8f6c56109.png" style="height: 42px;" />
<p>Where <em>j</em> goes from 1 to <em>n</em> and <em>a</em> is a vector with <em>n</em> components. If <em>f</em>
is differentiable at <em>a</em> <a class="footnote-reference" href="#id8" id="id3">[3]</a> then the derivative of <em>f</em> at <em>a</em> is the <em>Jacobian
matrix</em>:</p>
<img alt="\[Df(a)=\begin{bmatrix} D_1 f_1(a) &amp;amp; \cdots &amp;amp; D_n f_1(a) \\ \vdots &amp;amp; &amp;amp; \vdots \\ D_1 f_m(a) &amp;amp; \cdots &amp;amp; D_n f_m(a) \\ \end{bmatrix}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/ab09367d48e9ef4d8bc2314a60313dec700193af.png" style="height: 76px;" />
<p>The multivariate chain rule states: given <img alt="g:\mathbb{R}^n \to \mathbb{R}^m" class="valign-m4" src="https://eli.thegreenplace.net/images/math/b4b7d25491897b053abf7e48688fada4a85368bd.png" style="height: 16px;" />
and <img alt="f:\mathbb{R}^m \to \mathbb{R}^p" class="valign-m4" src="https://eli.thegreenplace.net/images/math/ac8a6cea4e02e885538fc3ef969c5733e84712f9.png" style="height: 16px;" /> and a point <img alt="a \in \mathbb{R}^n" class="valign-m1" src="https://eli.thegreenplace.net/images/math/43a85f2c59f396fe5c4e2c403a0453c463fcfb0d.png" style="height: 13px;" />,
if <em>g</em> is differentiable at <em>a</em> and <em>f</em> is differentiable at <img alt="g(a)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/e7373233d49e18a0882e0dce41d9d6aa26964d6b.png" style="height: 18px;" /> then
the composition <img alt="f \circ g" class="valign-m4" src="https://eli.thegreenplace.net/images/math/1247a6ac0bc07bfdbd790831aa70b0b000bad2e4.png" style="height: 16px;" /> is differentiable at <em>a</em> and its derivative
is:</p>
<img alt="\[D(f \circ g)(a)=Df(g(a)) \cdot Dg(a)\]" class="align-center" src="https://eli.thegreenplace.net/images/math/00bdefa904bd34df2dfb50cc385e6497c4e5096e.png" style="height: 18px;" />
<p>Which is the matrix multiplication of <img alt="Df(g(a))" class="valign-m4" src="https://eli.thegreenplace.net/images/math/e567730c48bb2f95c258b630b4d6e997043e09ab.png" style="height: 18px;" /> and <img alt="Dg(a)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/2575fc98e794a733a7aa6237fe67246a41e6c8c5.png" style="height: 18px;" /> <a class="footnote-reference" href="#id9" id="id4">[4]</a>.
Intuitively, the multivariate chain rule mirrors the single-variable one (and as
we'll soon see, the latter is just a special case of the former) with
derivatives replaced by derivative matrices. From linear algebra, we represent
linear transformations by matrices, and the composition of two linear transformations
is the product of their matrices. Therefore, since derivative matrices - like
derivatives in one dimension - are a linear approximation to the function, the
chain rule makes sense. This is a really nice connection between linear algebra
and calculus, though a full proof of the multivariate rule is very technical and
outside the scope of this article.</p>
</div>
<div class="section" id="multivariate-chain-rule-examples">
<h2>Multivariate chain rule - examples</h2>
<p>Since the chain rule deals with compositions of functions, it's natural to
present examples from the world of parametric curves and surfaces. For example,
suppose we define <img alt="f(x,y,z)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/c5d72ae6186c76bde08c693d4bfdb85e3201125d.png" style="height: 18px;" /> as a scalar function
<img alt="\mathbb{R}^3 \to \mathbb{R}" class="valign-m1" src="https://eli.thegreenplace.net/images/math/1862a20e93e78e42aafd20106ceabe142def19f1.png" style="height: 16px;" /> giving the temperature
at some point in 3D. Now imagine that we're moving through this 3D space on
a curve defined by a function <img alt="g:\mathbb{R} \to \mathbb{R}^3" class="valign-m4" src="https://eli.thegreenplace.net/images/math/e97099dd54f45a2a71a33d305c517ec97565909d.png" style="height: 19px;" /> which takes
time <em>t</em> and gives the coordinates <img alt="x(t),y(t),z(t)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/4e2bdd3060e49f3494f68f99cb6d2204b2a19e1c.png" style="height: 18px;" /> at that time. We want
to compute how the temperature changes as a function of time <em>t</em> - how do we do
that? Recall that the temprerature is not a direct function of time, but rather
is a function of location, while location <em>is</em> a function of time. Therefore,
we'll want to compose <img alt="f \circ g" class="valign-m4" src="https://eli.thegreenplace.net/images/math/1247a6ac0bc07bfdbd790831aa70b0b000bad2e4.png" style="height: 16px;" />. Here's a concrete example:</p>
<img alt="\[g(t)=\begin{pmatrix} t\\ t^2\\ t^3 \end{pmatrix}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/cdaff94ebfb318ec24f472be470497e28a091c42.png" style="height: 65px;" />
<p>And:</p>
<img alt="\[f\begin{pmatrix} x \\ y \\ z \end{pmatrix}=x^2+xyz+5y\]" class="align-center" src="https://eli.thegreenplace.net/images/math/0a2fc40b06886d3b54628680192d71a3186d9fc7.png" style="height: 65px;" />
<p>If we reformulate <em>x</em>, <em>y</em> and <em>z</em> as functions of <em>t</em>:</p>
<img alt="\[f(x(t),y(t),z(t))=x(t)^2+x(t)y(t)z(t)+5y(5)\]" class="align-center" src="https://eli.thegreenplace.net/images/math/17ff41a51f08ccd71948523bbfa6a6c742b3e81f.png" style="height: 21px;" />
<p>Composing <img alt="f \circ g" class="valign-m4" src="https://eli.thegreenplace.net/images/math/1247a6ac0bc07bfdbd790831aa70b0b000bad2e4.png" style="height: 16px;" />, we get:</p>
<img alt="\[(f \circ g)(t)=f(g(t))=f(t,t^2,t^3)=t^2+t^6+5t^2=6t^2+t^6\]" class="align-center" src="https://eli.thegreenplace.net/images/math/63ad25f62a0e93b1f8175a627aac0a29a88a3cca.png" style="height: 21px;" />
<p>Since this is a simple function, we can find its derivative directly:</p>
<img alt="\[(f \circ g)&amp;#x27;(t)=12t+6t^5\]" class="align-center" src="https://eli.thegreenplace.net/images/math/d1025880b042d304efe08de37eeafde5a8d9231c.png" style="height: 21px;" />
<p>Now let's repeat this exercise using the multivariate chain rule. To compute
<img alt="D(f \circ g)(t)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/c20cc5474ef67f0ec35bddccdc59b72742a864e1.png" style="height: 18px;" /> we need <img alt="Df(g(t))" class="valign-m4" src="https://eli.thegreenplace.net/images/math/ded52fd957c2b251c84052c335523b80a4e3c945.png" style="height: 18px;" /> and <img alt="Dg(t)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/ec8c49e88582659c617e6563375355ede5fe1090.png" style="height: 18px;" />. Let's start
with <img alt="Dg(t)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/ec8c49e88582659c617e6563375355ede5fe1090.png" style="height: 18px;" />. <img alt="g(t)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/851fb8b00904a32dff1c79d40158c7ec9d3d5254.png" style="height: 18px;" /> maps <img alt="\mathbb{R} \to \mathbb{R}^3" class="valign-m1" src="https://eli.thegreenplace.net/images/math/0354b4368db3496b963c21b446ad726b65a0ab90.png" style="height: 16px;" />,
so its Jacobian is a 3-by-1 matrix, or column vector:</p>
<img alt="\[Dg(t)=\begin{bmatrix} 1 \\ 2t \\ 3t^2 \end{bmatrix}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/492d3e9013352e0cd44e3c5721cd0535174fb318.png" style="height: 65px;" />
<p>To compute <img alt="Df(g(t))" class="valign-m4" src="https://eli.thegreenplace.net/images/math/ded52fd957c2b251c84052c335523b80a4e3c945.png" style="height: 18px;" /> let's first find <img alt="Df(x,y,z)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/dab2e6dc478f82ef76bff84080623a27fe214dec.png" style="height: 18px;" />. Since
<img alt="f(x,y,z)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/c5d72ae6186c76bde08c693d4bfdb85e3201125d.png" style="height: 18px;" /> maps <img alt="\mathbb{R}^3 \to \mathbb{R}" class="valign-m1" src="https://eli.thegreenplace.net/images/math/1862a20e93e78e42aafd20106ceabe142def19f1.png" style="height: 16px;" />, its Jacobian is a
1-by-3 matrix, or row vector:</p>
<img alt="\[Df(x,y,z)=\begin{bmatrix} 2x+yz &amp;amp; xz+5 &amp;amp; xy \end{bmatrix}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/e8d650cac68d341d2c99c2641be3d238e516e51c.png" style="height: 22px;" />
<p>To apply the chain rule, we need <img alt="Df(g(t))" class="valign-m4" src="https://eli.thegreenplace.net/images/math/ded52fd957c2b251c84052c335523b80a4e3c945.png" style="height: 18px;" />:</p>
<img alt="\[Df(g(t))=\begin{bmatrix} 2t+t^5 &amp;amp; t^4+5 &amp;amp; t^3 \end{bmatrix}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/b061977c12dcc918a96473939f6dc01eb7ea7847.png" style="height: 22px;" />
<p>Finally, multiplying <img alt="Df(g(t))" class="valign-m4" src="https://eli.thegreenplace.net/images/math/ded52fd957c2b251c84052c335523b80a4e3c945.png" style="height: 18px;" /> by <img alt="Dg(t)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/ec8c49e88582659c617e6563375355ede5fe1090.png" style="height: 18px;" />, we get:</p>
<img alt="\[\begin{align*} D(f \circ g)(t)=Df(g(t)) \cdot Dg(t)&amp;amp;=\begin{bmatrix} 2t+t^5 &amp;amp; t^4+5 &amp;amp; t^3 \end{bmatrix} \cdot \begin{bmatrix} 1 \\ 2t \\ 3t^2 \end{bmatrix}\\ &amp;amp;=2t+t^5+2t^6+10t+3t^5\\ &amp;amp;=12t+6t^5 \end{align*}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/9c5a5fc3e8024f6d1f2364ad5d0433bb530d4987.png" style="height: 118px;" />
<p>Another interesting way to interpret this result for the case where
<img alt="f:\mathbb{R}^3 \to \mathbb{R}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/d307aff95a39ad62cc090e4d6e3bd73b1ffc2b14.png" style="height: 19px;" /> and <img alt="g:\mathbb{R} \to \mathbb{R}^3" class="valign-m4" src="https://eli.thegreenplace.net/images/math/e97099dd54f45a2a71a33d305c517ec97565909d.png" style="height: 19px;" />
is to <a class="reference external" href="http://eli.thegreenplace.net/2016/understanding-gradient-descent">recall that</a>
the directional derivative of <em>f</em> in the direction of some vector <img alt="\vec{v}" class="valign-0" src="https://eli.thegreenplace.net/images/math/39a3a59a8f524cf72620db07b9ba7cdce9fc9391.png" style="height: 13px;" />
is:</p>
<img alt="\[D_{\vec{v}}f=(\nabla f) \cdot \vec{v}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/49933775272512c4c8686d9f9692c8ea01e1c97d.png" style="height: 18px;" />
<p>In our case <img alt="(\nabla f)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/cf1f51ce22cf132c44f5cd65c1c6ada1cce0347f.png" style="height: 18px;" /> is the Jacobian of <em>f</em> (because of its
dimensionality). So if we take <img alt="\vec{v}" class="valign-0" src="https://eli.thegreenplace.net/images/math/39a3a59a8f524cf72620db07b9ba7cdce9fc9391.png" style="height: 13px;" /> to be the vector <img alt="Dg(t)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/ec8c49e88582659c617e6563375355ede5fe1090.png" style="height: 18px;" />,
and evaluate the gradient at <img alt="g(t)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/851fb8b00904a32dff1c79d40158c7ec9d3d5254.png" style="height: 18px;" /> we get <a class="footnote-reference" href="#id10" id="id5">[5]</a>:</p>
<img alt="\[D_{\vec{Dg(t)}}f(t)=(\nabla f(g(t))) \cdot Dg(t)\]" class="align-center" src="https://eli.thegreenplace.net/images/math/dc8e045fe902682ada36e08fa0099f95632b7ced.png" style="height: 24px;" />
<p>This gives us some additional intuition for the temperature change question. The
change in temperature as a function of time is the directional derivative of <em>f</em>
in the direction of the change in location (<img alt="Dg(t)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/ec8c49e88582659c617e6563375355ede5fe1090.png" style="height: 18px;" />).</p>
<p>For additional examples of applying the chain rule, see
<a class="reference external" href="http://eli.thegreenplace.net/2016/the-softmax-function-and-its-derivative/">my post about softmax</a>.</p>
</div>
<div class="section" id="tricks-with-the-multivariate-chain-rule-derivative-of-products">
<h2>Tricks with the multivariate chain rule - derivative of products</h2>
<p>Earlier in the article we've seen how the chain rule helps find derivatives of
complicated functions by decomposing them into simpler functions. The
multivariate chain rule allows even more of that, as the following example
demonstrates. Suppose <img alt="h(x)=f(x)g(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/6b584f7e739b604fe2d90a983216090d25643ad1.png" style="height: 18px;" />. Then, the well-known <a class="reference external" href="https://en.wikipedia.org/wiki/Product_rule">product rule</a> of derivatives states that:</p>
<img alt="\[h&amp;#x27;(x)=f&amp;#x27;(x)g(x)+f(x)g&amp;#x27;(x)\]" class="align-center" src="https://eli.thegreenplace.net/images/math/6c77a942dbee351e8229ce7771680b6a2f55c4aa.png" style="height: 20px;" />
<p>Proving this from first principles (the definition of the derivative as a limit)
isn't hard, but I want to show how it stems very easily from the multivariate
chain rule.</p>
<p>Let's begin by re-formulating <img alt="h(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/2a1862ca703d9d9a76538d74b8f4b71df93bafab.png" style="height: 18px;" /> as a composition of two functions.
The first takes a vector <img alt="\vec{s}" class="valign-0" src="https://eli.thegreenplace.net/images/math/6a16290a6fe4bd5b30bf2cf959214e8fa4924959.png" style="height: 13px;" /> in <img alt="\mathbb{R}^2" class="valign-0" src="https://eli.thegreenplace.net/images/math/2b688757b3d0949451e1fa97e71ac5f5f284a5e4.png" style="height: 15px;" /> and maps it to
<img alt="\mathbb{R}" class="valign-0" src="https://eli.thegreenplace.net/images/math/0ed839b111fe0e3ca2b2f618b940893eaea88a57.png" style="height: 12px;" /> by computing the product of its two components:</p>
<img alt="\[p(\vec{s})=s_1 s_2\]" class="align-center" src="https://eli.thegreenplace.net/images/math/955d480267a38ec452bcdf2774dadc7652a757fa.png" style="height: 18px;" />
<p>The second is a vector-valued function that maps a number
<img alt="x \in \mathbb{R}" class="valign-m1" src="https://eli.thegreenplace.net/images/math/ec7e4961c34351c48080f6190b6ec363af9adf25.png" style="height: 13px;" /> to <img alt="\mathbb{R}^2" class="valign-0" src="https://eli.thegreenplace.net/images/math/2b688757b3d0949451e1fa97e71ac5f5f284a5e4.png" style="height: 15px;" /> :</p>
<img alt="\[s(x)=\begin{pmatrix} f(x)\\ g(x) \end{pmatrix}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/f5c473fb1fb5ee47e59414a91dc484e182bc6210.png" style="height: 43px;" />
<p>We can compose <img alt="p \circ s" class="valign-m4" src="https://eli.thegreenplace.net/images/math/3f1c954d3481a1a167ae311bc3c3980aaf1ee3a1.png" style="height: 12px;" />, producing a function that takes
a scalar an returns a scalar: <img alt="(p \circ s) : \mathbb{R} \to \mathbb{R}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/c827ac5b7598c117c157d0377b8c30a0f9a72b81.png" style="height: 18px;" />.
We get:</p>
<img alt="\[h(x)=(p \circ s)(x) = f(x)g(x)\]" class="align-center" src="https://eli.thegreenplace.net/images/math/3cbae5f44d32653bd6bbc66e6ee8bb5e1a4dfe40.png" style="height: 18px;" />
<p>Since we're composing two multivariate functions, we can apply the multivariate
chain rule here:</p>
<img alt="\[\begin{align*} D(p \circ s) &amp;amp;= Dp(s(x)) \cdot Ds(x)\\ &amp;amp;=\begin{bmatrix} \frac{\partial p}{\partial s_1}(x) &amp;amp; \frac{\partial p}{\partial s_2}(x) \end{bmatrix}\cdot \begin{bmatrix} {s_1}&amp;#x27;(x)\\ {s_2}&amp;#x27;(x) \end{bmatrix}\\ &amp;amp;=\begin{bmatrix} s_2(x) &amp;amp; s_1(x) \end{bmatrix} \cdot \begin{bmatrix} {s_1}&amp;#x27;(x)\\ {s_2}&amp;#x27;(x) \end{bmatrix}\\ &amp;amp;={s_1}&amp;#x27;(x)s_2(x)+{s_2}&amp;#x27;(x)s_1(x) \end{align*}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/ee8bd27a8257039f72c8751eb78626521f12a5fa.png" style="height: 147px;" />
<p>Since <img alt="s_1(x)=f(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/dc67440057222e8222ae08269e4ba2a1e58acbb4.png" style="height: 18px;" /> and <img alt="s_2(x)=g(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/89d252d983d49126f2f4a34fcf01fd6c882e4792.png" style="height: 18px;" />, this is exactly the product
rule.</p>
</div>
<div class="section" id="connecting-the-single-variable-and-multivariate-chain-rules">
<h2>Connecting the single-variable and multivariate chain rules</h2>
<p>Given function <img alt="f(x) : \mathbb{R} \to \mathbb{R}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/bd80387ffb4e5cd8702c12837a57f1806ea1d02b.png" style="height: 18px;" />, its Jacobian matrix
has a single entry:</p>
<img alt="\[Df(a)=\begin{bmatrix}D_{x}f(a)\end{bmatrix}= \begin{bmatrix}\frac{df}{dx}(a)\end{bmatrix}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/cc95d53415b32e6610c1a45bededb4fb584f0c64.png" style="height: 24px;" />
<p>Therefore, given two functions mapping <img alt="\mathbb{R} \to \mathbb{R}" class="valign-m1" src="https://eli.thegreenplace.net/images/math/4aaeb3aafa05a9ad54c8d7da4e4aecad4dfac1cd.png" style="height: 13px;" />, the
derivative of their composition using the multivariate chain rule is:</p>
<img alt="\[D(f \circ g)(a)=Df(g(a))\cdot Dg(a)=f&amp;#x27;(g(a))g&amp;#x27;(a)\]" class="align-center" src="https://eli.thegreenplace.net/images/math/98e554584c9d2d967b9a6759a64126093ef704ce.png" style="height: 20px;" />
<p>Which is precisely the single-variable chain rule. This results from matrix
multiplication between two 1x1 matrices, which ends up being just the product
of their single entries.</p>
</div>
<div class="section" id="appendix-proving-the-single-variable-chain-rule">
<h2>Appendix: proving the single-variable chain rule</h2>
<p>It turns out that many online resources (including Khan Academy) provide a
flawed proof for the chain rule. It's flawed due to a careless division by a
quantity that may be zero. This flaw can be corrected by making the proof
somewhat more complicated; I won't take that road here - for details see
Spivak's <em>Calculus</em>. Instead, I'll present a simpler proof inspired by the one I
found at <a class="reference external" href="http://math.rice.edu/~cjd/">Casey Douglas's site</a>.</p>
<p>We want to prove that:</p>
<img alt="\[(f \circ g)&amp;#x27;(x)={f}&amp;#x27;(g(x)){g}&amp;#x27;(x)\]" class="align-center" src="https://eli.thegreenplace.net/images/math/29f4194c9af3777ae55a15dad972a145eb7797be.png" style="height: 20px;" />
<p>Note that previously we defined derivatives at some concrete point <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" />.
Here for the sake of brevity I'll just use <img alt="x" class="valign-0" src="https://eli.thegreenplace.net/images/math/11f6ad8ec52a2984abaafd7c3b516503785c2072.png" style="height: 8px;" /> as an arbitrary point,
assuming the derivative exists.</p>
<p>Let's start with the definition of <img alt="g&amp;#x27;(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/5fba18e1364e151399f95daac5ed63f09feba9b7.png" style="height: 18px;" />:</p>
<img alt="\[{g}&amp;#x27;(x)=\lim_{h \to 0}\frac{g(x+h)-g(x)}{h}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/c19f7ddc43c3046489d7e012c3f213403edf7e8a.png" style="height: 39px;" />
<p>We can reorder it as follows:</p>
<img alt="\[\lim_{h \to 0}\left [ \frac{g(x+h)-g(x)}{h} - g&amp;#x27;(x) \right ] = 0\]" class="align-center" src="https://eli.thegreenplace.net/images/math/74a651394036af8aeaba69650dba26ccb4f90ae7.png" style="height: 43px;" />
<p>Let's give the part in the brackets the name <img alt="\Delta g(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/ba109afb4d6264ec2d39fe025fcc5a1dbc58637f.png" style="height: 18px;" />.</p>
<p>Similarly, if the function <em>f</em> is differentiable at the point <img alt="a=g(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/01bcc664dda02e9d98a5f37104ff028cf8fd0d62.png" style="height: 18px;" />,
we have:</p>
<img alt="\[f&amp;#x27;(a)=\lim_{k \to 0}\frac{f(a+k)-f(a)}{k}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/59daea2a46cd244229625131297a773820501571.png" style="height: 39px;" />
<p>We reorder:</p>
<img alt="\[\lim_{k \to 0}\left [ \frac{f(a+k)-f(a)}{k} - f&amp;#x27;(a) \right ] = 0\]" class="align-center" src="https://eli.thegreenplace.net/images/math/4600064fad365f360bd73063324a935a8b73266f.png" style="height: 43px;" />
<p>And call the part in the brackets <img alt="\Delta f(a)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/a85d05996c54eec8a9bef9a60f8e7e4f3231aa51.png" style="height: 18px;" />. The choice of the
variable used to go to zero: <em>k</em> instead of <em>h</em> is arbitrary and is useful to
simplify the discussion that follows.</p>
<p>Let's reorder the definition of <img alt="\Delta g(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/ba109afb4d6264ec2d39fe025fcc5a1dbc58637f.png" style="height: 18px;" /> a bit:</p>
<img alt="\[g(x+h)=g(x)+[g&amp;#x27;(x)+\Delta g(x)]h\]" class="align-center" src="https://eli.thegreenplace.net/images/math/59e0263f8a2ebfc0fac9a2b51f42c651b359fe31.png" style="height: 21px;" />
<p>We can apply <em>f</em> to both sides:</p>
<img alt="\[\begin{equation} f(g(x+h))=f(g(x)+[g&amp;#x27;(x)+\Delta g(x)]h) \tag{1} \end{equation}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/3b82da9d6cad509490e687b9e86093791545ea81.png" style="height: 21px;" />
<p>By reordering the definition of <img alt="\Delta f(a)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/a85d05996c54eec8a9bef9a60f8e7e4f3231aa51.png" style="height: 18px;" /> we get:</p>
<img alt="\[\begin{equation} f(a+k)=f(a)+[f&amp;#x27;(a)+\Delta f(a)]k \tag{2} \end{equation}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/88c5b43f3ba89da3853be9342381aa8dd60e024f.png" style="height: 21px;" />
<p>Now taking the right-hand side of (1), we can look at it as <img alt="f(a+k)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/831a544de22e2a6a8997413b576a67391ba31f53.png" style="height: 18px;" />
since <img alt="a=g(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/01bcc664dda02e9d98a5f37104ff028cf8fd0d62.png" style="height: 18px;" /> and we can define <img alt="k=[g&amp;#x27;(x)+\Delta g(x)]h" class="valign-m5" src="https://eli.thegreenplace.net/images/math/ed49f313283b8f266ffd1e9b4194c36f456a950d.png" style="height: 19px;" />. We still
have <em>k</em> going to zero when <em>h</em> goes to zero. Assigning these <em>a</em> and <em>k</em> into
(2) we get:</p>
<img alt="\[f(a+k)=f(g(x))+[f&amp;#x27;(g(x))+\Delta f(g(x))][g&amp;#x27;(x)+\Delta g(x)]h\]" class="align-center" src="https://eli.thegreenplace.net/images/math/275b3323c68b711b2458e4c748a887a368e32a40.png" style="height: 21px;" />
<p>So, starting from (1) again, we have:</p>
<img alt="\[\begin{align*} f(g(x+h))&amp;amp;=f(a+k) \\ &amp;amp;=f(g(x))+[f&amp;#x27;(g(x))+\Delta f(g(x))][g&amp;#x27;(x)+\Delta g(x)]h \end{align*}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/82e67cf24d9eb3dad58e7d30cd89ba1c19e367fb.png" style="height: 45px;" />
<p>Subtracting <img alt="f(g(x))" class="valign-m4" src="https://eli.thegreenplace.net/images/math/92cb7139e348ea05a69782b2cf7221bae86a2b03.png" style="height: 18px;" /> from both sides and dividing by <em>h</em> (which is legal,
since <em>h</em> is not zero, it's just very small) we get:</p>
<img alt="\[\frac{f(g(x+h))-f(g(x))}{h}=[f&amp;#x27;(g(x))+\Delta f(g(x))][g&amp;#x27;(x)+\Delta g(x)]\]" class="align-center" src="https://eli.thegreenplace.net/images/math/bfdfef3d46b471aa5d6803c3c5a6b5e26ffe3b37.png" style="height: 39px;" />
<p>Apply a limit to both sides:</p>
<img alt="\[\lim_{h \to 0} \frac{f(g(x+h))-f(g(x))}{h}= \lim_{h \to 0} [f&amp;#x27;(g(x))+\Delta f(g(x))][g&amp;#x27;(x)+\Delta g(x)]\]" class="align-center" src="https://eli.thegreenplace.net/images/math/0f5c316fcc2877f78b8a739898a31120471dd401.png" style="height: 39px;" />
<p>Now recall that both <img alt="\Delta f(g(x))" class="valign-m4" src="https://eli.thegreenplace.net/images/math/6752886c71e1a95dc360dd4e5ea10dd0b6f76e84.png" style="height: 18px;" /> and <img alt="\Delta g(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/ba109afb4d6264ec2d39fe025fcc5a1dbc58637f.png" style="height: 18px;" /> go to 0 when
<em>h</em> goes to 0. Taking this into account, we get:</p>
<img alt="\[\lim_{h \to 0} \frac{f(g(x+h))-f(g(x))}{h}= f&amp;#x27;(g(x))g&amp;#x27;(x)\]" class="align-center" src="https://eli.thegreenplace.net/images/math/3954d8d23c8fb53d4cd1732d19939d650ef830ae.png" style="height: 39px;" />
<p><em>Q.E.D.</em></p>
<hr class="docutils" />
<table class="docutils footnote" frame="void" id="id6" rules="none">
<colgroup><col class="label" /><col /></colgroup>
<tbody valign="top">
<tr><td class="label"><a class="fn-backref" href="#id1">[1]</a></td><td>Here, as in the rest of the post, I'm being careless with the usage of
<img alt="\lim" class="valign-0" src="https://eli.thegreenplace.net/images/math/6f5c7776306147fe3be3e4b8547a23c62eafddf4.png" style="height: 13px;" />, sometimes leaving its existence to be implicit. In general,
wherever <em>h</em> appears in a formula we know there's a
<img alt="\lim_{h \to 0}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/0f10af054e5ddc3b9603098fec294e0247190efa.png" style="height: 17px;" /> there, whether explicitly or not.</td></tr>
</tbody>
</table>
<table class="docutils footnote" frame="void" id="id7" rules="none">
<colgroup><col class="label" /><col /></colgroup>
<tbody valign="top">
<tr><td class="label"><a class="fn-backref" href="#id2">[2]</a></td><td>An alternative way to think about it is: suppose the functions
<em>f</em> and <em>g</em> are linear: <img alt="f(x)=ax+b" class="valign-m4" src="https://eli.thegreenplace.net/images/math/a85393d5068f5c4bc36ff7efed535a8f1a686848.png" style="height: 18px;" /> and <img alt="g(x)=cx+d" class="valign-m4" src="https://eli.thegreenplace.net/images/math/6d712cb582caa0e48a2b029ea4ae29a3e5e40f27.png" style="height: 18px;" />. Then
the chain rule is trivially true. But now recall what the derivative is.
The derivative at some point <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" /> is the best linear approximation
for the function at that point. Therefore the chain rule is true for any
pair of differentiable functions - even when the functions are not
linear, we approximate their rate of change in an infinitisemal area
around <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" /> with a linear function.</td></tr>
</tbody>
</table>
<table class="docutils footnote" frame="void" id="id8" rules="none">
<colgroup><col class="label" /><col /></colgroup>
<tbody valign="top">
<tr><td class="label"><a class="fn-backref" href="#id3">[3]</a></td><td>The condition for <em>f</em> being differentiable at <em>a</em> is stronger than simply
saying that all partial derivatives exist at <em>a</em>, but I won't spend more
time on this subtlety here.</td></tr>
</tbody>
</table>
<table class="docutils footnote" frame="void" id="id9" rules="none">
<colgroup><col class="label" /><col /></colgroup>
<tbody valign="top">
<tr><td class="label"><a class="fn-backref" href="#id4">[4]</a></td><td>As an exercise, verify that the matrix dimensions of <img alt="Df" class="valign-m4" src="https://eli.thegreenplace.net/images/math/5c6bf530660cba6530e83a86f0ed49fe0821d179.png" style="height: 16px;" /> and
<object class="valign-m4" data="https://eli.thegreenplace.net/images/math/38b655437da0880bd70168fcbadb50ebdbf46ca5.svg" style="height: 16px;" type="image/svg+xml">Dg</object> make this multiplication valid.</td></tr>
</tbody>
</table>
<table class="docutils footnote" frame="void" id="id10" rules="none">
<colgroup><col class="label" /><col /></colgroup>
<tbody valign="top">
<tr><td class="label"><a class="fn-backref" href="#id5">[5]</a></td><td>It shouldn't be surprising we get here, since the definition of the
directional derivative as the gradient <a class="reference external" href="http://eli.thegreenplace.net/2016/understanding-gradient-descent">was derived</a>
using the multivariate chain rule.</td></tr>
</tbody>
</table>
</div>
Linear regression2016-08-06T05:28:00-07:002016-08-06T05:28:00-07:00Eli Benderskytag:eli.thegreenplace.net,2016-08-06:/2016/linear-regression/<p>Linear regression is one of the most basic, and yet most useful approaches for
predicting a single quantitative (real-valued) variable given any number of
real-valued predictors. This article presents the basics of linear
regression for the &quot;simple&quot; (single-variable) case, as well as for the more
general multivariate case. <a class="reference external" href="https://github.com/eliben/deep-learning-samples/tree/master/linear-regression">Companion code …</a></p><p>Linear regression is one of the most basic, and yet most useful approaches for
predicting a single quantitative (real-valued) variable given any number of
real-valued predictors. This article presents the basics of linear
regression for the &quot;simple&quot; (single-variable) case, as well as for the more
general multivariate case. <a class="reference external" href="https://github.com/eliben/deep-learning-samples/tree/master/linear-regression">Companion code in Python</a>
implements the techniques described in the article on simulated and realistic
data sets. The code is self-contained, using only Numpy as a dependency.</p>
<div class="section" id="simple-linear-regression">
<h2>Simple linear regression</h2>
<p>The most basic kind of regression problem has a single <em>predictor</em> (the
input) and a single outcome. Given a list of input values
<img alt="x_i" class="valign-m3" src="https://eli.thegreenplace.net/images/math/34e03e6559b14df9fe5a97bbd2ed10109dfebbd3.png" style="height: 11px;" /> and corresponding output values <img alt="y_i" class="valign-m4" src="https://eli.thegreenplace.net/images/math/35c2ac2f82d0ff8f9011b596ed7e54bfcc55f471.png" style="height: 12px;" />, we have to find
parameters <em>m</em> and <em>b</em> such that the linear function:</p>
<img alt="\[\hat{y}(x) = mx + b\]" class="align-center" src="https://eli.thegreenplace.net/images/math/2dabbcda3b1953b08211f7e334698366d647d697.png" style="height: 18px;" />
<p>Is &quot;as close as possible&quot; to the observed outcome <em>y</em>. More concretely, suppose
we get this data <a class="footnote-reference" href="#id6" id="id1">[1]</a>:</p>
<img alt="Linear regression input data" class="align-center" src="https://eli.thegreenplace.net/images/2016/linreg-data.png" />
<p>We have to find a slope <em>m</em> and intercept <em>b</em> for a line that approximates this
data as well as possible. We evaluate how well some pair of <em>m</em> and <em>b</em>
approximates the data by defining a &quot;cost function&quot;. For linear regression, a
good cost function to use is the <a class="reference external" href="https://en.wikipedia.org/wiki/Mean_squared_error">Mean Square Error (MSE)</a> <a class="footnote-reference" href="#id7" id="id2">[2]</a>:</p>
<img alt="\[\operatorname{MSE}(m, b)=\frac{1}{n}\sum_{i=1}^n(\hat{y_i} - y_i)^2\]" class="align-center" src="https://eli.thegreenplace.net/images/math/e4b7b4ce3abd90f20144e6ab468b7870cedf3b07.png" style="height: 50px;" />
<p>Expanding <img alt="\hat{y_i}=m{x_i}+b" class="valign-m4" src="https://eli.thegreenplace.net/images/math/daecd48b7bb0a06ddd4326da5b87ee14fddaeb8e.png" style="height: 17px;" />, we get:</p>
<img alt="\[\operatorname{MSE}(m, b)=\frac{1}{n}\sum_{i=1}^n(m{x_i} + b - y_i)^2\]" class="align-center" src="https://eli.thegreenplace.net/images/math/3de1df776434b29620488aef327a9204757bc493.png" style="height: 50px;" />
<p>Let's turn this into Python code (<a class="reference external" href="https://github.com/eliben/deep-learning-samples/blob/master/linear-regression/simple_linear_regression.py">link to the full code sample</a>):</p>
<div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">compute_cost</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">m</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;Compute the MSE cost of a prediction based on m, b.</span>
<span class="sd"> x: inputs vector</span>
<span class="sd"> y: observed outputs vector</span>
<span class="sd"> m, b: regression parameters</span>
<span class="sd"> Returns: a scalar cost.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">yhat</span> <span class="o">=</span> <span class="n">m</span> <span class="o">*</span> <span class="n">x</span> <span class="o">+</span> <span class="n">b</span>
<span class="n">diff</span> <span class="o">=</span> <span class="n">yhat</span> <span class="o">-</span> <span class="n">y</span>
<span class="c1"># Vectorized computation using a dot product to compute sum of squares.</span>
<span class="n">cost</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">diff</span><span class="o">.</span><span class="n">T</span><span class="p">,</span> <span class="n">diff</span><span class="p">)</span> <span class="o">/</span> <span class="nb">float</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
<span class="c1"># Cost is a 1x1 matrix, we need a scalar.</span>
<span class="k">return</span> <span class="n">cost</span><span class="o">.</span><span class="n">flat</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
</pre></div>
<p>Now we're faced with a classical optimization problem: we have some parameters
(<em>m</em> and <em>b</em>) we can tweak, and some cost function <img alt="J(m, b)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/d61807c64b6ab8087a11167224df4b5f818aeae3.png" style="height: 18px;" /> we want to
minimize. The topic of mathematical optimization is vast, but what ends up
working very well for machine learning is a fairly simple algorithm called
<em>gradient descent</em>.</p>
<p>Imagine plotting <img alt="J(m, b)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/d61807c64b6ab8087a11167224df4b5f818aeae3.png" style="height: 18px;" /> as a 3-dimensional surface, and picking some
random point on it. Our goal is to find the lowest point on the surface, but we
have no idea where that is. A reasonable guess is to move a bit &quot;downwards&quot; from
our current location, and then repeat.</p>
<p>&quot;Downwards&quot; is exactly what &quot;gradient descent&quot; means. We make a small change to
our location (defined by <em>m</em> and <em>b</em>) in the direction in which <img alt="J(m, b)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/d61807c64b6ab8087a11167224df4b5f818aeae3.png" style="height: 18px;" />
decreases most - the gradient <a class="footnote-reference" href="#id8" id="id3">[3]</a>. We then repeat this process until we reach
a minimum, hopefully global. In fact, since the linear regression cost function
is <em>convex</em> we will find the global minimum this way. But in the general case
this is not guaranteed, and many sophisticated extensions of gradient descent
exist that try to avoid local minima and maximize the chance of finding a global
one.</p>
<p>Back to our function, <img alt="\operatorname{MSE}(m, b)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/e42329899ba53adedf0a7884b1844dba4f01bdee.png" style="height: 18px;" />. The gradient is defined
as the vector:</p>
<img alt="\[\nabla \operatorname{MSE}=\left \langle \frac{\partial \operatorname{MSE}}{\partial m}, \frac{\partial \operatorname{MSE}}{\partial b} \right \rangle\]" class="align-center" src="https://eli.thegreenplace.net/images/math/50b0404ea5a8f76da73caae5b8109dd384dbd18e.png" style="height: 43px;" />
<p>To find it, we have to compute the partial derivatives of MSE w.r.t. the
learning parameters <em>m</em> and <em>b</em>:</p>
<img alt="\[\begin{align*} \frac{\partial \operatorname{MSE}}{\partial m}&amp;amp;=\frac{2}{n}\sum_{i=i}^n(m{x_i}+b-y_i)x_i\\ \frac{\partial \operatorname{MSE}}{\partial b}&amp;amp;=\frac{2}{n}\sum_{i=i}^n(m{x_i}+b-y_i) \end{align*}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/dbd383b0d7ee194a417b88ad117b451531758fe7.png" style="height: 108px;" />
<p>And then update <em>m</em> and <em>b</em> in each step of the learning with:</p>
<img alt="\[\begin{align*} m &amp;amp;= m-\eta \frac{\partial \operatorname{MSE}}{\partial m} \\ b &amp;amp;= b-\eta \frac{\partial \operatorname{MSE}}{\partial b} \\ \end{align*}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/b0c7ff699fc61836051968db56224e6470b56d3c.png" style="height: 81px;" />
<p>Where <img alt="\eta" class="valign-m4" src="https://eli.thegreenplace.net/images/math/2899aeb886ad0fa72652bffd5511e452aaf084ab.png" style="height: 12px;" /> is a customizable &quot;learning rate&quot;, a hyperparameter. Here is
the gradient descent loop in Python. Note that we examine the whole data set in
every step; for much larger data sets, SGD (Stochastic Gradient Descent) with
some reasonable mini-batch would make more sense, but for simple linear
regression problems the data size is rarely very big.</p>
<div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">gradient_descent</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">nsteps</span><span class="p">,</span> <span class="n">learning_rate</span><span class="o">=</span><span class="mf">0.1</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;Runs gradient descent optimization to fit a line y^ = x * m + b.</span>
<span class="sd"> x, y: input data and observed outputs.</span>
<span class="sd"> nsteps: how many steps to run the optimization for.</span>
<span class="sd"> learning_rate: learning rate of gradient descent.</span>
<span class="sd"> Yields &#39;nsteps + 1&#39; triplets of (m, b, cost) where m, b are the fit</span>
<span class="sd"> parameters for the given step, and cost is their cost vs the real y.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">n</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="c1"># Start with m and b initialized to 0s for the first try.</span>
<span class="n">m</span><span class="p">,</span> <span class="n">b</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span>
<span class="k">yield</span> <span class="n">m</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">compute_cost</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">m</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
<span class="k">for</span> <span class="n">step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">nsteps</span><span class="p">):</span>
<span class="n">yhat</span> <span class="o">=</span> <span class="n">m</span> <span class="o">*</span> <span class="n">x</span> <span class="o">+</span> <span class="n">b</span>
<span class="n">diff</span> <span class="o">=</span> <span class="n">yhat</span> <span class="o">-</span> <span class="n">y</span>
<span class="n">dm</span> <span class="o">=</span> <span class="n">learning_rate</span> <span class="o">*</span> <span class="p">(</span><span class="n">diff</span> <span class="o">*</span> <span class="n">x</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">/</span> <span class="n">n</span>
<span class="n">db</span> <span class="o">=</span> <span class="n">learning_rate</span> <span class="o">*</span> <span class="n">diff</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">/</span> <span class="n">n</span>
<span class="n">m</span> <span class="o">-=</span> <span class="n">dm</span>
<span class="n">b</span> <span class="o">-=</span> <span class="n">db</span>
<span class="k">yield</span> <span class="n">m</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">compute_cost</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">m</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
</pre></div>
<p>After running this for 30 steps, the gradient converges and the parameters
barely change. Here's a 3D plot of the cost as a function of the regression
parameters, along with a contour plot of the same function. It's easy to see
this function is convex, as expected. This makes finding the global minimum
simple, since no matter where we start, the gradient will lead us directly to
it.</p>
<p>To help visualize this, I marked the cost for each successive training step on
the contour plot - you can see how the algorithm relentlessly converges to the
minimum</p>
<img alt="Linear regression cost and contour" class="align-center" src="https://eli.thegreenplace.net/images/2016/linreg-cost-contour.png" />
<p>The final parameters learned by the regression are 2.2775 for <em>m</em> and
6.0028 for <em>b</em>, which is very close to the actual parameters I used to
generate this fake data with.</p>
<p>Here's a visualization that shows how the regression line improves progressively
during learning:</p>
<img alt="Regression fit visualization" class="align-center" src="https://eli.thegreenplace.net/images/2016/regressionfit.gif" />
</div>
<div class="section" id="evaluating-how-good-the-fit-is">
<h2>Evaluating how good the fit is</h2>
<p>In statistics, there are many ways to evaluate how good a &quot;fit&quot; some model is
on the given data. One of the most popular ones is the <em>r-squared</em> test
(&quot;coefficient of determination&quot;). It measures the proportion of the total
variance in the output (<em>y</em>) that can be explained by the variation in <em>x</em>:</p>
<img alt="\[R^2 = 1 - \frac{\sum_{i=1}^n (y_i - (m{x_i} + b))^2}{n\cdot var(y)}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/2c989c7345d6901a0cf7c17f9b08762ef27c5148.png" style="height: 43px;" />
<p>This is trivial to translate to code:</p>
<div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">compute_rsquared</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">m</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
<span class="n">yhat</span> <span class="o">=</span> <span class="n">m</span> <span class="o">*</span> <span class="n">x</span> <span class="o">+</span> <span class="n">b</span>
<span class="n">diff</span> <span class="o">=</span> <span class="n">yhat</span> <span class="o">-</span> <span class="n">y</span>
<span class="n">SE_line</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">diff</span><span class="o">.</span><span class="n">T</span><span class="p">,</span> <span class="n">diff</span><span class="p">)</span>
<span class="n">SE_y</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">y</span><span class="p">)</span> <span class="o">*</span> <span class="n">y</span><span class="o">.</span><span class="n">var</span><span class="p">()</span>
<span class="k">return</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">SE_line</span> <span class="o">/</span> <span class="n">SE_y</span>
</pre></div>
<p>For our regression results, I get <em>r-squared</em> of 0.76, which isn't too bad. Note
that the data is very jittery, so it's natural the regression cannot explain all
the variance. As an interesting exercise, try to modify the code that generates
the data with different standard deviations for the random noise and see the
effect on <em>r-squared</em>.</p>
</div>
<div class="section" id="an-analytical-solution-to-simple-linear-regression">
<h2>An analytical solution to simple linear regression</h2>
<p>Using the equations for the partial derivatives of MSE (shown above) it's
possible to find the minimum analytically, without having to resort to a
computational procedure (gradient descent). We compare the derivatives to zero:</p>
<img alt="\[\begin{align*} \frac{\partial \operatorname{MSE}}{\partial m}&amp;amp;=\frac{2}{n}\sum_{i=i}^n(m{x_i}+b-y_i)x_i = 0\\ \frac{\partial \operatorname{MSE}}{\partial b}&amp;amp;=\frac{2}{n}\sum_{i=i}^n(m{x_i}+b-y_i) = 0 \end{align*}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/aef02f077919896478d0456619f934dcc5809142.png" style="height: 108px;" />
<p>And solve for <em>m</em> and <em>b</em>. To make the equations easier to follow, let's
introduce a bit of notation. <img alt="\bar{x}" class="valign-0" src="https://eli.thegreenplace.net/images/math/8eebe76c6f552df3f8b9480d5544fe47b1028322.png" style="height: 11px;" /> is the mean value of <em>x</em> across
all samples. Similarly <img alt="\bar{y}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/1e3bffc7f71c01acbc2c12e015be3086a06f824d.png" style="height: 15px;" /> is the mean value of <em>y</em>. So the sum
<img alt="\sum_{i=1}^n x_i" class="valign-m6" src="https://eli.thegreenplace.net/images/math/c42eb1b96dfa184fee1bc0f3a4b713b9c38b2a1a.png" style="height: 20px;" /> is actually <img alt="n\bar{x}" class="valign-0" src="https://eli.thegreenplace.net/images/math/ea6008aefff0c7d79044287c44e890b1fba97c22.png" style="height: 11px;" />. Now let's take the second
equation from above and see how to simplify it:</p>
<img alt="\[\begin{align*} \frac{\partial \operatorname{MSE}}{\partial b} &amp;amp;= \frac{2}{n}\sum_{i=i}^n(m{x_i}+b-y_i) \\ &amp;amp;= \frac{2}{n}(mn\bar{x}+nb-n\bar{y}) \\ &amp;amp;= 2m\bar{x} + 2b - 2\bar{y} = 0 \end{align*}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/c97c0c9ca8a66d54974fc914fcf929085dc63879.png" style="height: 119px;" />
<p>Similarly, for the partial derivative by <em>m</em> we can reach:</p>
<img alt="\[\frac{\partial \operatorname{MSE}}{\partial m}= 2m\overline{x^2} + 2b\bar{x} - 2\overline{xy} = 0\]" class="align-center" src="https://eli.thegreenplace.net/images/math/d9545273e11c9e179794f943e2c972bf62c38113.png" style="height: 38px;" />
<p>In these equations, all quantities except <em>m</em> and <em>b</em> are constant. Solving them
for the unknowns <em>m</em> and <em>b</em>, we get <a class="footnote-reference" href="#id9" id="id4">[4]</a>:</p>
<img alt="\[m = \frac{\bar{x}\bar{y} - \overline{xy}}{\bar{x}^2 - \overline{x^2}} \qquad b = \bar{y} - \bar{x}\frac{\bar{x}\bar{y} - \overline{xy}}{\bar{x}^2 - \overline{x^2}}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/becd671e8c032d0568e33b986033c181ac5c133b.png" style="height: 38px;" />
<p>If we plug the data values we have for <em>x</em> and <em>y</em> in these equations, we get
2.2777 for <em>m</em> and 6.0103 for <em>b</em> - almost exactly the values we obtained
with regression <a class="footnote-reference" href="#id10" id="id5">[5]</a>.</p>
<p>Remember that by comparing the partial derivatives to zero we find a <em>critical
point</em>, which is not necessarily a minimum. We can use the <a class="reference external" href="https://en.wikipedia.org/wiki/Second_partial_derivative_test">second derivative
test</a> to find
what kind of critical point that is, by computing the Hessian of the cost:</p>
<img alt="\[H(m, b) = \begin{pmatrix} \operatorname{MSE}_{mm}(x, y) &amp;amp; \operatorname{MSE}_{mb}(x, y) \\ \operatorname{MSE}_{bm}(x, y) &amp;amp; \operatorname{MSE}_{bb}(x, y) \end{pmatrix}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/39c2e86ae1437d3b19bc8e77b66501486550d3bc.png" style="height: 43px;" />
<p>Plugging the numbers and running the test, we can indeed verify that the
critical point is a minimum.</p>
</div>
<div class="section" id="multiple-linear-regression">
<h2>Multiple linear regression</h2>
<p>The good thing about simple regression is that it's easy to visualize. The model
is trained using just two parameters, and visualizing the cost as a function of
these two parameters is possible since we get a 3D plot. Anything beyond that
becomes increasingly more difficult to visualize.</p>
<p>In simple linear regression, every <em>x</em> is just a number; so is every <em>y</em>. In
multiple linear regression this is no longer so, and each data point <em>x</em> is a
vector. The model parameters can also be represented by the vector
<img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" />. To avoid confusion of indices and subscripts, let's agree that
we use subscripts to denote components of vectors, while parenthesized
superscripts are used to denote different samples. So <img alt="x_1^{(6)}" class="valign-m6" src="https://eli.thegreenplace.net/images/math/d01999f5014c6aea058368231c0d2b958fa8a89e.png" style="height: 26px;" /> is the
second component of sample 6.</p>
<p>Our goal is to find the vector <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> such that the linear function:</p>
<img alt="\[\hat{y}(x) = \theta_0 x_0 + \theta_1 x_1 + \cdots + \theta_n x_n\]" class="align-center" src="https://eli.thegreenplace.net/images/math/ae682f9fda97c28c8e100c87aecad635c7c1d96c.png" style="height: 18px;" />
<p>Is as close as possible to the actual <em>y</em> across all samples. Since working with
vectors is easier for this problem, we define <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" /> to always be equal to
1, so that the first term in the equation above denotes the intercept.
Expressing the regression coefficients as a vector:</p>
<img alt="\[\begin{pmatrix} \theta_0\\ \theta_1\\ ...\\ \theta_n \end{pmatrix}\in\mathbb{R}^{n+1}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/b16fd3d2b3041f13cb70199837a7c02c756078c7.png" style="height: 86px;" />
<p>We can now rewrite <img alt="\hat{y}(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/11533fb1b0218620907f5859e6e22aeb65c12cd8.png" style="height: 18px;" /> as:</p>
<img alt="\[\hat{y}(x) = \theta^T x\]" class="align-center" src="https://eli.thegreenplace.net/images/math/8156e2dc4e654f77a8664180c168829f6b4cdb0b.png" style="height: 21px;" />
<p>Where both <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> and <em>x</em> are column vectors with <em>n+1</em> elements, as
shown above. The mean square error (over <em>k</em> samples) now becomes:</p>
<img alt="\[\operatorname{MSE}=\frac{1}{k}\sum_{i=1}^k(\hat{y}(x^{(i)}) - y^{(i)})^2\]" class="align-center" src="https://eli.thegreenplace.net/images/math/1e0a7c0c85c1827b992671b88e89ba052d37a204.png" style="height: 54px;" />
<p>Now we have to find the partial derivative of this cost by each <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" />.
Using the chain rule, it's easy to see that:</p>
<img alt="\[\frac{\partial \operatorname{MSE}}{\partial \theta_j} = \frac{2}{k}\sum_{i=1}^k(\hat{y}(x^{(i)}) - y^{(i)})x_j^{(i)}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/4c2fcfed81c294ef7313198debe3801f50bea92a.png" style="height: 54px;" />
<p>And use this to update the parameters in every training step. The code is
actually not much different from the simple regression case; here is a <a class="reference external" href="https://github.com/eliben/deep-learning-samples/blob/master/linear-regression/multiple_linear_regression.py">well
documented, completely worked out example</a>.
The code takes a realistic dataset from the <a class="reference external" href="http://archive.ics.uci.edu/ml/">UCI machine learning repository</a> with 4 predictors and a single outcome and
builds a regression model. 4 predictors plus one intercept give us a
5-dimensional <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" />, which is utterly impossible to visualize, so we
have to stick to math in order to analyze it.</p>
</div>
<div class="section" id="an-analytical-solution-to-multiple-linear-regression">
<h2>An analytical solution to multiple linear regression</h2>
<p>Multiple linear regression also has an analytical solution. If we compute the
derivative of the cost by each <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" />, we'll end up with <em>n+1</em> equations
with the same number of variables, which we can solve analytically.</p>
<p>An elegant matrix formula that computes <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> from <em>X</em> and <em>y</em> is
called the Normal Equation:</p>
<img alt="\[\theta=(X^TX)^{-1}X^Ty\]" class="align-center" src="https://eli.thegreenplace.net/images/math/20baabd9d33dcd26003bc44c7d81ba39e1ad4caa.png" style="height: 21px;" />
<p>I've written about <a class="reference external" href="http://eli.thegreenplace.net/2014/derivation-of-the-normal-equation-for-linear-regression">deriving the normal equation</a>
previously, so I won't spend more time on it. The accompanying code computes
<img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> using the normal equation and compares the result with the
<img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> obtained from gradient descent.</p>
<p>As an excercise, you can double check that the analytical solution for simple
linear regression (formulae for <em>m</em> and <em>b</em>) is just a special case of applying
the normal equation in two dimensions.</p>
<p>You may wonder: when should we use the analytical solution, and when is gradient
descent better? In general, whenever we can use the analytical solution - we
should. But it's not always feasible, computationally.</p>
<p>Consider a data set with <em>k</em> samples and <em>n</em> features. Then <em>X</em> is a <em>k x n</em>
matrix, and hence <img alt="X^TX" class="valign-0" src="https://eli.thegreenplace.net/images/math/5c817c84ec1f83b23494df6125edd091a7c413dd.png" style="height: 15px;" /> is a <em>n x n</em> matrix. Inverting a matrix is a
<img alt="O(n^3)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/62a87bfd600dc05059675e34b881c78648f53401.png" style="height: 19px;" /> operation, so for large <em>n</em>, finding <img alt="(X^TX)^{-1}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/57f592cee6ceac659262d97e61c64f9ca405d7f1.png" style="height: 19px;" /> can take
quite a bit of time. Moreover, keeping <img alt="X^TX" class="valign-0" src="https://eli.thegreenplace.net/images/math/5c817c84ec1f83b23494df6125edd091a7c413dd.png" style="height: 15px;" /> in memory can be
computationally infeasible if <img alt="X" class="valign-0" src="https://eli.thegreenplace.net/images/math/c032adc1ff629c9b66f22749ad667e6beadf144b.png" style="height: 12px;" /> is huge and sparse, but <img alt="X^TX" class="valign-0" src="https://eli.thegreenplace.net/images/math/5c817c84ec1f83b23494df6125edd091a7c413dd.png" style="height: 15px;" /> is
dense. In all these cases, iterative gradient descent is a more feasible
approach.</p>
<p>In addition, the moment we deviate from the linear regression a bit, such as
adding nonlinear terms, regularization, or some other model enhancement, the
analytical solutions no longer apply. Gradient descent keeps working just the
same, however, as long as we know how to compute the gradient of the new cost
function.</p>
<hr class="docutils" />
<table class="docutils footnote" frame="void" id="id6" rules="none">
<colgroup><col class="label" /><col /></colgroup>
<tbody valign="top">
<tr><td class="label"><a class="fn-backref" href="#id1">[1]</a></td><td>This data was generated by using a slope of 2.25, intercept of 6 and
added Gaussian noise with a standard deviation of 1.5</td></tr>
</tbody>
</table>
<table class="docutils footnote" frame="void" id="id7" rules="none">
<colgroup><col class="label" /><col /></colgroup>
<tbody valign="top">
<tr><td class="label"><a class="fn-backref" href="#id2">[2]</a></td><td>Some resources use SSE - the Squared Sum Error, which is just the MSE
without the averaging. Yet others have <em>2n</em> in the denominator to
make the gradient derivation cleaner. None of this really matters in
practice. When finding the minimum analytically, we compare derivatives
to zero so constant factors cancel out. When running gradient descent,
all constant factors are subsumed into the learning rate which is
arbitrary.</td></tr>
</tbody>
</table>
<table class="docutils footnote" frame="void" id="id8" rules="none">
<colgroup><col class="label" /><col /></colgroup>
<tbody valign="top">
<tr><td class="label"><a class="fn-backref" href="#id3">[3]</a></td><td>For a mathematical justification for <em>why</em> the gradient leads us in the
direction of most change, see <a class="reference external" href="http://eli.thegreenplace.net/2016/understanding-gradient-descent">this post</a>.</td></tr>
</tbody>
</table>
<table class="docutils footnote" frame="void" id="id9" rules="none">
<colgroup><col class="label" /><col /></colgroup>
<tbody valign="top">
<tr><td class="label"><a class="fn-backref" href="#id4">[4]</a></td><td>An alternative way I've seen this equation written is to express <em>m</em>
as:</td></tr>
</tbody>
</table>
<img alt="\[\begin{align*} m &amp;amp;= \frac{\sum_{i=1}^n(x_i-\bar{x})(y_i-\bar{y})}{\sum_{i=1}^n(x_i-\bar{x})^2} \\ &amp;amp;= \frac{cov(x, y)}{var(x)} \end{align*}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/53639f1f77080dbe8a6d3a8cd06e08a90de69a8e.png" style="height: 92px;" />
<table class="docutils footnote" frame="void" id="id10" rules="none">
<colgroup><col class="label" /><col /></colgroup>
<tbody valign="top">
<tr><td class="label"><a class="fn-backref" href="#id5">[5]</a></td><td>Can you figure out why even the analytical solution is a little off from
the actual parameters used to generated this data?</td></tr>
</tbody>
</table>
</div>
Understanding gradient descent2016-08-05T05:38:00-07:002016-08-05T05:38:00-07:00Eli Benderskytag:eli.thegreenplace.net,2016-08-05:/2016/understanding-gradient-descent/<p>Gradient descent is a standard tool for optimizing complex functions iteratively
within a computer program. Its goal is: given some arbitrary function, find a
minumum. For some small subset of functions - those that are <em>convex</em> - there's
just a single minumum which also happens to be global. For most realistic
functions …</p><p>Gradient descent is a standard tool for optimizing complex functions iteratively
within a computer program. Its goal is: given some arbitrary function, find a
minumum. For some small subset of functions - those that are <em>convex</em> - there's
just a single minumum which also happens to be global. For most realistic
functions, there may be many minima, so most minima are local. Making sure the
optimization finds the &quot;best&quot; minumum and doesn't get stuck in sub-optimial
minima is out of the scope of this article. Here we'll just be dealing with the
core gradient descent algorithm for finding <em>some</em> minumum from a given starting
point.</p>
<p>The main premise of gradient descent is: given some current location <em>x</em> in the
search space (the domain of the optimized function) we ought to update <em>x</em> for
the next step in the direction opposite to the gradient of the function computed
at <em>x</em>. But <em>why</em> is this the case? The aim of this article is to explain why,
mathematically.</p>
<p>This is also the place for a disclaimer: the examples used throughout the
article are trivial, low-dimensional, convex functions. We don't really <em>need</em>
an algorithmic procedure to find their global minumum - a quick computation
would do, or really just eyeballing the function's plot. In reality we will be
dealing with non-linear, 1000-dimensional functions where it's utterly
impossible to visualize anything, or solve anything analytically. The approach
works just the same there, however.</p>
<div class="section" id="building-intuition-with-single-variable-functions">
<h2>Building intuition with single-variable functions</h2>
<p>The gradient is formally defined for <em>multivariate</em> functions. However, to start
building intuition, it's useful to begin with the two-dimensional case, a
single-variable function <img alt="f(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/3e03f4706048fbc6c5a252a85d066adf107fcc1f.png" style="height: 18px;" />.</p>
<p>In single-variable functions, the simple derivative plays the role of a
gradient. So &quot;gradient descent&quot; would really be &quot;derivative descent&quot;; let's see
what that means.</p>
<p>As an example, let's take the function <img alt="f(x)=(x-1)^2" class="valign-m4" src="https://eli.thegreenplace.net/images/math/b898d66867ea1e832ab5cda94453ab3a69bae865.png" style="height: 19px;" />. Here's its plot, in
red:</p>
<img alt="Plot of parabola with tangent lines" class="align-center" src="https://eli.thegreenplace.net/images/2016/plot-parabola-with-tangents.png" />
<p>I marked a couple of points on the plot, in blue, and drew the tangents to the
function at these points. Remember, our goal is to find the minimum of the
function. To do that, we'll start with a guess for an <em>x</em>, and continously
update it to improve our guess based on some computation. How do we know how to
update <em>x</em>? The update has only two possible directions: increase <em>x</em> or
decrease <em>x</em>. We have to decide which of the two directions to take.</p>
<p>We do that based on the derivative of <img alt="f(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/3e03f4706048fbc6c5a252a85d066adf107fcc1f.png" style="height: 18px;" />. The derivative at some point
<img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" /> is defined as the limit <a class="footnote-reference" href="#id5" id="id1">[1]</a>:</p>
<img alt="\[\frac{d}{dx}f(x_0)=\lim_{h \to 0}\frac{f(x_0+h)-f(x_0)}{h}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/bfd7f38f59e2ff0d548c19f8f780605b099ecaf7.png" style="height: 39px;" />
<p>Intuitively, this tells us what happens to <img alt="f(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/3e03f4706048fbc6c5a252a85d066adf107fcc1f.png" style="height: 18px;" /> when we add a very small
value to <em>x</em>. For example in the plot above, at <img alt="x_0=3" class="valign-m3" src="https://eli.thegreenplace.net/images/math/5fa44ff4e2c914452bf56041b4ef99ceb61592f9.png" style="height: 15px;" /> we have:</p>
<img alt="\[\begin{align*} \frac{d}{dx}f(3)&amp;amp;=\lim_{h \to 0}\frac{f(3+h)-f(3)}{h} \\ &amp;amp;=\lim_{h \to 0}\frac{(3+h-1)^2-(3-1)^2}{h} \\ &amp;amp;=\lim_{h \to 0}\frac{h^2+4h}{h} \\ &amp;amp;=\lim_{h \to 0}h+4=4 \end{align*}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/e572beffc8415b4ba4c8c9419105863e3ce2082f.png" style="height: 168px;" />
<p>This means that the <em>slope</em> of <img alt="\frac{df}{dx}" class="valign-m6" src="https://eli.thegreenplace.net/images/math/45e7d07281bf1883224069f5b8d98a4bd6b21693.png" style="height: 23px;" /> at <img alt="x_0=3" class="valign-m3" src="https://eli.thegreenplace.net/images/math/5fa44ff4e2c914452bf56041b4ef99ceb61592f9.png" style="height: 15px;" /> is 4; for
a very small positive change <em>h</em> to <em>x</em> at that point, the value of <img alt="f(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/3e03f4706048fbc6c5a252a85d066adf107fcc1f.png" style="height: 18px;" />
will increase by <em>4h</em>. Therefore, to get closer to the minimum of
<img alt="f(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/3e03f4706048fbc6c5a252a85d066adf107fcc1f.png" style="height: 18px;" /> we should rather <em>decrease</em> <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" /> a bit.</p>
<p>Let's take another example point, <img alt="x_0=-1" class="valign-m3" src="https://eli.thegreenplace.net/images/math/c84eef20ea61cf41b13fd1a157968eba20823c8e.png" style="height: 15px;" />. At that point, if we add a
little bit to <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" />, <img alt="f(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/3e03f4706048fbc6c5a252a85d066adf107fcc1f.png" style="height: 18px;" /> will <em>decrease</em> by 4x that little
bit. So that's exactly what we should do to get closer to the minimum.</p>
<p>It turns out that in both cases, we should nudge <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" /> in the direction
opposite to the derivative at <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" />. That's the most basic idea behind
gradient descent - the derivative shows us the way to the minimum; or rather,
it shows us the way to the maximum and we then go in the opposite direction.
Given some initial guess <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" />, the next guess will be:</p>
<img alt="\[x_1=x_0-\eta\frac{d}{dx}f(x_0)\]" class="align-center" src="https://eli.thegreenplace.net/images/math/d8666c1e2cf8740af45a228730f7c632fc00ed14.png" style="height: 37px;" />
<p>Where <img alt="\eta" class="valign-m4" src="https://eli.thegreenplace.net/images/math/2899aeb886ad0fa72652bffd5511e452aaf084ab.png" style="height: 12px;" /> is what we call a &quot;learning rate&quot;, and is constant for each
given update. It's the reason why we don't care much about the magnitude of the
derivative at <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" />, only its direction. In general, it makes sense to
keep the learning rate fairly small so we only make a tiny step at at time. This
makes sense mathematically, because the derivative at a point is defined as the
rate of change of <img alt="f(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/3e03f4706048fbc6c5a252a85d066adf107fcc1f.png" style="height: 18px;" /> assuming an infinitesimal change in <em>x</em>. For
some large change <em>x</em> who knows where we will get. It's easy to imagine cases
where we'll entirely overshoot the minimum by making too large a step <a class="footnote-reference" href="#id6" id="id2">[2]</a>.</p>
</div>
<div class="section" id="multivariate-functions-and-directional-derivatives">
<h2>Multivariate functions and directional derivatives</h2>
<p>With functions of multiple variables, derivatives become more interesting. We
can't just say &quot;the derivative points to where the function is increasing&quot;,
because... which derivative?</p>
<p>Recall the formal definition of the derivative as the limit for a small step
<em>h</em>. When our function has many variables, which one should have the step added?
One at a time? All at once? In multivariate calculus, we use partial derivatives
as building blocks. Let's use a function of two variables - <img alt="f(x,y)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/720aabe593c880dc58881240e567ecda2b89bdf4.png" style="height: 18px;" /> as an
example throughout this section, and define the partial derivatives w.r.t. <em>x</em>
and <em>y</em> at some point <img alt="(x_0,y_0)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/f8b63792829adeff8314a72fa87be1a770dfca85.png" style="height: 18px;" />:</p>
<img alt="\[\begin{align*} \frac{\partial }{\partial x}f(x_0,y_0)&amp;amp;=\lim_{h \to 0}\frac{f(x_0+h,y_0)-f(x_0,y_0)}{h} \\ \frac{\partial }{\partial y}f(x_0,y_0)&amp;amp;=\lim_{h \to 0}\frac{f(x_0,y_0+h)-f(x_0,y_0)}{h} \end{align*}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/b58dd3cada7292828cf79f3ca8653a99fd94c1f9.png" style="height: 87px;" />
<p>When we have a single-variable function <img alt="f(x)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/3e03f4706048fbc6c5a252a85d066adf107fcc1f.png" style="height: 18px;" />, there's really only two
directions in which we can move from a given point <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" /> - left (decrease
<em>x</em>) or right (increase <em>x</em>). With two variables, the number of possible
directions is <em>infinite</em>, becase we pick a direction to move on a 2D plane.
Hopefully this immediately pops ups &quot;vectors&quot; in your head, since vectors are
the perfect tool to deal with such problems. We can represent the change from
the point <img alt="(x_0,y_0)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/f8b63792829adeff8314a72fa87be1a770dfca85.png" style="height: 18px;" /> as the vector <img alt="\vec{v}=\langle a,b \rangle" class="valign-m5" src="https://eli.thegreenplace.net/images/math/4ef7c8a835491ba5ec6dc5f2b94ebff879938a21.png" style="height: 19px;" />
<a class="footnote-reference" href="#id7" id="id3">[3]</a>.
The <em>directional derivative</em> of <img alt="f(x,y)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/720aabe593c880dc58881240e567ecda2b89bdf4.png" style="height: 18px;" /> along <img alt="\vec{v}" class="valign-0" src="https://eli.thegreenplace.net/images/math/39a3a59a8f524cf72620db07b9ba7cdce9fc9391.png" style="height: 13px;" /> at
<img alt="(x_0,y_0)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/f8b63792829adeff8314a72fa87be1a770dfca85.png" style="height: 18px;" /> is defined as its rate of change in the direction of the
vector at that point. Mathematically, it's defined as:</p>
<img alt="\[\begin{equation} D_{\vec{v}}f(x_0,y_0)=\lim_{h \to 0}\frac{f(x_0+ah,y_0+bh)-f(x_0,y_0)}{h} \tag{1} \end{equation}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/1af5afd7427f744daa0c75b05697b32f21b2f40c.png" style="height: 39px;" />
<p>The partial derivatives w.r.t. <em>x</em> and <em>y</em> can be seen as special cases of this
definition. The partial derivative <img alt="\frac{\partial f}{\partial x}" class="valign-m7" src="https://eli.thegreenplace.net/images/math/75a2ab078215106a1084cf5e262e98f32c1cc3b9.png" style="height: 25px;" /> is just
the directional direvative in the direction of the <em>x</em> axis. In vector-speak,
this is the directional derivative for
<img alt="\vec{v}=\langle a,b \rangle=\widehat{e_x}=\langle 1,0 \rangle" class="valign-m5" src="https://eli.thegreenplace.net/images/math/36b4fd6cf884fd12b36c605cb6ec7a7c9b4ee65f.png" style="height: 19px;" />, the
standard basis vector for <em>x</em>. Just plug <img alt="a=1,b=0" class="valign-m4" src="https://eli.thegreenplace.net/images/math/7feadfc4043894ed6a3de2cced949a91bea9e5b2.png" style="height: 17px;" /> into (1) to see why.
Similarly, the partial derivative <img alt="\frac{\partial f}{\partial y}" class="valign-m9" src="https://eli.thegreenplace.net/images/math/5bc3d10d9714f8f7a95791fe29e497cf0ecbe3b0.png" style="height: 27px;" /> is the
directional derivative in the direction of the standard basis vector
<img alt="\widehat{e_y}=\langle 0,1 \rangle" class="valign-m6" src="https://eli.thegreenplace.net/images/math/3ce4793144c7bfd02d245b81f8bd44a595721196.png" style="height: 20px;" />.</p>
</div>
<div class="section" id="a-visual-interlude">
<h2>A visual interlude</h2>
<p>Functions of two variables <img alt="f(x,y)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/720aabe593c880dc58881240e567ecda2b89bdf4.png" style="height: 18px;" /> are the last frontier for meaningful
visualizations, for which we need 3D to plot the value of <img alt="f" class="valign-m4" src="https://eli.thegreenplace.net/images/math/4a0a19218e082a343a1b17e5333409af9d98f0f5.png" style="height: 16px;" /> for each
given <em>x</em> and <em>y</em>. Even in 3D, visualizing gradients is significantly harder
than in 2D, and yet we have to try since for anything above two variables all
hopes of visualization are lost.</p>
<p>Here's the function <img alt="f(x,y)=x^2+y^2" class="valign-m4" src="https://eli.thegreenplace.net/images/math/d3eb0fc536d00e84cd63bb5af98b7e2bc01bde4f.png" style="height: 19px;" /> plotted in a small range around zero.
I drew the standard basis vectors <img alt="\widehat{x}=\widehat{e_x}" class="valign-m3" src="https://eli.thegreenplace.net/images/math/0ea0752aa73540ee1e464a42d5d1b2b9741d3eab.png" style="height: 17px;" /> and
<img alt="\widehat{y}=\widehat{e_y}" class="valign-m6" src="https://eli.thegreenplace.net/images/math/c0bf47cb98b1f01e6b47992929694ec9da20f8f7.png" style="height: 20px;" /> <a class="footnote-reference" href="#id8" id="id4">[4]</a> and some combination of them
<img alt="\vec{v}" class="valign-0" src="https://eli.thegreenplace.net/images/math/39a3a59a8f524cf72620db07b9ba7cdce9fc9391.png" style="height: 13px;" />.</p>
<img alt="3D parabola with direction vector markers" class="align-center" src="https://eli.thegreenplace.net/images/2016/plot-3d-parabola.png" />
<p>I also marked the point on <img alt="f(x,y)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/720aabe593c880dc58881240e567ecda2b89bdf4.png" style="height: 18px;" /> where the vectors are based. The goal
is to help us keep in mind how the independent variables <em>x</em> and <em>y</em> change, and
how that affects <img alt="f(x,y)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/720aabe593c880dc58881240e567ecda2b89bdf4.png" style="height: 18px;" />. We change <em>x</em> and <em>y</em> by adding some small
vector <img alt="\vec{v}" class="valign-0" src="https://eli.thegreenplace.net/images/math/39a3a59a8f524cf72620db07b9ba7cdce9fc9391.png" style="height: 13px;" /> to their current value. The result is &quot;nudging&quot;
<img alt="f(x,y)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/720aabe593c880dc58881240e567ecda2b89bdf4.png" style="height: 18px;" /> in the direction of <img alt="\vec{v}" class="valign-0" src="https://eli.thegreenplace.net/images/math/39a3a59a8f524cf72620db07b9ba7cdce9fc9391.png" style="height: 13px;" />. Remember our goal for this
article - find <img alt="\vec{v}" class="valign-0" src="https://eli.thegreenplace.net/images/math/39a3a59a8f524cf72620db07b9ba7cdce9fc9391.png" style="height: 13px;" /> such that this &quot;nudge&quot; gets us closer to a
minimum.</p>
</div>
<div class="section" id="finding-directional-derivatives-using-the-gradient">
<h2>Finding directional derivatives using the gradient</h2>
<p>As we've seen, the derivative of <img alt="f(x,y)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/720aabe593c880dc58881240e567ecda2b89bdf4.png" style="height: 18px;" /> in the direction of
<img alt="\vec{v}" class="valign-0" src="https://eli.thegreenplace.net/images/math/39a3a59a8f524cf72620db07b9ba7cdce9fc9391.png" style="height: 13px;" /> is defined as:</p>
<img alt="\[D_{\vec{v}}f(x_0,y_0)=\lim_{h \to 0}\frac{f(x_0+ah,y_0+bh)-f(x_0,y_0)}{h}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/9f2c62d64f016bd77712873294a0f5e64537b1ab.png" style="height: 39px;" />
<p>Looking at the 3D plot above, this is how much the value of <img alt="f(x,y)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/720aabe593c880dc58881240e567ecda2b89bdf4.png" style="height: 18px;" />
changes when we add <img alt="\vec{v}" class="valign-0" src="https://eli.thegreenplace.net/images/math/39a3a59a8f524cf72620db07b9ba7cdce9fc9391.png" style="height: 13px;" /> to the vector <img alt="\langle x_0,y_0 \rangle" class="valign-m5" src="https://eli.thegreenplace.net/images/math/f74aa2c6fda35535931fad69ec339eaef3693913.png" style="height: 19px;" />. But how do we do that? This limit definition doesn't look like
something friendly for analytical analysis for arbitrary functions. Sure, we
could plug <img alt="(x_0,y_0)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/f8b63792829adeff8314a72fa87be1a770dfca85.png" style="height: 18px;" /> and <img alt="\vec{v}" class="valign-0" src="https://eli.thegreenplace.net/images/math/39a3a59a8f524cf72620db07b9ba7cdce9fc9391.png" style="height: 13px;" /> in there and do the
computation, but it would be nice to have an easier-to-use formula. Luckily,
with the help of the gradient of <img alt="f(x,y)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/720aabe593c880dc58881240e567ecda2b89bdf4.png" style="height: 18px;" /> it becomes much easier.</p>
<p>The gradient is a vector value we compute from a scalar function. It's defined
as:</p>
<img alt="\[\nabla f=\left \langle \frac{\partial f}{\partial x},\frac{\partial f}{\partial y} \right \rangle\]" class="align-center" src="https://eli.thegreenplace.net/images/math/03eab64984be412b6db132c2534bbecc006af47c.png" style="height: 43px;" />
<p>It turns out that given a vector <img alt="\vec{v}" class="valign-0" src="https://eli.thegreenplace.net/images/math/39a3a59a8f524cf72620db07b9ba7cdce9fc9391.png" style="height: 13px;" />, the directional derivative
<img alt="D_{\vec{v}}f" class="valign-m4" src="https://eli.thegreenplace.net/images/math/03a3931c968b3b6f26e82958785539d74db94293.png" style="height: 16px;" /> can be expressed as the following dot product:</p>
<img alt="\[D_{\vec{v}}f=(\nabla f) \cdot \vec{v}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/49933775272512c4c8686d9f9692c8ea01e1c97d.png" style="height: 18px;" />
<p>If this looks like a mental leap too big to trust, please read the Appendix
section at the bottom. Otherwise, feel free to verify that the two are
equivalent with a couple of examples. For instance, try to find the derivative
in the direction of <img alt="\vec{v}=\langle \frac{1}{\sqrt{2}},\frac{1}{\sqrt{2}} \rangle" class="valign-m11" src="https://eli.thegreenplace.net/images/math/d614069c5beaf6fb858de40fa492a7b523a683d9.png" style="height: 27px;" />
at <img alt="(x_0,y_0)=(-1.5,0.25)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/61355565f13944faf85baec62c5fc1a682b0b5d5.png" style="height: 18px;" />. You should get <img alt="\frac{-2.5}{\sqrt{2}}" class="valign-m11" src="https://eli.thegreenplace.net/images/math/0c22fc563236a48f94882876c68f6edc0c3fb4da.png" style="height: 27px;" /> using
both methods.</p>
</div>
<div class="section" id="direction-of-maximal-change">
<h2>Direction of maximal change</h2>
<p>We're almost there! Now that we have a relatively simple way of computing any
directional derivative from the partial derivatives of a function, we can
figure out which direction to take to get the maximal change in the value of
<img alt="f(x,y)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/720aabe593c880dc58881240e567ecda2b89bdf4.png" style="height: 18px;" />.</p>
<p>We can rewrite:</p>
<img alt="\[D_{\vec{v}}f=(\nabla f) \cdot \vec{v}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/49933775272512c4c8686d9f9692c8ea01e1c97d.png" style="height: 18px;" />
<p>As:</p>
<img alt="\[D_{\vec{v}}f=\left \| \nabla f \right \| \left \| \vec{v} \right \| cos(\theta)\]" class="align-center" src="https://eli.thegreenplace.net/images/math/8227de3117c60690ced3153cdc38d9bccd960fba.png" style="height: 19px;" />
<p>Where <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> is the angle between the two vectors. Now, recall that
<img alt="\vec{v}" class="valign-0" src="https://eli.thegreenplace.net/images/math/39a3a59a8f524cf72620db07b9ba7cdce9fc9391.png" style="height: 13px;" /> is normalized so its magnitude is 1. Therefore, we only care
about the <em>direction</em> of <img alt="\vec{v}" class="valign-0" src="https://eli.thegreenplace.net/images/math/39a3a59a8f524cf72620db07b9ba7cdce9fc9391.png" style="height: 13px;" /> w.r.t. the gradient. When is this
equation maximized? When <img alt="\theta=0" class="valign-0" src="https://eli.thegreenplace.net/images/math/a1dffbe89f1ec5a919198de979fca459eb7fdf84.png" style="height: 12px;" />, because then <img alt="cos(\theta)=1" class="valign-m4" src="https://eli.thegreenplace.net/images/math/66a6eb87ec7f340e2e24bd46cdf02ab050013aac.png" style="height: 18px;" />.
Since a cosine can never be larger than 1, that's the best we can have.</p>
<p>So <img alt="\theta=0" class="valign-0" src="https://eli.thegreenplace.net/images/math/a1dffbe89f1ec5a919198de979fca459eb7fdf84.png" style="height: 12px;" /> gives us the largest positive change in <img alt="f(x,y)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/720aabe593c880dc58881240e567ecda2b89bdf4.png" style="height: 18px;" />. To
get <img alt="\theta=0" class="valign-0" src="https://eli.thegreenplace.net/images/math/a1dffbe89f1ec5a919198de979fca459eb7fdf84.png" style="height: 12px;" />, <img alt="\vec{v}" class="valign-0" src="https://eli.thegreenplace.net/images/math/39a3a59a8f524cf72620db07b9ba7cdce9fc9391.png" style="height: 13px;" /> has to point in the same direction as the
gradient. Similarly, for <img alt="\theta=180^{\circ}" class="valign-m1" src="https://eli.thegreenplace.net/images/math/f35bd3cc416e154fddabe833458147c566028a8c.png" style="height: 13px;" /> we get
<img alt="cos(\theta)=-1" class="valign-m4" src="https://eli.thegreenplace.net/images/math/65b96d5ab442e325098894e80d655263a24b14d6.png" style="height: 18px;" /> and therefore the largest <em>negative</em> change in
<img alt="f(x,y)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/720aabe593c880dc58881240e567ecda2b89bdf4.png" style="height: 18px;" />. So if we want to decrease <img alt="f(x,y)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/720aabe593c880dc58881240e567ecda2b89bdf4.png" style="height: 18px;" /> the most,
<img alt="\vec{v}" class="valign-0" src="https://eli.thegreenplace.net/images/math/39a3a59a8f524cf72620db07b9ba7cdce9fc9391.png" style="height: 13px;" /> has to point in the opposite direction of the gradient.</p>
</div>
<div class="section" id="gradient-descent-update-for-multivariate-functions">
<h2>Gradient descent update for multivariate functions</h2>
<p>To sum up, given some starting point <img alt="(x_0,y_0)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/f8b63792829adeff8314a72fa87be1a770dfca85.png" style="height: 18px;" />, to nudge it in the
direction of the minimum of <img alt="f(x,y)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/720aabe593c880dc58881240e567ecda2b89bdf4.png" style="height: 18px;" />, we first compute the gradient of
<img alt="f(x,y)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/720aabe593c880dc58881240e567ecda2b89bdf4.png" style="height: 18px;" /> at <img alt="(x_0,y_0)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/f8b63792829adeff8314a72fa87be1a770dfca85.png" style="height: 18px;" />. Then, we update (using vector notation):</p>
<img alt="\[\langle x_1,y_1 \rangle=\langle x_0,y_0 \rangle-\eta \nabla{f(x_0,y_0)}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/66a0a92b6ff9a4c0d2162a41484ab17115f57bd7.png" style="height: 19px;" />
<p>Generalizing to multiple dimensions, let's say we have the function
<img alt="f:\mathbb{R}^n\rightarrow \mathbb{R}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/5b4aba3ea35b9daec583b61ecb5a556ae28103e3.png" style="height: 16px;" /> taking the n-dimensional vector
<img alt="\vec{x}=(x_1,x_2 \dots ,x_n)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/e8ece11f27b7cf7726e6ea055cfb0718761733e0.png" style="height: 18px;" />. We define the gradient update at step <em>k</em>
to be:</p>
<img alt="\[\vec{x}_{(k)}=\vec{x}_{(k-1)} - \eta \nabla{f(\vec{x}_{(k-1)})}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/265d53b7832258e30f00049a1772e9f213140628.png" style="height: 20px;" />
<p>Previously, for the single-variate case we said that the derivatve points us
to the way to the minimum. Now we can say that while there are many ways to
get to the minimum (eventually), the gradient points us to the <em>fastest</em> way
from any given point.</p>
</div>
<div class="section" id="appendix-directional-derivative-definition-and-gradient">
<h2>Appendix: directional derivative definition and gradient</h2>
<p>This is an optional section for those who don't like taking mathematical
statements for granted. Now it's time to prove the equation shown earlier in
the article, and on which its main result is based:</p>
<img alt="\[D_{\vec{v}}f=(\nabla f) \cdot \vec{v}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/49933775272512c4c8686d9f9692c8ea01e1c97d.png" style="height: 18px;" />
<p>As usual with proofs, it really helps to start by working through an example or
two to build up some intuition into why the equation works. Feel free to do that
if you'd like, using any function, starting point and direction vector
<img alt="\vec{v}" class="valign-0" src="https://eli.thegreenplace.net/images/math/39a3a59a8f524cf72620db07b9ba7cdce9fc9391.png" style="height: 13px;" />.</p>
<p>Suppose we define a function <img alt="w(t)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/0382ffc90ae7b4c24238f68a32bebd14bc53c8d7.png" style="height: 18px;" /> as follows:</p>
<img alt="\[w(t)=f(x,y)\]" class="align-center" src="https://eli.thegreenplace.net/images/math/dc37eb3cf47966d7338e561faffeffbb291085c5.png" style="height: 18px;" />
<p>Where <img alt="x=x(t)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/97aeb925cf8f501cc8836794ee06fb357b9d9a83.png" style="height: 18px;" /> and <img alt="y=y(t)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/ebacc26a97fccf1aa96e1b59f21fcb2ca66c8924.png" style="height: 18px;" /> defined as:</p>
<img alt="\[\begin{align*} x(t)&amp;amp;=x_0+at \\ y(t)&amp;amp;=y_0+bt \end{align*}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/27988a5772de0fe761873494e88f7cad887ede85.png" style="height: 45px;" />
<p>In these definitions, <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" />, <img alt="y_0" class="valign-m4" src="https://eli.thegreenplace.net/images/math/2bb5817d0f3bf8490a8c7b1343f84f9635e683a3.png" style="height: 12px;" />, <em>a</em> and <em>b</em> are constants, so
both <img alt="x(t)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/62b10cd9e1396c7ea33fd211e67de2fb29019cfc.png" style="height: 18px;" /> and <img alt="y(t)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/ed8576b7227103b62d3648e7d1bbdff4052b27ff.png" style="height: 18px;" /> are truly functions of a single variable.
Using <a class="reference external" href="http://eli.thegreenplace.net/2016/the-chain-rule-of-calculus">the chain rule</a>), we know that:</p>
<img alt="\[\frac{dw}{dt}=\frac{\partial f}{\partial x}\frac{dx}{dt}+\frac{\partial f}{\partial y}\frac{dy}{dt}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/d5f4f13aeba35328cd2bea9b247842acb7524724.png" style="height: 41px;" />
<p>Substituting the derivatives of <img alt="x(t)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/62b10cd9e1396c7ea33fd211e67de2fb29019cfc.png" style="height: 18px;" /> and <img alt="y(t)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/ed8576b7227103b62d3648e7d1bbdff4052b27ff.png" style="height: 18px;" />, we get:</p>
<img alt="\[\frac{dw}{dt}=a\frac{\partial f}{\partial x}+b\frac{\partial f}{\partial y}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/829069469d88717c9d95e3f788ed9e0c6cbeebc6.png" style="height: 41px;" />
<p>One more step, the significance of which will become clear shortly. Specifically,
the derivative of <img alt="w(t)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/0382ffc90ae7b4c24238f68a32bebd14bc53c8d7.png" style="height: 18px;" /> at <img alt="t=0" class="valign-0" src="https://eli.thegreenplace.net/images/math/31056375cdff6a052261f18ceb3afe466731302a.png" style="height: 12px;" /> is:</p>
<img alt="\[\begin{equation} \frac{d}{dt}w(0)=a\frac{\partial}{\partial x}f(x_0,y_0)+b\frac{\partial}{\partial y}f(x_0,y_0) \tag{2} \end{equation}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/ea579cad8f6c62a817f2253e1d596178ea673d37.png" style="height: 41px;" />
<p>Now let's see how to compute the derivative of <img alt="w(t)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/0382ffc90ae7b4c24238f68a32bebd14bc53c8d7.png" style="height: 18px;" /> at <img alt="t=0" class="valign-0" src="https://eli.thegreenplace.net/images/math/31056375cdff6a052261f18ceb3afe466731302a.png" style="height: 12px;" /> using
the formal limit definition:</p>
<img alt="\[\begin{align*} \frac{d}{dt}w(0)&amp;amp;=\lim_{h \to 0}\frac{w(h)-w(0)}{h} \\ &amp;amp;=\lim_{h \to 0}\frac{f(x_0+ah,b_0+bh)-f(x_0,y_0)}{h} \end{align*}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/10a224da7b7ab2424b9f88edcbfe17f273f3bd8b.png" style="height: 84px;" />
<p>But the latter is precisely the definition of the directional derivative in
equation (1). Therefore, we can say that:</p>
<img alt="\[\frac{d}{dt}w(0)=D_{\vec{v}}f(x_0,y_0)\]" class="align-center" src="https://eli.thegreenplace.net/images/math/4f377110022c468e46cbdb32bfb11a072d11b330.png" style="height: 37px;" />
<p>From this and (2), we get:</p>
<img alt="\[\frac{d}{dt}w(0)=D_{\vec{v}}f(x_0,y_0)=a\frac{\partial}{\partial x}f(x_0,y_0)+b\frac{\partial}{\partial y}f(x_0,y_0)\]" class="align-center" src="https://eli.thegreenplace.net/images/math/d259ebf3697f480823a40247ce7191f9e954a584.png" style="height: 41px;" />
<p>This derivation is not special to the point <img alt="(x_0,y_0)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/f8b63792829adeff8314a72fa87be1a770dfca85.png" style="height: 18px;" /> - it works just as
well for any point where <img alt="f(x,y)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/720aabe593c880dc58881240e567ecda2b89bdf4.png" style="height: 18px;" /> has partial derivatives w.r.t. <em>x</em> and
<em>y</em>; therefore, for any point <img alt="(x,y)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/d330d6e65470cb03e76e092ee47971f9e931f759.png" style="height: 18px;" /> where <img alt="f(x,y)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/720aabe593c880dc58881240e567ecda2b89bdf4.png" style="height: 18px;" /> is
differentiable:</p>
<img alt="\[\begin{align*} D_{\vec{v}}f(x,y)&amp;amp;=a\frac{\partial}{\partial x}f(x,y)+b\frac{\partial}{\partial y}f(x,y) \\ &amp;amp;=\left \langle \frac{\partial f}{\partial x},\frac{\partial f}{\partial y} \right \rangle \cdot \langle a,b \rangle \\ &amp;amp;=(\nabla f) \cdot \vec{v} \qedhere \end{align*}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/7c306dfedd474d99a62894e6258cea186d8be428.png" style="height: 115px;" />
<hr class="docutils" />
<table class="docutils footnote" frame="void" id="id5" rules="none">
<colgroup><col class="label" /><col /></colgroup>
<tbody valign="top">
<tr><td class="label"><a class="fn-backref" href="#id1">[1]</a></td><td>The notation <img alt="\frac{d}{dx}f(x_0)" class="valign-m6" src="https://eli.thegreenplace.net/images/math/b0d6f765abf215972d5dbb982f77f1a83c233066.png" style="height: 22px;" /> means: the value of the
derivative of <img alt="f" class="valign-m4" src="https://eli.thegreenplace.net/images/math/4a0a19218e082a343a1b17e5333409af9d98f0f5.png" style="height: 16px;" /> w.r.t. <em>x</em>, evaluated at <img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" />. Another
way to say the same would be <img alt="f{}&amp;#x27;(x_0)" class="valign-m4" src="https://eli.thegreenplace.net/images/math/e11c4ee90d42c3261aec6ef9c71893411b11cf34.png" style="height: 18px;" />.</td></tr>
</tbody>
</table>
<table class="docutils footnote" frame="void" id="id6" rules="none">
<colgroup><col class="label" /><col /></colgroup>
<tbody valign="top">
<tr><td class="label"><a class="fn-backref" href="#id2">[2]</a></td><td>That said, in some advanced variations of gradient descent we actually
want to probe different areas of the function early on in the process,
so a larger step makes sense (remember, realistic functions have many
local minima and we want to find the best one). Further along in the
optimization process, when we've settled on a general area of the
function we want the learning rate to be small so we actually get to the
minimum. This approach is called <em>annealing</em> and I'll leave it for some
future article.</td></tr>
</tbody>
</table>
<table class="docutils footnote" frame="void" id="id7" rules="none">
<colgroup><col class="label" /><col /></colgroup>
<tbody valign="top">
<tr><td class="label"><a class="fn-backref" href="#id3">[3]</a></td><td>To avoid tracking vector magnitudes, from now on in the article we'll
be dealing with <em>normalized</em> direction vectors. That is, we always assume
that <img alt="\left \| \vec{v} \right \|=1" class="valign-m5" src="https://eli.thegreenplace.net/images/math/d68cb9ca8e7b5fd7fe4a7c4548ed5d98b63292eb.png" style="height: 19px;" />.</td></tr>
</tbody>
</table>
<table class="docutils footnote" frame="void" id="id8" rules="none">
<colgroup><col class="label" /><col /></colgroup>
<tbody valign="top">
<tr><td class="label"><a class="fn-backref" href="#id4">[4]</a></td><td>Yes, <img alt="\widehat{y}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/8cf4f01720ca8008752c182a8d3443aa2b174442.png" style="height: 18px;" /> is actually going in the opposite direction so
it's <img alt="-\widehat{e_y}" class="valign-m6" src="https://eli.thegreenplace.net/images/math/160a7a02c9645a3948812151b7a0cf38eb29c562.png" style="height: 20px;" />, but that really doesn't change anything.
It was easier to draw :)</td></tr>
</tbody>
</table>
</div>
The Normal Equation and matrix calculus2015-05-27T06:19:00-07:002015-05-27T06:19:00-07:00Eli Benderskytag:eli.thegreenplace.net,2015-05-27:/2015/the-normal-equation-and-matrix-calculus/<p>A few months ago I wrote <a class="reference external" href="http://eli.thegreenplace.net/2014/derivation-of-the-normal-equation-for-linear-regression">a post</a>
on formulating the Normal Equation for linear regression. A crucial part in the
formulation is using <a class="reference external" href="http://en.wikipedia.org/wiki/Matrix_calculus">matrix calculus</a> to compute a scalar-by-vector
derivative. I didn't spend much time explaining how this step works, instead
remarking:</p>
<blockquote>
Deriving by a vector may feel uncomfortable …</blockquote><p>A few months ago I wrote <a class="reference external" href="http://eli.thegreenplace.net/2014/derivation-of-the-normal-equation-for-linear-regression">a post</a>
on formulating the Normal Equation for linear regression. A crucial part in the
formulation is using <a class="reference external" href="http://en.wikipedia.org/wiki/Matrix_calculus">matrix calculus</a> to compute a scalar-by-vector
derivative. I didn't spend much time explaining how this step works, instead
remarking:</p>
<blockquote>
Deriving by a vector may feel uncomfortable, but there's nothing to worry
about. Recall that here we only use matrix notation to conveniently
represent a system of linear formulae. So we derive by each component of the
vector, and then combine the resulting derivatives into a vector again.</blockquote>
<p>According to the comments received on the post, folks didn't find this
convincing and asked for more details. One commenter even said that &quot;matrix
calculus feels handwavy&quot;, something which I fully agree with. The reason matrix
calculus feels handwavy is that it's not as commonly encountered as &quot;regular&quot;
calculus, and hence its identities and intuitions are not as familiar. However,
there's really not that much to it, as I want to show here.</p>
<p>Let's get started with a simple example, which I'll use to demonstrate the
principles. Say we have the function:</p>
<img alt="\[f(v)=a^Tv\]" class="align-center" src="https://eli.thegreenplace.net/images/math/94f87149715376908db65a00f793836a4b2092a9.png" style="height: 21px;" />
<p>Where <strong>a</strong> and <strong>v</strong> are vectors with <em>n</em> components <a class="footnote-reference" href="#id4" id="id1">[1]</a>. We want to compute
its derivative by <strong>v</strong>. But wait, while a &quot;regular&quot; derivative by a scalar is
clearly defined (using limits), what does deriving by a vector mean? It simply
means that we derive by each component of the vector separately, and then
combine the results into a new vector <a class="footnote-reference" href="#id5" id="id2">[2]</a>. In other words:</p>
<img alt="\[\frac{\partial f}{\partial v}=\begin{pmatrix}\frac{\partial f}{\partial v_1}\\[1em] \frac{\partial f}{\partial v_2}\\ ...\\ \frac{\partial f}{\partial v_n}\\[1em] \end{pmatrix}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/13d227107c5323f47460ad077504fda60726d933.png" style="height: 131px;" />
<p>Let's see how this works out for our function <em>f</em>. It may be more convenient to
rewrite it by using components rather than vector notation:</p>
<img alt="\[f(v)=a^Tv=a_1v_1+a_2v_2+...+a_nv_n\]" class="align-center" src="https://eli.thegreenplace.net/images/math/e9e17e44bb85d825f304b09247a7f3cfbe11f64e.png" style="height: 21px;" />
<p>Computing the derivatives by each component, we'll get:</p>
<img alt="\[\begin{matrix}\frac{\partial f}{\partial v_1}=a_1\\[1em] \frac{\partial f}{\partial v_2}=a_2\\ ...\\ \frac{\partial f}{\partial v_n}=a_n\\[1em] \end{matrix}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/768563e5b7e2e8cddd00830e9b945419f598e4bb.png" style="height: 114px;" />
<p>So we have a sequence of partial derivatives, which we combine into a vector:</p>
<img alt="\[\frac{\partial f}{\partial v}=\begin{pmatrix}a_1\\ ...\\ a_n\\ \end{pmatrix}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/b13cc64568603d73240709c1fb49cfcc7f2a2b62.png" style="height: 65px;" />
<p>Or, in other words <img alt="\frac{\partial f}{\partial v}=a" class="valign-m7" src="https://eli.thegreenplace.net/images/math/1f3eaea99f7fab11ac1b70dc8b618635a9ed4c91.png" style="height: 25px;" />.</p>
<p>This example demonstrates the algorithm for computing scalar-by-vector
derivatives:</p>
<ol class="arabic simple">
<li>Figure out what the dimensions of all vectors and matrices are.</li>
<li>Expand the vector equations into their full form (a multiplication of two
vectors is either a scalar or a matrix, depending on their orientation, etc.)
Note that this will end up with a scalar.</li>
<li>Compute the derivative of the scalar by each component of the variable vector
separately.</li>
<li>Combine the derivatives into a vector.</li>
</ol>
<p>Similarly to regular calculus, matrix and vector calculus rely on a set of
identities to make computations more manageable. We can either go the hard way
(computing the derivative of each function from basic principles using limits),
or the easy way - applying the plethora of convenient identities that were
developed to make this task simpler. The identity for computing the derivative
of <img alt="a^Tv" class="valign-0" src="https://eli.thegreenplace.net/images/math/ea7bffcd29c6bad40e358ad7313102670fb1a9cf.png" style="height: 15px;" /> shown above plays the role of <img alt="\frac{d}{dx}ax=a" class="valign-m6" src="https://eli.thegreenplace.net/images/math/999f262480b3690892d0af5651b96160d924997e.png" style="height: 22px;" /> in
regular calculus.</p>
<p>Now we have the tools to understand how the vector derivatives in the
<a class="reference external" href="http://eli.thegreenplace.net/2014/derivation-of-the-normal-equation-for-linear-regression">normal equation article</a>
were computed. As a reminder, this is the matrix form of the cost function <em>J</em>:</p>
<img alt="\[J(\theta)=\theta^TX^TX\theta-2(X\theta)^Ty+y^Ty\]" class="align-center" src="https://eli.thegreenplace.net/images/math/2864b88546c007a79dc92271f5e01487ba608e43.png" style="height: 21px;" />
<p>And we're interested in computing <img alt="\frac{\partial J}{\partial \theta}" class="valign-m7" src="https://eli.thegreenplace.net/images/math/27ffac3eede7fce0b342abf8fc10d29f24c68263.png" style="height: 24px;" />.
The equation for <em>J</em> consists of three terms added together. The last one
<img alt="y^Ty" class="valign-m4" src="https://eli.thegreenplace.net/images/math/81015d6225923cec985bef47ca151ef1cb654c92.png" style="height: 19px;" /> doesn't contribute to the derivative because it doesn't depend on
the variable. Let's start looking at the second (since it's simpler than the
first) - and give it a name, for convenience:</p>
<img alt="\[P(\theta)=2(X\theta)^Ty\]" class="align-center" src="https://eli.thegreenplace.net/images/math/35d3ddf05898e8bc2085030aa399ce98318674f9.png" style="height: 21px;" />
<p>We'll start by recalling what all the dimensions are. <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> is a vector
of n components. <img alt="y" class="valign-m4" src="https://eli.thegreenplace.net/images/math/95cb0bfd2977c761298d9624e4b4d4c72a39974a.png" style="height: 12px;" /> is a vector of m components. <img alt="X" class="valign-0" src="https://eli.thegreenplace.net/images/math/c032adc1ff629c9b66f22749ad667e6beadf144b.png" style="height: 12px;" /> is a m-by-n
matrix.</p>
<p>Let's see what <em>P</em> expands to <a class="footnote-reference" href="#id6" id="id3">[3]</a>:</p>
<img alt="\[P(\theta)=2\left [ \begin{pmatrix} x_1_1 &amp;amp; x_1_2 &amp;amp; ... &amp;amp; x_1_n\\ x_2_1 &amp;amp; ... &amp;amp; ... &amp;amp; x_2_n\\ ...\\ x_m_1 &amp;amp; ... &amp;amp; ... &amp;amp; x_m_n\\ \end{pmatrix}\begin{pmatrix} \theta_1\\ \theta_2\\ ...\\ \theta_n\\ \end{pmatrix} \right ]^T\begin{pmatrix} y_1\\ y_2\\ ...\\ y_m\\ \end{pmatrix}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/a7873ed04e274b30852e0f8d9450b5abc200ac17.png" style="height: 91px;" />
<p>Computing the matrix-by-vector multiplication inside the parens:</p>
<img alt="\[P(x)=2\left [ \begin{pmatrix} x_1_1\theta_1+...+x_1_n\theta_n\\ x_2_1\theta_1+...+x_2_n\theta_n\\ ...\\ x_m_1\theta_1+...+x_m_n\theta_n \end{pmatrix} \right ]^T\begin{pmatrix} y_1\\ y_2\\ ...\\ y_m\\ \end{pmatrix}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/6b9b8e2335579f352a19ef3da609be2e8b2d9925.png" style="height: 91px;" />
<p>And finally, multiplying the two vectors together:</p>
<img alt="\[P(x)=2(x_1_1\theta_1+...+x_1_n\theta_n)y_1+2(x_2_1\theta_1+...+x_2_n\theta_n)y_2+...+2(x_m_1\theta_1+...+x_m_n\theta_n)y_m\]" class="align-center" src="https://eli.thegreenplace.net/images/math/3271758ac98b149969516dd809fd35b90aacf056.png" style="height: 18px;" />
<p>Working with such formulae makes you appreciate why mathematicians have long
ago come up with shorthand notations like &quot;sigma&quot; summation:</p>
<img alt="\[P(x)=2\sum_{r=1}^{m}y_r(x_r_1\theta_1+...+x_r_n\theta_n)=2\sum_{r=1}^{m}y_r\sum_{c=1}^{n}x_r_c\theta_c\]" class="align-center" src="https://eli.thegreenplace.net/images/math/6c71eb575ab3fafbc7b268be33d0d17a37bb1553.png" style="height: 50px;" />
<p>OK, so we've finally completed step 2 of the algorithm - we have the scalar
equation for <em>P</em>. Now it's time to compute its derivative by each
<img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" />:</p>
<img alt="\[\begin{matrix} \frac{\partial P}{\partial \theta_1}=2(x_1_1y_1+...+x_m_1y_m)\\[1em] \frac{\partial P}{\partial \theta_2}=2(x_1_2y_1+...+x_m_2y_m)\\ ...\\ \frac{\partial P}{\partial \theta_n}=2(x_1_ny_1+...+x_m_ny_m) \end{matrix}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/889eb3c4e50b4fbdf5380c4d4e31ac4c0c09dddd.png" style="height: 111px;" />
<p>Now comes the most interesting part. If we treat
<img alt="\frac{\partial P}{\partial \theta}" class="valign-m7" src="https://eli.thegreenplace.net/images/math/3c653fa292156c8914f1463fcb6869633d37487c.png" style="height: 24px;" /> as a vector of n components, we can
rewrite this set of equations using a matrix-by-vector multiplication:</p>
<img alt="\[\frac{\partial P}{\partial \theta}=2X^Ty\]" class="align-center" src="https://eli.thegreenplace.net/images/math/7f75aa0f038ca73c58e95ef604ffb54468a18ae2.png" style="height: 38px;" />
<p>Take a moment to convince yourself this is true. It's just collecting the
individual components of <strong>X</strong> into a matrix and the individual components of
<strong>y</strong> into a vector. Since <strong>X</strong> is a m-by-n matrix and <strong>y</strong> is a m-by-1 column
vector, the dimensions work out and the result is a n-by-1 column vector.</p>
<p>So we've just computed the second term of the vector derivative of <em>J</em>. In the
process, we've discovered a useful vector derivative identity for a matrix <strong>X</strong>
and vectors <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> and <strong>y</strong>:</p>
<img alt="\[\frac{\partial (X\theta)^T y}{\partial \theta}=X^Ty\]" class="align-center" src="https://eli.thegreenplace.net/images/math/bf7325787bc464f067372a6d4ed612ea514d29b6.png" style="height: 41px;" />
<p>OK, now let's get back to the full definition of <em>J</em> and see how to compute the
derivative of its first term. We'll give it the name <em>Q</em>:</p>
<img alt="\[Q(\theta)=\theta^TX^TX\theta\]" class="align-center" src="https://eli.thegreenplace.net/images/math/0031acbab8dba6cef63f2605a15a0b7bc826766a.png" style="height: 21px;" />
<p>This derivation is somewhat more complex, since <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> appears twice in
the equation. Here's the equation again with all the matrices and vectors fully
laid out (note that I've already done the transposes):</p>
<img alt="\[Q(\theta)=(\theta_1...\theta_n)\begin{pmatrix}x_1_1 &amp;amp; x_2_1 &amp;amp; ... &amp;amp; x_m_1\\ x_1_2 &amp;amp; ... &amp;amp; ... &amp;amp; x_m_2\\ ...\\ x_1_n &amp;amp; ... &amp;amp; ... &amp;amp; x_m_n\\ \end{pmatrix}\begin{pmatrix}x_1_1 &amp;amp; x_1_2 &amp;amp; ... &amp;amp; x_1_n\\ x_2_1 &amp;amp; ... &amp;amp; ... &amp;amp; x_2_n\\ ...\\ x_m_1 &amp;amp; ... &amp;amp; ... &amp;amp; x_m_n\\ \end{pmatrix}\begin{pmatrix} \theta_1\\ \theta_2\\ ...\\ \theta_n\\ \end{pmatrix}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/b3f9b4ffe1853d6610f9814fc820d1a71825a06e.png" style="height: 87px;" />
<p>I'll just multiply the two matrices in the middle together. The result is a
&quot;<strong>X</strong> squared&quot; matrix, which is n-by-n. The element in row <em>r</em> and column <em>c</em>
of this square matrix is:</p>
<img alt="\[\sum_{i=1}^{m}x_i_rx_i_c\]" class="align-center" src="https://eli.thegreenplace.net/images/math/f8628d68855e03195fb4fd01806c8655beaf7b30.png" style="height: 50px;" />
<p>Note that &quot;<strong>X</strong> squared&quot; is a symmetric matrix (this fact will be important
later on). For simplicity of notation, we'll call its elements
<img alt="X^{2}_{rc}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/c565201908a5c75f62849e7c1634b65e0930824c.png" style="height: 19px;" />. Multiplying by the <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> vector on the right we
get:</p>
<img alt="\[Q(\theta)=(\theta_1...\theta_n)\begin{pmatrix}X^{2}_{11}\theta_1+...+X^{2}_{1n}\theta_n\\[1em] X^{2}_{21}\theta_1+...+X^{2}_{2n}\theta_n\\ ...\\ X^{2}_{n1}\theta_1+...+X^{2}_{nn}\theta_n\end{pmatrix}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/5821cf256f6cf6debbdac48d6e9bbe698baa0a11.png" style="height: 107px;" />
<p>And left-multiplying by <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> to get the fully unwrapped formula for
<em>Q</em>:</p>
<img alt="\[Q(\theta)=\theta_1(X^{2}_{11}\theta_1+...+X^{2}_{1n}\theta_n)+\theta_2(X^{2}_{21}\theta_1+...+X^{2}_{2n}\theta_n)+...+\theta_n(X^{2}_{n1}\theta_1+...+X^{2}_{nn}\theta_n)\]" class="align-center" src="https://eli.thegreenplace.net/images/math/0451f9fa7c61ff3a61be8c1836c15667cd916330.png" style="height: 22px;" />
<p>Once again, it's now time to compute the derivatives. Let's focus on
<img alt="\frac{\partial Q}{\partial \theta_1}" class="valign-m9" src="https://eli.thegreenplace.net/images/math/5161830b1f644a3c2d1a650ccd6405e0fe5940aa.png" style="height: 27px;" />, from which we can infer the rest:</p>
<img alt="\[\frac{\partial Q}{\partial \theta_1}=(2\theta_1X^{2}_{11}+\theta_2X^{2}_{12}+...+\theta_nX^{2}_{1n})+\theta_2X^{2}_{21}+...+\theta_nX^{2}_{n1}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/f99e5e7024b4d13b0a767b98653b6ccc22fa1abd.png" style="height: 41px;" />
<p>Using the fact that <strong>X</strong> squared is symmetric, we know that
<img alt="X^{2}_{12}=X^{2}_{21}" class="valign-m6" src="https://eli.thegreenplace.net/images/math/c14595d1000ad9a8da5be7f37da801eadfdfb698.png" style="height: 21px;" /> and so on. Therefore:</p>
<img alt="\[\frac{\partial Q}{\partial \theta_1}=2\theta_1X^{2}_{11}+2\theta_2X^{2}_{12}+...+2\theta_nX^{2}_{1n}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/832b294f472a23e500616db08d9d6832770af6a3.png" style="height: 40px;" />
<p>The partial derivatives by other <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> components are similar.
Collecting the sequence of partial derivatives back into a vector equation, we
get:</p>
<img alt="\[\frac{\partial Q}{\partial \theta}=2X^2\theta=2X^TX\theta\]" class="align-center" src="https://eli.thegreenplace.net/images/math/541124d49fa78dcf92a15b14643b2ebc4187eaaf.png" style="height: 38px;" />
<p>Now back to <em>J</em>. Recall that for convenience we broke <em>J</em> up into three parts:
<em>P</em>, <em>Q</em> and <img alt="y^Ty" class="valign-m4" src="https://eli.thegreenplace.net/images/math/81015d6225923cec985bef47ca151ef1cb654c92.png" style="height: 19px;" />; the latter doesn't depend on <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> so it
doesn't play a role in the derivative. Collecting our results from this post, we
then get:</p>
<img alt="\[\frac{\partial J}{\partial \theta}=\frac{\partial Q}{\partial \theta}-\frac{\partial P}{\partial \theta}=2X^TX\theta-2X^Ty\]" class="align-center" src="https://eli.thegreenplace.net/images/math/9c3d0d108ada3bfc7290c2328c8e6171bc01d7de.png" style="height: 38px;" />
<p>Which is exactly the equation we were expecting to see.</p>
<p>To conclude - if matrix calculus feels handwavy, it's because its identities are
less familiar. In a sense, it's handwavy in the same way
<img alt="\frac{dx^2}{dx}=2x" class="valign-m6" src="https://eli.thegreenplace.net/images/math/5fa725ae5b10a9249e9480d595770cf34accf533.png" style="height: 24px;" /> is handwavy. We remember the identity so we don't
have to recalculate it every time from first principles. Once you get some
experience with matrix calculus, parts of equations start looking familiar and
you no longer need to engage in the long and tiresome computations demonstrated
here. It's perfectly fine to just remember that the derivative of
<img alt="\theta^TX\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/7616542d90e084c74423b2a9d93b7a3a6cadcd00.png" style="height: 15px;" /> with a symmetric <strong>X</strong> is <img alt="2X\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/7fa6bcc17eae56f6f3f4a6fdcadae3cb3ee2c5d7.png" style="height: 12px;" />. See the
&quot;identities&quot; section of the <a class="reference external" href="http://en.wikipedia.org/wiki/Matrix_calculus">wikipedia article on matrix calculus</a> for many more examples.</p>
<hr class="docutils" />
<table class="docutils footnote" frame="void" id="id4" rules="none">
<colgroup><col class="label" /><col /></colgroup>
<tbody valign="top">
<tr><td class="label"><a class="fn-backref" href="#id1">[1]</a></td><td>A few words on notation: by default, a vector <strong>v</strong> is a <em>column</em> vector.
To get its row version, we transpose it. Moreover, in the vector
derivative equations that follow I'm using <a class="reference external" href="http://en.wikipedia.org/wiki/Matrix_calculus#Layout_conventions">denominator layout notation</a>. This
is not super-important though; as the Wikipedia article suggests, many
mathematical papers and writings aren't consistent about this and it's
perfectly possible to understand the derivations regardless.</td></tr>
</tbody>
</table>
<table class="docutils footnote" frame="void" id="id5" rules="none">
<colgroup><col class="label" /><col /></colgroup>
<tbody valign="top">
<tr><td class="label"><a class="fn-backref" href="#id2">[2]</a></td><td>Yes, this is exactly like computing a gradient of a multivariate
function.</td></tr>
</tbody>
</table>
<table class="docutils footnote" frame="void" id="id6" rules="none">
<colgroup><col class="label" /><col /></colgroup>
<tbody valign="top">
<tr><td class="label"><a class="fn-backref" href="#id3">[3]</a></td><td>Take a minute to convince yourself that the dimensions of this equation
work out and the result is a scalar.</td></tr>
</tbody>
</table>
Derivation of the Normal Equation for linear regression2014-12-22T20:50:00-08:002014-12-22T20:50:00-08:00Eli Benderskytag:eli.thegreenplace.net,2014-12-22:/2014/derivation-of-the-normal-equation-for-linear-regression/<p>I was going through the Coursera &quot;Machine Learning&quot; course, and in the section
on multivariate linear regression something caught my eye. Andrew Ng presented
the <a class="reference external" href="http://en.wikipedia.org/w/index.php?title=Normal_equation&amp;redirect=no">Normal Equation</a> as an
analytical solution to the linear regression problem with a least-squares cost
function. He mentioned that in some cases (such as for …</p><p>I was going through the Coursera &quot;Machine Learning&quot; course, and in the section
on multivariate linear regression something caught my eye. Andrew Ng presented
the <a class="reference external" href="http://en.wikipedia.org/w/index.php?title=Normal_equation&amp;redirect=no">Normal Equation</a> as an
analytical solution to the linear regression problem with a least-squares cost
function. He mentioned that in some cases (such as for small feature sets) using
it is more effective than applying gradient descent; unfortunately, he left its
derivation out.</p>
<p>Here I want to show how the normal equation is derived.</p>
<p>First, some terminology. The following symbols are compatible with the machine
learning course, not with the exposition of the normal equation on Wikipedia and
other sites - semantically it's all the same, just the symbols are different.</p>
<p>Given the hypothesis function:</p>
<img alt="\[h_{\theta}(x)=\theta_0x_0+\theta_1x_1+\cdots+\theta_nx_n\]" class="align-center" src="https://eli.thegreenplace.net/images/math/dd8fad9bf111e83d47252d51dd037a6c6c3136aa.png" style="height: 18px;" />
<p>We'd like to minimize the least-squares cost:</p>
<img alt="\[J(\theta_{0...n})=\frac{1}{2m}\sum_{i=1}^{m}(h_{\theta}(x^{(i)})-y^{(i)})^2\]" class="align-center" src="https://eli.thegreenplace.net/images/math/c1abe0768f4deb31ed97f37d760236c94439a780.png" style="height: 50px;" />
<p>Where <img alt="x^{(i)}" class="valign-0" src="https://eli.thegreenplace.net/images/math/233014006c0adbee71ec71ba3a70f22ad1b906a1.png" style="height: 17px;" /> is the <tt class="docutils literal">i</tt>-th sample (from a set of <tt class="docutils literal">m</tt> samples) and
<img alt="y^{(i)}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/d34414117d493106f731939df6bb7f1762365d3f.png" style="height: 21px;" /> is the <tt class="docutils literal">i</tt>-th expected result.</p>
<p>To proceed, we'll represent the problem in matrix notation; this is natural,
since we essentially have a system of linear equations here. The regression
coefficients <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> we're looking for are the vector:</p>
<img alt="\[\begin{pmatrix} \theta_0\\ \theta_1\\ ...\\ \theta_n \end{pmatrix}\in\mathbb{R}^{n+1}\]" class="align-center" src="https://eli.thegreenplace.net/images/math/b16fd3d2b3041f13cb70199837a7c02c756078c7.png" style="height: 86px;" />
<p>Each of the <tt class="docutils literal">m</tt> input samples is similarly a column vector with <tt class="docutils literal">n+1</tt> rows,
<img alt="x_0" class="valign-m3" src="https://eli.thegreenplace.net/images/math/efbda784ad565c1c5201fdc948a570d0426bc6e6.png" style="height: 11px;" /> being 1 for convenience. So we can now rewrite the hypothesis
function as:</p>
<img alt="\[h_{\theta}(x)=\theta^Tx\]" class="align-center" src="https://eli.thegreenplace.net/images/math/be661047c89f6a48c7bc563b81207949c251de6a.png" style="height: 21px;" />
<p>When this is summed over all samples, we can dip further into matrix notation.
We'll define the &quot;design matrix&quot; <tt class="docutils literal">X</tt> (uppercase X) as a matrix of <tt class="docutils literal">m</tt> rows,
in which each row is the <tt class="docutils literal">i</tt>-th sample (the vector <img alt="x^{(i)}" class="valign-0" src="https://eli.thegreenplace.net/images/math/233014006c0adbee71ec71ba3a70f22ad1b906a1.png" style="height: 17px;" />). With
this, we can rewrite the least-squares cost as following, replacing the explicit
sum by matrix multiplication:</p>
<img alt="\[J(\theta)=\frac{1}{2m}(X\theta-y)^T(X\theta-y)\]" class="align-center" src="https://eli.thegreenplace.net/images/math/db5e3da78e25c18c8fc88f1291c1ac13a2645388.png" style="height: 36px;" />
<p>Now, using some matrix transpose identities, we can simplify this a bit. I'll
throw the <img alt="\frac{1}{2m}" class="valign-m6" src="https://eli.thegreenplace.net/images/math/7a2a3f6dba54b64f0e88e18c40e0f68c523713ea.png" style="height: 22px;" /> part away since we're going to compare a
derivative to zero anyway:</p>
<img alt="\[J(\theta)=((X\theta)^T-y^T)(X\theta-y)\]" class="align-center" src="https://eli.thegreenplace.net/images/math/c1368de1a0634c3fbeb92d67f368f253943d089f.png" style="height: 21px;" />
<img alt="\[J(\theta)=(X\theta)^TX\theta-(X\theta)^Ty-y^T(X\theta)+y^Ty\]" class="align-center" src="https://eli.thegreenplace.net/images/math/e41fc822adccf1f865b02100f5671e265e7b30bc.png" style="height: 21px;" />
<p>Note that <img alt="X\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/52f2de6065bdc187b876c5696041f3c716c446f5.png" style="height: 12px;" /> is a vector, and so is <tt class="docutils literal">y</tt>. So when we multiply
one by another, it doesn't matter what the order is (as long as the dimensions
work out). So we can further simplify:</p>
<img alt="\[J(\theta)=\theta^TX^TX\theta-2(X\theta)^Ty+y^Ty\]" class="align-center" src="https://eli.thegreenplace.net/images/math/2864b88546c007a79dc92271f5e01487ba608e43.png" style="height: 21px;" />
<p>Recall that here <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> is our unknown. To find where the above
function has a minimum, we will derive by <img alt="\theta" class="valign-0" src="https://eli.thegreenplace.net/images/math/cb005d76f9f2e394a770c2562c2e150a413b3216.png" style="height: 12px;" /> and compare to 0.
Deriving by a vector may feel uncomfortable, but there's nothing to worry about.
Recall that here we only use matrix notation to conveniently represent a system
of linear formulae. So we derive by each component of the vector, and then
combine the resulting derivatives into a vector again. The result is:</p>
<img alt="\[\frac{\partial J}{\partial \theta}=2X^TX\theta-2X^{T}y=0\]" class="align-center" src="https://eli.thegreenplace.net/images/math/9b142c00e031c9db7f575b0542e86261732a4689.png" style="height: 38px;" />
<p>Or:</p>
<img alt="\[X^TX\theta=X^{T}y\]" class="align-center" src="https://eli.thegreenplace.net/images/math/ab453f9f1f7bd4b1d646b9712fbe0b2fbe01740f.png" style="height: 21px;" />
<p>[<em>Update 27-May-2015</em>: I've written <a class="reference external" href="http://eli.thegreenplace.net/2015/the-normal-equation-and-matrix-calculus/">another post</a>
that explains in more detail how these derivatives are computed.]</p>
<p>Now, assuming that the matrix <img alt="X^TX" class="valign-0" src="https://eli.thegreenplace.net/images/math/5c817c84ec1f83b23494df6125edd091a7c413dd.png" style="height: 15px;" /> is invertible, we can multiply both
sides by <img alt="(X^TX)^{-1}" class="valign-m4" src="https://eli.thegreenplace.net/images/math/57f592cee6ceac659262d97e61c64f9ca405d7f1.png" style="height: 19px;" /> and get:</p>
<img alt="\[\theta=(X^TX)^{-1}X^Ty\]" class="align-center" src="https://eli.thegreenplace.net/images/math/20baabd9d33dcd26003bc44c7d81ba39e1ad4caa.png" style="height: 21px;" />
<p>Which is the normal equation.</p>