Gradient descent is perhaps the simplest learning optimization algorithm that exists. In Deep Learning, it’s the foundational learning algorithm upon which modern learning algorithms are developed.

Yet in introductory material to machine learning, it’s hardly mentioned. I quite frequently get asked about gradient descent and despite it’s simplicity, there’s a good deal of confusion about it’s basic properties. In this article, we’ll turn gradient descent inside out.

Let’s start by considering how to optimize some arbitrary function that we know almost nothing about. Once we know how to do that, we’ll introduce the machine learning context.

Find a minimum: brute force

Suppose I give you a differentiable function $f: \mathbb{R}^n \mapsto \mathbb{R}$ and ask you to find a minimum. That is, solve

I haven’t told you anything about this function, so you have no way of knowing how it behaves. So how do you minimize it?

Somewhat dumbly, we could pick a random input $\boldsymbol{v}$ as our initial guess, then pick and another point $\boldsymbol{d}$ in the neighborhood of $\boldsymbol{v}$ as a challenger. We then compare $f(\boldsymbol{v})$ to $f(\boldsymbol{d})$ and if $f(\boldsymbol{d})$ is lower, we make $\boldsymbol{v}$ our new candidate $\boldsymbol{v}$. Otherwise we draw a new $\boldsymbol{d}$ and repeat the process. Eventually, this process we’ll find a local minimum. But guessing at random is profoundly unsatisfying and rather inefficient. In particular, at every iteration we learn something new about $f$, but currently we completely ignore this information.

First order descent

Suppose I tell you $f$ is one-dimensional and looks like this:

SGD

Starting at the blue point, where would you go? Obviously, to the right—that is, we’d want to increase $u$. Once we realize $f$ has this shape, it’s easy to see that we’d increase $u$ with some small amount until increase $u$ further would lead to an increase in $f$. This strategy would take us to the orange local minimum in about half as many iterations as at each point randomly drawing a point either to the left or right as we did above.

What we have done here is to incorporate information about how $f$ behaves around the blue point, in particular, information about the local curvature of $f$. Since $f$ is differentiable, the mathematical analogue is to evaluating the derivative at each candidate point and move in the direction that would decrease $f$. In this way, we use the gradient to descend on a local minimum. Note that this strategy will not lead us to the global minimum (the green point), since the curvature of $f$ in the neighborhood of the orange point implies moving further to the right would lead to a higher function value.

But given that we know nothing about $f$, finding some minimum is pretty good! When we use the gradient (if this is new to you, read derivative) we are effectively evaluating every direction we could go in and picking the one that looks the most promising. Comparing to our first brute-force method, this is equivalent to making $n$ draws $\boldsymbol{d}_1, \boldsymbol{d}_2, …, \boldsymbol{d}_n$ around $\boldsymbol{v}$ in such a way that each $\boldsymbol{d}_i$ only differs from $\boldsymbol{v}$ in the $i$-th dimension. We then simply go in the direction that minimizes $f$ the most.

Of course, this hinges on the idea that following the gradient actually minimizes $f$. Indeed, you can convince yourself of this by proving that, for any given $\boldsymbol{v}$, the direction $\boldsymbol{d}$ with maximal descent from $\boldsymbol{v}$ within a unit ball is given by

Restricting $\boldsymbol{d}$ to lie within a unit ball of $\boldsymbol{v}$ is merely for convenience: the idea is to find the optimal direction in some small neighborhood of $\boldsymbol{v}$. In fact, we typically use a slightly modified version of $\eqref{d}$ where $ \boldsymbol{d} = - \alpha \nabla f(\boldsymbol{v})$ for some small $\alpha \in (0, 1)$. This formulation let’s us regulate how much to move in the optimal direction via $\alpha$.

Time to solve our original problem $\eqref{p}$. Let’s pick some starting point $\boldsymbol{v}_0$ at random. But now, rather than drawing a new candidate point at random, we use $\eqref{d}$ to give us the best direction to move in. We repeat this process for a given number of steps or until the benefit from updating our candidate solution is sufficiently low. This gives us a sequence of points $ \{ \boldsymbol{v}_t \} _ {t \in \mathbb{N}} $ as follows:

This is the canonical gradient update rule. An important property of this rule is that, for $\alpha$ sufficiently small, it is guaranteed to find a minimum. To see this, write out the first-order Taylor series expansion of $f(\boldsymbol{v})$:

The $o(\cdot)$ term is the remainder in Peano form. Plug in $\eqref{s}$ for $\boldsymbol{v}_{t+1} - \boldsymbol{v}_t$ to get

If you’re unfamiliar with the remainder term, what you need to know is that as $\alpha \to 0$, this terms collapses at an exponential rate, so for $\alpha$ sufficiently small, the first term will dominate. When this happens, the right hand side is negative and

so $\boldsymbol{v} _ {t} \to \boldsymbol{v}^* $. That’s gradient descent. There are many other version of gradient descent that bring various benefits and drawbacks. Some common ones are to use an adaptive learning rate, so instead of one $\alpha$ we have a sequence $\{\alpha_t\}_{ t \in \mathbb{N}}$ of learning rates for each of our updates. We could also use information in the second-order derivative of $f$ to guide us in our update, i.e. the Hessian of $f$. This gives rise to second-order descent methods. They typically converge much faster, but suffer from higher computational costs and can be quite sensitive to the starting point. Classical machine learning tries to leverage some approximate second-order optimization strategy, but deep learning predominantly relies on first-order methods like gradient descent.

Gradient descent for Machine Learning

