Gradient descent
Last updated
Last updated
Gradient descent is a very simple and powerful algorithm that is used to find a local minimum of a function.
Starting with something simple - the "classic" parabola example. To make it we need two things - the function and the derivative of that function.
As a function I'm going to use:
And we can easily calculate its derivative:
In the plot above we can see our parabola and its derivative (red line). Derivative it's nothing more than the instantaneous rate of change of the function at a certain point. In that particular case - when the derivative of our function is equal to 0 - then we are at the global minimum of our function.
Now let's start realization of the gradient descent algorithm.
The algorithm is very simple and can be divided into several steps:
Define a starting x coordinate, from which we want to descent to local/global minimum
Calculate derivative at this point
Subtract from starting point derivative at this point. To prevent divergence we need to multiply the derivative on small number "alpha", or learning rate.
Repeat from step 2
So the main gradient descent formula can be written as:
Now let's take something more interesting - a multivariable function with 2 input variables and 1 output.
As in previous example - to succesfully descent to local minimum we need to calculate derivaties but, because we have multivariable function, we need to calculate partial derivatives wit respect to X and Y:
These partial derivatives are necessary parts of gradient - the vector, which shows direction of steepest ascend of the function:
So, if we take opposite of that vector - it will show us the direction of steepest descent. First - let's plot the function to see how it looks like.