Cross-entropy loss and its optimization [WIP]
1. Background
Computing cross-entropy loss becomes significantly more challenging for LLMs. This is primarily due to the extremely large logit and label matrices involved in the calculations, which can lead to high computational costs and memory usage. Recently, several optimization strategies have been proposed to address this issue, starting from a Pytorch GitHub issue.
- https://github.com/pytorch/pytorch/issues/124480
- https://github.com/mgmalek/efficient_cross_entropy
- Liger Kernel: github, arxiv
- Cut Your Losses in Large-Vocabulary Language Models: arxiv
All these approaches share a common goal: avoiding the full materialization of the logit matrix. They achieve this by:
- chunking the logit matrix
- computating the gradient of logit in place
In this blog, I will dive into the cross entropy loss and its optimization strategies.
2. Softmax Cross-Entropy
2.1. Forward Pass
Let's begin by understanding the forward pass of the cross-entropy loss.
Consider:
- An input vector representing the logits (unnormalized scores) produced by the model for each class.
- A true label indicating the correct class.
The softmax function converts the logits into probabilities:
Here, represents the probability of the input belonging to class .
The cross-entropy loss for a single instance is then defined as:
Expanding this, we get:
2.2. Backward Pass
In general, the gradient of the loss with respect to the input is given by
2.2.1. Step 1: Compute
The result is:
The full derivation for the case is:
And for :
2.2.2. Step 2: Compute
So,
2.3. Gradient in Matrix Form
For batch computations, it's efficient to represent gradients in matrix form.
Given:
- : Matrix of predicted probabilities for a batch of size .
- : Matrix of logits.
- : One-hot encoded true labels.
The gradient with respect to the logits is:
Normalized by batch size, the overall gradient of the loss is:
3. Linear-Softmax-Cross-Entropy
Cross-entropy loss is typically preceded by a linear (fully connected) layer and followed by a softmax activation. If we can fuse the linear layer and softmax activation, we may avoid the full materialization of the logit matrix.
- Input before the final linear layer:
- Linear weights:
- Linear bias:
- Labels: , representing the true classes for each instance in the batch.
3.1. Forward Pass
With a linear transformation, the input is transformed linearly using the weights and bias:
Softmax:
Cross-entropy loss is computed for each instance and then averaged over the batch:
3.2. Backward Pass
Gradient of :
Gradient of :
Gradient of :
Gradient of input :
3.3. Summary of Gradients
Parameter | Formula | Dimensions |
4. Optimization Strategies
-
Chunking the logit matrix: Chunking over the batch can avoid materializing the full logit matrix. The logit matrix is divided into chunks over the batch size dimension, and the cross-entropy loss is computed for each chunk. The final loss is the sum of the losses of all chunks.
-
Compute the gradient of logit in place: The gradient of the logit matrix is computed in place, and the gradient of the input is computed by multiplying the gradient of the logit matrix with the weight matrix.