Research Article - Xiaotian Han | Academic Insights

Cross-entropy loss and its optimization [WIP]

1.Background
2.Softmax Cross-Entropy
3.Linear-Softmax-Cross-Entropy
4.Optimization Strategies

1. Background

Computing cross-entropy loss becomes significantly more challenging for LLMs. This is primarily due to the extremely large logit and label matrices involved in the calculations, which can lead to high computational costs and memory usage. Recently, several optimization strategies have been proposed to address this issue, starting from a Pytorch GitHub issue.

All these approaches share a common goal: avoiding the full materialization of the logit matrix. They achieve this by:

  1. chunking the logit matrix
  2. computating the gradient of logit in place

In this blog, I will dive into the cross entropy loss and its optimization strategies.

2. Softmax Cross-Entropy

2.1. Forward Pass

Let's begin by understanding the forward pass of the cross-entropy loss.

Consider:

  • An input vector π’™βˆˆβ„π‘‘ representing the logits (unnormalized scores) produced by the model for each class.
  • A true label π‘¦βˆˆ{0,1,…,π‘‘βˆ’1} indicating the correct class.

The softmax function converts the logits into probabilities:

𝒑𝑖=π‘’π’™π‘–βˆ‘π‘˜=1π‘‘π‘’π’™π‘˜

Here, 𝒑𝑖 represents the probability of the input belonging to class 𝑖.

The cross-entropy loss for a single instance is then defined as:

𝐿=βˆ’log(𝒑𝑦)

Expanding this, we get:

𝐿=βˆ’log(𝒑𝑦)=βˆ’log(π‘’π’™π‘¦βˆ‘π‘˜=1π‘‘π‘’π’™π‘˜)=βˆ’log(𝑒𝒙𝑦)+log(βˆ‘π‘˜=1π‘‘π‘’π’™π‘˜)=βˆ’π’™π‘¦+log(βˆ‘π‘˜=1π‘‘π‘’π’™π‘˜)

2.2. Backward Pass

In general, the gradient of the loss with respect to the input is given by

βˆ‚πΏβˆ‚π’›π‘–=βˆ‚πΏβˆ‚π’‘π‘—βˆ‚π’‘π‘—βˆ‚π’›π‘–

2.2.1. Step 1: Compute βˆ‚π’‘π‘—βˆ‚π’›π‘–

The result is:

