Logistic Regression

Logistic regression estimates the probability of an event occurring, such as voting or not voting, based on a given dataset of independent variables.

Inner Workings

Here we are going to look at the binary classification case, but it is straightforward to generalize the algorithm to multi-class classification.

Assume that we have a k predictor: Xi=1kR{X}^k_{i=1} \in \mathbb{R} and a binary response variable: Y0,1Y \in {0,1}

In the logistic regression algorithm, the relationship between the predictors and the logit of the probability of a positive outcome Y=1Y=1 is assumed to be linear: logit(P(Y=1w))=c+i=1kwiXilogit(P(Y=1|w))=c+\sum_{i=1}^kw_iX_i

where{wi}i=1kRk\{w_i\}^k_{i=1} \in \mathbb{R}^kare the linear weights and cRc \in \mathbb{R} the intercept.

Now what is the logit function? It is the log of odds: logit(p)=ln(p1p)logit(p)=\ln\left(\frac{p}{1−p}\right)

We see that the logit function is a way to map a probability value from [0,1] to R\mathbb{R}

The inverse of the logit is the logistic curve [also called sigmoid function], which we are going to note σ\sigma: σ(r)=11+er\sigma(r)=\frac{1}{1+e^{−r}}

If we denote by w=[c;w1;...;wk]Tw=[c;w_1;...;w_k]^T the weight vector, x=[1;x1;...;xk]Tx=[1;x_1;...;x_k]^T the observed values of the predictors, and y the associated class value, we have: logit(P(y=1w))=wTxlogit(P(y=1|w))=w^Tx

And thus: P(y=1w)=σ(wTx)σw(x)P(y=1|w)=\sigma(w^Tx)≡\sigma_w(x)

For a given set of weights w, the probability of a positive outcome is σw(x)\sigma_w(x).

This probability can be turned into a predicted class label y^\hat{y} using a threshold value:y^=1;if;σw(x)0.5,;0;otherwise\hat{y} = 1 ; \text{if} ; \sigma_{\textbf{w}} (\textbf{x}) \geq 0.5, ; 0 ; \text{otherwise}

Cost function

Now we assume that we have nn observations and that they are independently Bernoulli distributed: {(x(1),y(1)),(x(2),y(2)),...,(x(n),y(n))}\{ \left( \textbf{x}^{(1)}, y^{(1)} \right), \left( \textbf{x}^{(2)}, y^{(2)} \right), ..., \left( \textbf{x}^{(n)}, y^{(n)} \right) \}

The likelihood that we would like to maximize given the samples is the following one:

L(w)=i=1nP(y(i)x(i);w)=i=1nσw(x(i))y(i)(1σw(x(i)))1y(i)L(\textbf{w}) = \prod_{i=1}^n P( y^{(i)} | \textbf{x}^{(i)}; \textbf{w}) = \prod_{i=1}^n \sigma_{\textbf{w}} \left(\textbf{x}^{(i)} \right)^{y^{(i)}} \left( 1- \sigma_{\textbf{w}} \left(\textbf{x}^{(i)} \right)\right)^{1-y^{(i)}}

For some reasons related to numerical stability, we prefer to deal with a scaled log-likelihood. Also, we take the negative, to get a minimization problem:

J(w)=1ni=1n[y(i)log(σw(x(i)))+(1y(i))log(1σw(x(i)))](8)J(\textbf{w}) = - \frac{1}{n} \sum_{i=1}^n \left[ y^{(i)} \log \left( \sigma_{\textbf{w}} \left(\textbf{x}^{(i)} \right) \right) + \left( 1-y^{(i)} \right) \log \left( 1- \sigma_{\textbf{w}} \left(\textbf{x}^{(i)} \right)\right) \right] \tag{8}

A great feature of this cost function is that it is differentiable and convex. A gradient-based algorithm should find the global minimum. Now let's also introduce some l2l2-regularization to improve the model: Jr(w)=J(w)+λ2wTwJ_r(\textbf{w}) = J(\textbf{w}) + \frac{\lambda}{2} \textbf{w}^T \textbf{w} with λ0\lambda \geq 0.

Regularization is a very useful method to handle collinearity [high correlation among features], filter out noise from data, and eventually prevent overfitting.

Learning the weights

So we need to minimize Jr(w)J_r(\textbf{w}). For that, we are going to apply Gradient descent. This method requires the gradient of the cost function: wJr(w)\nabla_{\textbf{w}} J_r(\textbf{w})

Compute the gradient

We could compute the gradient of this Logistic regression cost function analytically. However, we won't, because we are lazy and want JAX to do it for us! Also, we can say that JAX would be more relevant if applied to a very complex function for which an analytical derivative is very hard or impossible to compute, such as the cost function of a deep neural network for example.

So let's differentiate this cost function concerning the first and second positional arguments using JAX's grad function.

Last updated