Cross-entropy loss and its optimization [WIP]

1. Background

Computing cross-entropy loss becomes significantly more challenging for LLMs. This is primarily due to the extremely large logit and label matrices involved in the calculations, which can lead to high computational costs and memory usage. Recently, several optimization strategies have been proposed to address this issue, starting from a Pytorch GitHub issue.

All these approaches share a common goal: avoiding the full materialization of the logit matrix. They achieve this by:

  1. chunking the logit matrix
  2. computating the gradient of logit in place

In this blog, I will dive into the cross entropy loss and its optimization strategies.

2. Softmax Cross-Entropy

2.1. Forward Pass

Let’s begin by understanding the forward pass of the cross-entropy loss.

Consider:

The softmax function converts the logits into probabilities:

𝒑𝑖=π‘’π’™π‘–βˆ‘π‘˜=1π‘‘π‘’π’™π‘˜

Here, 𝒑𝑖 represents the probability of the input belonging to class 𝑖.

The cross-entropy loss for a single instance is then defined as:

𝐿=βˆ’log(𝒑𝑦)

Expanding this, we get:

𝐿=βˆ’log(𝒑𝑦)=βˆ’log(π‘’π’™π‘¦βˆ‘π‘˜=1π‘‘π‘’π’™π‘˜)=βˆ’log(𝑒𝒙𝑦)+log(βˆ‘π‘˜=1π‘‘π‘’π’™π‘˜)=βˆ’π’™π‘¦+log(βˆ‘π‘˜=1π‘‘π‘’π’™π‘˜)

2.2. Backward Pass

In general, the gradient of the loss with respect to the input is given by

βˆ‚πΏβˆ‚π’›π‘–=βˆ‚πΏβˆ‚π’‘π‘—βˆ‚π’‘π‘—βˆ‚π’›π‘–

2.2.1. Step 1: Compute βˆ‚π’‘π‘—βˆ‚π’›π‘–

The result is:

βˆ‚π’‘π‘—βˆ‚π’›π‘–={𝒑𝑗(1βˆ’π’‘π‘—)if  𝑗=π‘–βˆ’π’‘π‘—π’‘π‘–if  𝑗≠𝑖

The full derivation for the case 𝑗=𝑖 is:

βˆ‚π’‘π‘—βˆ‚π’›π‘—=βˆ‚(π‘’π’›π‘—βˆ‘π‘˜=1π‘π‘’π’›π‘˜)βˆ‚π’›π‘—=(βˆ‘π‘˜=1π‘π‘’π’›π‘˜)β‹…π‘’π’›π‘—βˆ’π‘’π’›π‘—π‘’π’›π‘—(βˆ‘π‘˜=1π‘π‘’π’›π‘˜)2=(π‘’π’›π‘—βˆ‘π‘˜=1π‘π‘’π’›π‘˜)(1βˆ’π‘’π’›π‘—βˆ‘π‘˜=1π‘π‘’π’›π‘˜)=𝒑𝑗(1βˆ’π’‘π‘—)

And for 𝑗≠𝑖:

βˆ‚π’‘π‘—βˆ‚π’›π‘–=βˆ‚(π‘’π’›π‘—βˆ‘π‘˜=1π‘π‘’π’›π‘˜)βˆ‚π’›π‘–=βˆ’π‘’π’›π‘—β‹…π‘’π’›π‘–(βˆ‘π‘˜=1π‘π‘’π’›π‘˜)2=βˆ’π’‘π‘—π’‘π‘–

2.2.2. Step 2: Compute βˆ‚πΏβˆ‚π’›π‘–

βˆ‚πΏβˆ‚π’›π‘–=βˆ‘π‘—=1π‘βˆ‚(βˆ’π’•π‘—log𝒑𝑗)βˆ‚π’›π‘–=βˆ’βˆ‘π‘—=1π‘π’•π‘—βˆ‚(log𝒑𝑗)βˆ‚π’›π‘–=βˆ’βˆ‘π‘—=1𝑁𝒕𝑗1π’‘π‘—βˆ‚π’‘π‘—βˆ‚π’›π‘–=βˆ’π’•π‘–π’‘π‘–βˆ‚π’‘π‘–βˆ‚π’›π‘–βˆ’βˆ‘π‘—=1,π‘—β‰ π‘–π‘π’•π‘—π’‘π‘—βˆ‚π’‘π‘—βˆ‚π’›π‘–=βˆ’π’•π‘–π’‘π‘–π’‘π‘–(1βˆ’π’‘π‘–)βˆ’βˆ‘π‘—=1,𝑗≠𝑖𝑁𝒕𝑗𝒑𝑗(βˆ’π’‘π‘—π’‘π‘–)=βˆ’π’•π‘–+𝒕𝑖𝒑𝑖+βˆ‘π‘—=1,𝑗≠𝑖𝑁𝒕𝑗𝒑𝑖=βˆ’π’•π‘–+βˆ‘π‘—=1𝑁𝒕𝑗𝒑𝑖=βˆ’π’•π‘–+π’‘π‘–βˆ‘π‘—=1𝑁𝒕𝑗=βˆ’π’•π‘–+𝒑𝑖=π’‘π‘–βˆ’π’•π‘–

So,

βˆ‚πΏβˆ‚π’›=π’‘βˆ’π’•

2.3. Gradient in Matrix Form

For batch computations, it’s efficient to represent gradients in matrix form.

Given:

The gradient with respect to the logits is:

βˆ‚π‘·π‘–,π‘—βˆ‚π’π‘–,π‘˜=𝑷𝑖,𝑗(𝛿𝑗,π‘˜βˆ’π‘·π‘–,π‘˜)βˆ‚πΏβˆ‚π’=π‘·βˆ’π’€

Normalized by batch size, the overall gradient of the loss is:

βˆ‚πΏβˆ‚π’=1𝑛(π‘·βˆ’π’€)

3. Linear-Softmax-Cross-Entropy

Cross-entropy loss is typically preceded by a linear (fully connected) layer and followed by a softmax activation. If we can fuse the linear layer and softmax activation, we may avoid the full materialization of the logit matrix.

3.1. Forward Pass

With a linear transformation, the input 𝑿 is transformed linearly using the weights and bias:

𝒁=𝑿𝑾+𝒃

Softmax:

𝑷𝑖,𝑗=𝑒𝒁𝑖,π‘—βˆ‘π‘˜=1𝑑out 𝑒𝒁𝑖,π‘˜

Cross-entropy loss is computed for each instance and then averaged over the batch:

𝐿𝑖=βˆ’log(𝑷𝑖,𝑦𝑖)𝐿=1π‘›βˆ‘π‘–=1𝑛𝐿𝑖

3.2. Backward Pass

Gradient of 𝒁:

βˆ‚πΏβˆ‚π’=1𝑛(π‘·βˆ’π’€)

Gradient of 𝑾:

βˆ‚πΏβˆ‚π‘Ύ=π‘Ώπ‘‡βˆ‚πΏβˆ‚π’

Gradient of 𝒃:

βˆ‚πΏβˆ‚π’ƒ=βˆ‘π‘–=1π‘›βˆ‚πΏβˆ‚π’π‘–

Gradient of input 𝑿:

βˆ‚πΏβˆ‚π‘Ώ=βˆ‚πΏβˆ‚π’π‘Ύπ‘‡

3.3. Summary of Gradients

Parameter Formula Dimensions
𝒁 𝒁=𝑿𝑾+𝒃 [𝑛,𝑑 out]
𝑷 𝑷=softmax(𝒁) [𝑛,𝑑 out]
𝐿 𝐿=βˆ’1π‘›βˆ‘log(𝑷𝑖,𝑦𝑖) Scalar
𝑑𝒁 𝑑𝒁=1𝑛(π‘·βˆ’π’€) [𝑛,𝑑 out]
𝑑𝑾 𝑑𝑾=𝑿𝑇𝑑𝒁 [𝑑in,𝑑 out]
𝑑𝒃 𝑑𝒃=sum(𝑑𝒁) [𝑑out]
𝑑𝑿 𝑑𝑿=𝑑𝒁𝑾𝑇 [𝑛,𝑑 in]

4. Optimization Strategies

4.1. efficient_cross_entropy

4.2. liger kernel

4.3. cut your losses in large-vocabulary language models