So far, $f$ has been some random function. It’s time we relate this to a real machine learning problem. The simplest case, which is also the most intuitive, is to understand how we learn a linear relationship between an input and an output. Suppose we observe a data generating process $F$ that takes two inputs and returns an output. We have observed a set of datapoints $\{(\boldsymbol{x}^{(1)}, y^{(1)}), …, (\boldsymbol{x}^{(N)}, y^{(N)})\}$ that looks like this:

SGD

This looks like a fairly linear relationship, but we don’t know what it is. Hence, we want to approximately learn the true data generating process $F$. Let’s specify a parametric statistical model that we think $F$ belongs to

Here, $\boldsymbol{w} = (w_1, w_2)$ are parameters and $\boldsymbol{w}^* $ are the true parameters we wish to learn. We need to learn a set of parameters $\hat{\boldsymbol{w}}$ from the data that we expect to be as close to the true parameters as possible. To this end, we measure how good a candidate model is from the true model with a loss function, $l$. We learn $\hat{\boldsymbol{w}}$ by minimizing the expected loss:

This objective is known as risk minimization. To minimize the expected risk, we need to approximate the expectation with the the data that we have. This approximation is the empirical risk objective, $J(\boldsymbol{w}) = \frac1N \sum_{i=1}^N l(\boldsymbol{w}^\top \boldsymbol{x}, {\boldsymbol{w}^ *}^\top \boldsymbol{x} )$. Finally, we need to specify the loss function. We want a function that is monotonically increasing as $\boldsymbol{w}^\top \boldsymbol{x}$ moves away from ${\boldsymbol{w}^ *}^\top \boldsymbol{x}$. An intuitive chocie would be a metric distance, specifically a norm. A popular choice is the euclidean distance, which is what we’ll use: $ l(\boldsymbol{w}, \boldsymbol{w}^ * ) = (\boldsymbol{w}^{\top}\boldsymbol{x} - {\boldsymbol{w} ^ * }^{\top}\boldsymbol{x})^2$. This choice might seem a bit ad hoc, but you can actually derive it by imposing a Gaussian distribution over $y$ given $\boldsymbol{x}$. But since we’re more interested in gradient descent here, let’s gloss over that detail. One detail that might strike you as harder to avoid is that our loss uses the true parameters, which we don’t know. But since ${\boldsymbol{w} ^ * }^{\top}\boldsymbol{x} = y$, we can simply substitute for the observed output in the loss to get

The cost function $J$ plays the role of $f$ in our original problem $\eqref{p}$, so we can readily apply gradient descent. Applying $\eqref{s}$, we update our estimated parameters according to

Iterating this process $100$ times we are able to learn the true data generating process. In the figure below, you can see how we steadily make progress as we gradually move our parameters towards the ground truth (the dashed lines). On the left is the cost function. The dashed lines gives the error rate of the true model (incurred due to noise).

SGD

On the right you see the evolution of our parameter estimates. Note that $w_2$ is first pushed up above it’s true value, and then converges to the true value as $w_1$ converges on its true value. This is an inherent weakness with gradient descent. Since it only considers what’s optimal in a small neighborhood around $\boldsymbol{w}_t$, it is susceptible to taking long detours before converging on some minimum.

Speeding up convergence - Momentum

While Gradient descent is a pretty good algorithm for finding at least a local minimum, it is also really slow. The reason is that it only tells us which direction to go in, but not how far. For this reason, we are forced to take small steps, or risk being thrown off completely by a large gradient. But small learning rates causes us to take several steps in the same direction when it would have been much more efficient to cover the same distance in a single step. We might be tempted to increase the learning rate a little, but this can lead to an even slower convergence if we move our parameters so much that they start oscillating between two modes. In our above example, increasing the learning rate does not lead to faster convergence as we move the parameters too much with each iteration, instead overshooting the minimum. This gives rise to the following “zig-zag” pattern.

SGD

What we want instead is a method that allows us to take large steps when we’re repeatedly going in the same direction, but reverts back to small steps when we start moving around. This is what momentum is designed to achieve. In particular, we record previous directions and allow them to accumulate. When several iterations have gradients pointing in the same direction, momentum builds up allowing larger steps. When we gradients start pointing in different directions, momentum unravels. Formally, we modify $\eqref{s}$ as follows

We now have a new parameter, $\mu$, that governs how much of our previous state to incorporate into the update. If you combine these equations, you’ll notice that the difference between $\eqref{m}$ and $\eqref{s}$ amounts to $\mu \boldsymbol{z} _ {t+1} $. If $\mu=0$, we recover the standard gradient update rule. In practice, $\mu \in (0.8, 1)$. When we use momentum, the extra term $\boldsymbol{z} _ {t+1} $ grows when the gradient has the same sign repeatedly. Repeatedly going in the same direction will therefore induce larger steps with each iteration. When we start going in a different direction, the extra term will gradually be reduced, causing us to take smaller steps again. In our case, setting $\mu=0.8$ allows us to converge in half as many iterations and along a much smoother path.

SGD

One thing to notice here is that once we reach the “bottom” of the bowl, we need a few iterations to unwind momentum. Ideally, we’d like to build up momentum at a linear rate (additively), but unravel momentum at an exponential rate (multiplicatively) to avoid this type of behavior. There are several versions of momentum that gives this type of behavior, and in practice those are the versions used. Popular alternatives include Nesterov Momentum, RMSProp and Adam. For an informal survey of popular variations, see this e-print on arxiv.

That concludes our introduction to gradient descent. You’ve seen how we can derive it as a tool to find a minimum of an arbitrary function, and who we can apply it to learn parameters in a machine learning model. There are many versions of gradient descent, and many more with momentum. To gain further understanding of momentum, this article at Distill is highly recommended.