βˆ‚π’‘π‘—βˆ‚π’›π‘–={𝒑𝑗(1βˆ’π’‘π‘—)if  𝑗=π‘–βˆ’π’‘π‘—π’‘π‘–if  𝑗≠𝑖

The full derivation for the case 𝑗=𝑖 is:

βˆ‚π’‘π‘—βˆ‚π’›π‘—=βˆ‚(π‘’π’›π‘—βˆ‘π‘˜=1π‘π‘’π’›π‘˜)βˆ‚π’›π‘—=(βˆ‘π‘˜=1π‘π‘’π’›π‘˜)β‹…π‘’π’›π‘—βˆ’π‘’π’›π‘—π‘’π’›π‘—(βˆ‘π‘˜=1π‘π‘’π’›π‘˜)2=(π‘’π’›π‘—βˆ‘π‘˜=1π‘π‘’π’›π‘˜)(1βˆ’π‘’π’›π‘—βˆ‘π‘˜=1π‘π‘’π’›π‘˜)=𝒑𝑗(1βˆ’π’‘π‘—)

And for 𝑗≠𝑖:

βˆ‚π’‘π‘—βˆ‚π’›π‘–=βˆ‚(π‘’π’›π‘—βˆ‘π‘˜=1π‘π‘’π’›π‘˜)βˆ‚π’›π‘–=βˆ’π‘’π’›π‘—β‹…π‘’π’›π‘–(βˆ‘π‘˜=1π‘π‘’π’›π‘˜)2=βˆ’π’‘π‘—π’‘π‘–

2.2.2. Step 2: Compute βˆ‚πΏβˆ‚π’›π‘–

βˆ‚πΏβˆ‚π’›π‘–=βˆ‘π‘—=1π‘βˆ‚(βˆ’π’•π‘—log𝒑𝑗)βˆ‚π’›π‘–=βˆ’βˆ‘π‘—=1π‘π’•π‘—βˆ‚(log𝒑𝑗)βˆ‚π’›π‘–=βˆ’βˆ‘π‘—=1𝑁𝒕𝑗1π’‘π‘—βˆ‚π’‘π‘—βˆ‚π’›π‘–=βˆ’π’•π‘–π’‘π‘–βˆ‚π’‘π‘–βˆ‚π’›π‘–βˆ’βˆ‘π‘—=1,π‘—β‰ π‘–π‘π’•π‘—π’‘π‘—βˆ‚π’‘π‘—βˆ‚π’›π‘–=βˆ’π’•π‘–π’‘π‘–π’‘π‘–(1βˆ’π’‘π‘–)βˆ’βˆ‘π‘—=1,𝑗≠𝑖𝑁𝒕𝑗𝒑𝑗(βˆ’π’‘π‘—π’‘π‘–)=βˆ’π’•π‘–+𝒕𝑖𝒑𝑖+βˆ‘π‘—=1,𝑗≠𝑖𝑁𝒕𝑗𝒑𝑖=βˆ’π’•π‘–+βˆ‘π‘—=1𝑁𝒕𝑗𝒑𝑖=βˆ’π’•π‘–+π’‘π‘–βˆ‘π‘—=1𝑁𝒕𝑗=βˆ’π’•π‘–+𝒑𝑖=π’‘π‘–βˆ’π’•π‘–

So,

βˆ‚πΏβˆ‚π’›=π’‘βˆ’π’•

2.3. Gradient in Matrix Form

For batch computations, it's efficient to represent gradients in matrix form.

Given:

  • π‘·βˆˆβ„π‘›Γ—π‘‘: Matrix of predicted probabilities for a batch of size 𝑛.
  • π’βˆˆβ„π‘›Γ—π‘‘: Matrix of logits.
  • π’€βˆˆβ„π‘›Γ—π‘‘: One-hot encoded true labels.

The gradient with respect to the logits is:

βˆ‚π‘·π‘–,π‘—βˆ‚π’π‘–,π‘˜=𝑷𝑖,𝑗(𝛿𝑗,π‘˜βˆ’π‘·π‘–,π‘˜)βˆ‚πΏβˆ‚π’=π‘·βˆ’π’€

Normalized by batch size, the overall gradient of the loss is:

βˆ‚πΏβˆ‚π’=1𝑛(π‘·βˆ’π’€)

3. Linear-Softmax-Cross-Entropy

Cross-entropy loss is typically preceded by a linear (fully connected) layer and followed by a softmax activation. If we can fuse the linear layer and softmax activation, we may avoid the full materialization of the logit matrix.

  • Input before the final linear layer: π‘Ώβˆˆβ„π‘›Γ—π‘‘Β in
  • Linear weights: π‘Ύβˆˆβ„π‘‘in ×𝑑 out
  • Linear bias: π’ƒβˆˆβ„π‘‘Β out
  • Labels: π‘¦βˆˆ{0,1,…,π‘›βˆ’1}, representing the true classes for each instance in the batch.

3.1. Forward Pass

With a linear transformation, the input 𝑿 is transformed linearly using the weights and bias:

𝒁=𝑿𝑾+𝒃

Softmax:

𝑷𝑖,𝑗=𝑒𝒁𝑖,π‘—βˆ‘π‘˜=1𝑑out 𝑒𝒁𝑖,π‘˜

Cross-entropy loss is computed for each instance and then averaged over the batch:

𝐿𝑖=βˆ’log(𝑷𝑖,𝑦𝑖)𝐿=1π‘›βˆ‘π‘–=1𝑛𝐿𝑖

3.2. Backward Pass

Gradient of 𝒁:

βˆ‚πΏβˆ‚π’=1𝑛(π‘·βˆ’π’€)

Gradient of 𝑾:

βˆ‚πΏβˆ‚π‘Ύ=π‘Ώπ‘‡βˆ‚πΏβˆ‚π’

Gradient of 𝒃:

βˆ‚πΏβˆ‚π’ƒ=βˆ‘π‘–=1π‘›βˆ‚πΏβˆ‚π’π‘–

Gradient of input 𝑿:

βˆ‚πΏβˆ‚π‘Ώ=βˆ‚πΏβˆ‚π’π‘Ύπ‘‡

3.3. Summary of Gradients

Parameter Formula Dimensions
𝒁 𝒁=𝑿𝑾+𝒃 [𝑛,𝑑 out]
𝑷 𝑷=softmax(𝒁) [𝑛,𝑑 out]
𝐿 𝐿=βˆ’1π‘›βˆ‘log(𝑷𝑖,𝑦𝑖) Scalar
𝑑𝒁 𝑑𝒁=1𝑛(π‘·βˆ’π’€) [𝑛,𝑑 out]
𝑑𝑾 𝑑𝑾=𝑿𝑇𝑑𝒁 [𝑑in,𝑑 out]
𝑑𝒃 𝑑𝒃=sum(𝑑𝒁) [𝑑out]
𝑑𝑿 𝑑𝑿=𝑑𝒁𝑾𝑇 [𝑛,𝑑 in]

4. Optimization Strategies

  • Chunking the logit matrix: Chunking over the batch can avoid materializing the full logit matrix. The logit matrix is divided into chunks over the batch size dimension, and the cross-entropy loss is computed for each chunk. The final loss is the sum of the losses of all chunks.

  • Compute the gradient of logit in place: The gradient of the logit matrix is computed in place, and the gradient of the input is computed by multiplying the gradient of the logit matrix with the weight matrix.

4.1. efficient_cross_entropy

4.2. liger kernel

4.3. cut your losses in large-vocabulary language models