Gradient descent

Gradient descent is a very simple and powerful algorithm that is used to find a local minimum of a function.

"Classical" example

Starting with something simple - the "classic" parabola example. To make it we need two things - the function and the derivative of that function.

As a function I'm going to use: f(x)=x2f(x)=x^2

And we can easily calculate its derivative: f(x)=2xf'(x)=2x

In the plot above we can see our parabola and its derivative (red line). Derivative it's nothing more than the instantaneous rate of change of the function at a certain point. In that particular case - when the derivative of our function is equal to 0 - then we are at the global minimum of our function.

Now let's start realization of the gradient descent algorithm.

The algorithm is very simple and can be divided into several steps:

  1. Define a starting x coordinate, from which we want to descent to local/global minimum

  2. Calculate derivative at this point

  3. Subtract from starting point derivative at this point. To prevent divergence we need to multiply the derivative on small number "alpha", or learning rate.

  4. Repeat from step 2

So the main gradient descent formula can be written as: xnext=xstartαderivativex_{next} = x_{start} - \alpha * derivative

Gradient descent of multivariable function

Now let's take something more interesting - a multivariable function with 2 input variables and 1 output.

As in previous example - to succesfully descent to local minimum we need to calculate derivaties but, because we have multivariable function, we need to calculate partial derivatives wit respect to X and Y:

1. Partial derivative with respect to X:

2. Partial derivative with respect to Y:

These partial derivatives are necessary parts of gradient - the vector, which shows direction of steepest ascend of the function:

So, if we take opposite of that vector - it will show us the direction of steepest descent. First - let's plot the function to see how it looks like.

Last updated