Logistic Regression
Last updated
Last updated
Logistic regression estimates the probability of an event occurring, such as voting or not voting, based on a given dataset of independent variables.
Here we are going to look at the binary classification case, but it is straightforward to generalize the algorithm to multi-class classification.
Assume that we have a k predictor: and a binary response variable:
In the logistic regression algorithm, the relationship between the predictors and the logit of the probability of a positive outcome is assumed to be linear:
whereare the linear weights and the intercept.
Now what is the logit function? It is the log of odds:
We see that the logit function is a way to map a probability value from [0,1] to
The inverse of the logit is the logistic curve [also called sigmoid function], which we are going to note :
If we denote by the weight vector, the observed values of the predictors, and y the associated class value, we have:
And thus:
For a given set of weights w, the probability of a positive outcome is .
This probability can be turned into a predicted class label using a threshold value:
Now we assume that we have observations and that they are independently Bernoulli distributed:
The likelihood that we would like to maximize given the samples is the following one:
For some reasons related to numerical stability, we prefer to deal with a scaled log-likelihood. Also, we take the negative, to get a minimization problem:
Regularization is a very useful method to handle collinearity [high correlation among features], filter out noise from data, and eventually prevent overfitting.
We could compute the gradient of this Logistic regression cost function analytically. However, we won't, because we are lazy and want JAX to do it for us! Also, we can say that JAX would be more relevant if applied to a very complex function for which an analytical derivative is very hard or impossible to compute, such as the cost function of a deep neural network for example.
So let's differentiate this cost function concerning the first and second positional arguments using JAX's grad function.
A great feature of this cost function is that it is differentiable and convex. A gradient-based algorithm should find the global minimum. Now let's also introduce some -regularization to improve the model: with .
So we need to minimize . For that, we are going to apply Gradient descent. This method requires the gradient of the cost function: