Cross-entropy loss and its optimization [WIP]
  • 12/12/2024
  • Xiaotian Han

  • background

    Computing cross-entropy loss becomes significantly more challenging for LLM. This is primarily due to the extremely large logit and label matrices involved in the calculations, which can lead to high computational costs and memory usage. Recently, serveral optimization strategies have been proposed to address this issue start form pytorch github issue.

    1. https://github.com/pytorch/pytorch/issues/124480
    2. https://github.com/mgmalek/efficient_cross_entropy
    3. liger kernel [github] [arxiv]
    4. Cut Your Losses in Large-Vocabulary Language Models [arxiv]

    All these approaches share a common goal: avoiding the full materialization of the logit matrix. They achieve this by:

    1. chunking the logit matrix
    2. computating the gradient of logit in place

    In this blog, I will dive into the cross entropy loss and its optimization strategies.

    softmax cross-entropy

    forward pass

    Let’s begin by understanding the forward pass of the cross-entropy loss.

    Consider:

    • An input vector \(\mathbf{x} \in \mathbb{R}^d\) representing the logits (unnormalized scores) produced by the model for each class.
    • A true label \(y \in \{0, 1, \dots, d-1\}\) indicating the correct class.

    The softmax function converts the logits into probabilities:

    \[\mathbf{p}_{i} = \frac{e^{\mathbf{x}_{i}}}{\sum_{k=1}^{d} e^{\mathbf{x}_{k}}}\]

    Here, \(\mathbf{p}_i\) represents the probability of the input belonging to class \(i\).

    The cross-entropy loss for a single instance is then defined as:

    \[L = -\log(\mathbf{p}_{y})\]

    Expanding this, we get:

    \[L = -\log(\mathbf{p}_{y}) = -\log(\frac{e^{\mathbf{x}_{y}}}{\sum_{k=1}^{d} e^{\mathbf{x}_{k}}}) = -\log(e^{\mathbf{x}_{y}}) + \log(\sum_{k=1}^{d} e^{\mathbf{x}_{k}}) = -\mathbf{x}_{y} + \log(\sum_{k=1}^{d} e^{\mathbf{x}_{k}})\]

    backward pass

    in general, the gradient of the loss with respect to the input is given by

    \[\frac{\partial L}{\partial \mathbf{z}_i} = \frac{\partial L}{\partial \mathbf{p}_j} \frac{\partial \mathbf{p}_j}{\partial \mathbf{z}_i}\]
    step 1: Compute \(\frac{\partial \mathbf{p}_j}{\partial \mathbf{z}_i}\)
    \[\frac{\partial \mathbf{p}_j}{\partial \mathbf{z}_i} = \begin{cases} \displaystyle \frac{\partial \frac{e^{\mathbf{z}_j}}{\sum_{i=1}^{N} e^{\mathbf{z}_i}}}{\partial \mathbf{z}_j} = \frac{\sum_{k=1}^{N} e^{\mathbf{z}_k} \cdot e^{\mathbf{z}_j}\;-\; e^{\mathbf{z}_j}\, e^{\mathbf{z}_j}}{\left(\sum_{k=1}^{N} e^{\mathbf{z}_k}\right)^2} = \frac{e^{\mathbf{z}_j}}{\sum_{i=1}^{N} e^{\mathbf{z}_i}} \left[ 1 - \frac{e^{\mathbf{z}_j}}{\sum_{i=1}^{N} e^{\mathbf{z}_i}} \right] = \mathbf{p}_j(1 - \mathbf{p}_j) & j = i, \\ \\[0.2em] \displaystyle \frac{\partial \frac{e^{\mathbf{z}_j}}{\sum_{i=1}^{N} e^{\mathbf{z}_i}}}{\partial \mathbf{z}_i} = \frac{\sum_{i=1}^{N} e^{\mathbf{z}_i} \cdot \frac{\partial e^{\mathbf{z}_j}}{\partial \mathbf{z}_i} - e^{\mathbf{z}_j} \cdot \frac{\partial \sum_{i=1}^{N} e^{\mathbf{z}_i}}{\partial \mathbf{z}_i}}{(\sum_{i=1}^{N} e^{\mathbf{z}_i})^2} = -\frac{e^{\mathbf{z}_j}\, e^{\mathbf{z}_i}} {\left(\sum_{k=1}^{N} e^{\mathbf{z}_k}\right)^2} = -\,\mathbf{p}_j\, \mathbf{p}_i & j \neq i. \end{cases}\]
    step 2: Compute \(\frac{\partial L}{\partial \mathbf{z}_i}\)
    \[\begin{aligned} \frac{\partial L}{\partial \mathbf{z}_i} &= \sum_{j=1}^{N} \frac{\partial (-\mathbf{t}_j \log \mathbf{p}_j)}{\partial \mathbf{z}_i} \\ &= - \sum_{j=1}^{N} \mathbf{t}_j \frac{\partial (\log \mathbf{p}_j)}{\partial \mathbf{z}_i} \\ &= - \sum_{j=1}^{N} \mathbf{t}_j \frac{1}{\mathbf{p}_j} \frac{\partial \mathbf{p}_j}{\partial \mathbf{z}_i} \\ &= - \frac{\mathbf{t}_i}{\mathbf{p}_i} \frac{\partial \mathbf{p}_i}{\partial \mathbf{z}_i} - \sum_{\substack{j=1 \\ j \neq i}}^{N} \frac{\mathbf{t}_j}{\mathbf{p}_j} \frac{\partial \mathbf{p}_j}{\partial \mathbf{z}_i} \\ &= - \frac{\mathbf{t}_i}{\mathbf{p}_i} \mathbf{p}_i(1 - \mathbf{p}_i) - \sum_{\substack{j=1 \\ j \neq i}}^{N} \frac{\mathbf{t}_j}{\mathbf{p}_j} (-\mathbf{p}_j \mathbf{p}_i) \\ &= -\mathbf{t}_i + \mathbf{t}_i \mathbf{p}_i + \sum_{\substack{j=1 \\ j \neq i}}^{N} \mathbf{t}_j \mathbf{p}_i \\ &= -\mathbf{t}_i + \sum_{j=1}^{N} \mathbf{t}_j \mathbf{p}_i \\ &= -\mathbf{t}_i + \mathbf{p}_i \sum_{j=1}^{N} \mathbf{t}_j \\ &= - \mathbf{t}_i + \mathbf{p}_i \\ &= \mathbf{p}_i - \mathbf{t}_i \end{aligned}\]

    so

    \[\frac{\partial L}{\partial \mathbf{z}} = \mathbf{p} - \mathbf{t}\]

    gradient in matrix form

    for batch computations, it’s efficient to represent gradients in matrix form.

    Given:

    • \(\mathbf{P} \in \mathbb{R}^{n \times d}\): Matrix of predicted probabilities for a batch of size \(n\).
    • \(\mathbf{Z} \in \mathbb{R}^{n \times d}\): Matrix of logits.
    • \(\mathbf{Y} \in \mathbb{R}^{n \times d}\): One-hot encoded true labels.

    The gradient with respect to the logits is:

    \[\frac{\partial \mathbf{P}_{ij}}{\partial \mathbf{Z}_{ik}} = \mathbf{P}_{ij} (\delta_{jk} - \mathbf{P}_{ik})\] \[\frac{\partial L}{\partial \mathbf{Z}} = \mathbf{P} - \mathbf{Y}\]

    normalize by batch size and the overall gradient of the loss is:

    \[\frac{\partial L}{\partial \mathbf{Z}} = \frac{1}{n}(\mathbf{P} - \mathbf{Y}) \in [n, d]\]

    liner - softmax - cross-entropy

    cross-entropy loss is typically preceded by a linear (fully connected) layer and followed by a softmax activation. If we can fused the linear layer and softmax activation, we may avoid the full materialization of the logit matrix.

    • input before the final liner: \(\mathbf{X} \in [n, d_{in}]\)
    • liner weights: \(\mathbf{W} \in [d_{in}, d_{out}]\)
    • liner bias: \(b \in [d_{out}]\)
    • labels: \(y \in \{0, 1, \dots, n-1\}\), representing the true classes for each instance in the batch.

    forward pass

    linear transformation, the input \(\mathbf{X}\) is transformed linearly using the weights and bias:

    \[\mathbf{Z} = \mathbf{X} \mathbf{W} + \mathbf{b} \in [n, d_{out}]\]

    softmax

    \[\mathbf{P}_{ij} = \frac{e^{\mathbf{Z}_{ij}}}{\sum_{k=1}^{d_{out}} e^{\mathbf{Z}_{ik}}} \in [n, d_{out}]\]

    cross-entropy loss is computed for each instance and then averaged over the batch:

    \[L_i = -\log(\mathbf{P}_{i, y_i})\] \[L = \frac{1}{n} \sum_{i=1}^{n} L_i\]

    backward pass

    gradient of \(\mathbf{Z}\)

    \[\frac{\partial L}{\partial \mathbf{Z}} = \frac{1}{n}(\mathbf{P} - \mathbf{Y}), \quad ( [n, d_{out}] = [n, d_{out}] - [n, d_{out}] )\]

    gradient of \(\mathbf{W}\)

    \[\frac{\partial L}{\partial \mathbf{W}} = \mathbf{X}^\top \frac{\partial L}{\partial \mathbf{Z}} \quad ([d_{in}, d_{out}] = [n, d_{in}]^\top [n, d_{out}])\]

    gradient of \(\mathbf{b}\)

    \[\frac{\partial L}{\partial \mathbf{b}} = \sum_{i=1}^{n} \frac{\partial L}{\partial \mathbf{Z}_i} \quad ([d_{out}] = \sum_{i=1}^{n} [n, d_{out}])\]

    gradient of input \(\mathbf{X}\)

    \[\frac{\partial L}{\partial \mathbf{X}} = \frac{\partial L}{\partial \mathbf{Z}} \mathbf{W}^\top \quad ([n, d_{in}] = [n, d_{out}] [d_{in}, d_{out}]^\top)\]

    summary of gradients

    Parameter Formula Dimensions
    \(\mathbf{Z}\) \(\mathbf{Z} = \mathbf{X}\mathbf{W} + \mathbf{b}\) \([n, d_{out}]\)
    \(\mathbf{P}\) \(\mathbf{P} = \text{softmax}(\mathbf{Z})\) \([n, d_{out}]\)
    \(L\) \(L = -\frac{1}{n} \sum \log(\mathbf{P}_{i, y_i})\) Scalar
    \(d\mathbf{Z}\) \(d\mathbf{Z} = \frac{1}{n}(\mathbf{P} - \mathbf{Y})\) \([n, d_{out}]\)
    \(d\mathbf{W}\) \(d\mathbf{W} = \mathbf{X}^\top d\mathbf{Z}\) \([d_{in}, d_{out}]\)
    \(db\) \(db = \text{sum}(d\mathbf{Z})\) \([d_{out}]\)
    \(d\mathbf{X}\) \(d\mathbf{X} = d\mathbf{Z} \mathbf{W}^\top\) \([n, d_{in}]\)

    optimization strategies

    • chunking the logit matrix
      • chunking over the batch can avoid matrelizing full logit matrix. The logit matrix is divided into chunks over the batch size dimension, and the cross-entropy loss is computed for each chunk. The final loss is the sum of the losses of all chunks.
    • computate the gradient of logit in place
      • The gradient of the logit matrix is computed in place, and the gradient of the input is computed by multiplying the gradient of the logit matrix with the weight matrix.

    efficient_cross_entropy

    liger kernel

    cut your losses in large-vocabulary language models