Research Article - Xiaotian Han | Academic Insights

Softmax and its triton implementation

1.Background
2.Gradient of softmax (vector form)
3.softmax - batch form
4.Implementation
5.Results: speed comparison
6.Notations

1. Background

The softmax function is a fundamental operation in deep learning that converts vectors of real numbers into probability distributions. This blog post explores the softmax function's implementation and optimization using Triton, a programming framework for efficient GPU computations.

TL;DR

  • dive into softmax, from math to implementation, from vector to matrix.
  • torch and triton implementations, with reference code and speed comparison.

The softmax function transforms an input vector into a probability distribution where all elements sum to 1.

1.1. softmax - vector form

๐’๐‘–=softmax(๐’™๐‘–)=๐‘’๐’™๐‘–โˆ‘๐‘—=1๐‘‘๐‘’๐’™๐‘—

where:

  • ๐’™โˆˆโ„๐‘‘: input vector.
  • ๐’โˆˆโ„๐‘‘: output vector, probability distribution.

2. Gradient of softmax (vector form)

We will compute gradients โˆ‚๐ฟโˆ‚๐’™ given โˆ‚๐ฟโˆ‚๐’, where ๐ฟ is loss function, ๐’ is softmax output.

2.1. Jacobian matrix

softmax is a vector function, the Jacobian matrix is the matrix of all partial derivatives:

โˆ‚๐’โˆ‚๐’™=๐‘ฑ=(โˆ‚๐’1โˆ‚๐’™1โˆ‚๐’1โˆ‚๐’™2โ€ฆโˆ‚๐’1โˆ‚๐’™๐‘‘โˆ‚๐’2โˆ‚๐’™1โˆ‚๐’2โˆ‚๐’™2โ€ฆโˆ‚๐’2โˆ‚๐’™๐‘‘ย โ‹ฎย โ‹ฎย โ‹ฑย โ‹ฎโˆ‚๐’๐‘‘โˆ‚๐’™1โˆ‚๐’๐‘‘โˆ‚๐’™2โ€ฆโˆ‚๐’๐‘‘โˆ‚๐’™๐‘‘)

For softmax, the derivative has two cases:

  1. when ๐‘–=๐‘—, consider ๐’๐‘–=๐‘’๐’™๐‘–โˆ‘๐‘—=1๐‘‘๐‘’๐’™๐‘—, the derivative is:

    โˆ‚๐’๐‘–โˆ‚๐’™๐‘–=๐’๐‘–(1โˆ’๐’๐‘–)
  2. similarly, when ๐‘–โ‰ ๐‘—:

    โˆ‚๐’๐‘–โˆ‚๐’™๐‘—=โˆ’๐’๐‘–๐’๐‘—

Thus, (๐‘–,๐‘—)-th element in Jacobian matrix will be:

๐‘ฑ๐‘–๐‘—=๐’๐‘–(๐›ฟ๐‘–๐‘—โˆ’๐’๐‘—)

where ๐‘ฑ has shape [๐‘‘ร—๐‘‘] and ๐›ฟ๐‘–๐‘— is the Kronecker delta, which is 1 if ๐‘–=๐‘— and 0 otherwise.

In matrix form, the Jacobian of the softmax is:

๐‘ฑ=diag(๐’)โˆ’๐’๐’๐‘‡

where:

  • ๐’ is the output of softmax, the shape is [๐‘‘].
  • diag(๐’) is a diagonal matrix of ๐’, the shape is [๐‘‘ร—๐‘‘].
  • ๐’๐’๐‘‡ is the outer product of ๐’ with itself, the shape is [๐‘‘ร—๐‘‘].

2.2. gradient of โˆ‚๐ฟโˆ‚๐’™

Given โˆ‚๐ฟโˆ‚๐’, we can compute โˆ‚๐ฟโˆ‚๐’™ using the Jacobian matrix:

โˆ‚๐ฟโˆ‚๐’™=โˆ‚๐’โˆ‚๐’™โ‹…โˆ‚๐ฟโˆ‚๐’=๐‘ฑ๐‘‡โ‹…โˆ‚๐ฟโˆ‚๐’

where โˆ‚๐ฟโˆ‚๐’ has shape [๐‘‘], ๐‘ฑ๐‘‡ has shape [๐‘‘ร—๐‘‘], and โˆ‚๐ฟโˆ‚๐’™ has shape [๐‘‘].

2.3. avoid explicit Jacobian

For the ๐‘–-th element of โˆ‚๐ฟโˆ‚๐’™, we can decompose the computation to:

โˆ‚๐ฟโˆ‚๐’™๐‘–=๐’๐‘–(โˆ‚๐ฟโˆ‚๐’๐‘–โˆ’โˆ‘๐‘—=1๐‘‘๐’๐‘—โˆ‚๐ฟโˆ‚๐’๐‘—)

This leads to an efficient vector form:

๐‘ grad=(๐’โˆ—โˆ‚๐ฟโˆ‚๐’)sumโˆ‚๐ฟโˆ‚๐’™=๐’โˆ—(โˆ‚๐ฟโˆ‚๐’โˆ’๐‘ grad)

3. softmax - batch form

๐‘ฟ: A batch of input vectors.

๐‘ฟโˆˆโ„๐‘ร—๐‘‘

where:

  • ๐‘ is batch size.
  • ๐‘‘ is vector dimension.

3.1. forward pass

๐‘ฌ=๐‘’๐‘ฟ๐’”=โˆ‘๐‘—=1๐‘‘๐‘’๐‘ฟ๐‘–๐‘—๐‘ถ=๐‘ฌ๐’”

where ๐‘ฌโˆˆโ„๐‘ร—๐‘‘, ๐’”โˆˆโ„๐‘ร—1, ๐‘ถโˆˆโ„๐‘ร—๐‘‘.

3.2. backward pass

We have gradient with respect to softmax output:

โˆ‚๐ฟโˆ‚๐‘ถโˆˆโ„๐‘ร—๐‘‘

we compute the gradient:

๐’”gradย =(๐‘ถโˆ—โˆ‚๐ฟโˆ‚๐‘ถ)ย row_sumย โˆˆโ„๐‘ร—1

where ๐‘ถ has size [๐‘ร—๐‘‘], and โˆ‚๐ฟโˆ‚๐‘ถ has size [๐‘ร—๐‘‘].

โˆ‚๐ฟโˆ‚๐‘ฟ=๐‘ถโˆ—(โˆ‚๐ฟโˆ‚๐‘ถโˆ’๐’”ย grad)

where โˆ‚๐ฟโˆ‚๐‘ฟโˆˆโ„๐‘ร—๐‘‘ and ๐‘ถโˆˆโ„๐‘ร—๐‘‘ and ๐’”gradย โˆˆโ„๐‘ร—1 will be broadcasted to โ„๐‘ร—๐‘‘.

4. Implementation

In practice, we subtract the maximum value from each row before applying exp() to prevent numerical overflow:

4.1. real forward pass

For input ๐‘ฟโˆˆโ„๐‘ร—๐‘‘:

๐‘ฟmaxย =max(๐‘ฟ)โˆˆโ„๐‘ร—1๐‘ฌ=๐‘’๐‘ฟโˆ’๐‘ฟย max๐’”=โˆ‘๐‘—=1๐‘‘๐‘’๐‘ฟ๐‘–๐‘—โˆ’๐‘ฟย max๐‘ถ=๐‘ฌ๐’”

4.2. real backward pass

we have โˆ‚๐ฟโˆ‚๐‘ถโˆˆโ„๐‘ร—๐‘‘ and cached ๐‘ถโˆˆโ„๐‘ร—๐‘‘

๐’”gradย =(๐‘ถโˆ—โˆ‚๐ฟโˆ‚๐‘ถ)ย row_sumโˆ‚๐ฟโˆ‚๐‘ฟ=๐‘ถโˆ—(โˆ‚๐ฟโˆ‚๐‘ถโˆ’๐’”ย grad)

4.3. a real example

give a real example to show how to implement softmax and its backward pass in pytorch and triton.

forwards pass is as follows:

๐‘‹=(1.02.03.01.03.05.0)๐‘‹maxย =(3.05.0)๐‘‹โˆ’๐‘‹ย maxย =(โˆ’2.0โˆ’1.00.0โˆ’4.0โˆ’2.00.0)๐ธ=๐‘’๐‘‹โˆ’๐‘‹ย max=(๐‘’โˆ’2.0๐‘’โˆ’1.0๐‘’0.0๐‘’โˆ’4.0๐‘’โˆ’2.0๐‘’0.0)๐ธ=(0.13530.36791.00000.01830.13531.0000)๐‘†=(1.50321.1536)๐‘‚=๐ธ๐‘†=(0.09000.24470.66520.01590.11730.8668)

backward pass is as follows:

๐‘‘๐‘‚=(0.10.20.70.20.30.5)๐‘ gradย =(0.20360.2597)๐‘‘๐‘‹=๐‘‚โˆ—(๐‘‘๐‘‚โˆ’๐‘ ย grad)๐‘‘๐‘‹=(โˆ’0.0381โˆ’0.07920.1173โˆ’0.0043โˆ’0.02020.0245)

4.4. native pytorch implementation

import torch
import torch.nn.functional as F

# Custom Forward Pass (Numerically Stable Softmax)
def softmax_forward(X):
X_max = torch.max(X, dim=1, keepdim=True)[0] # Shape: (N, 1)
E = torch.exp(X - X_max) # Shape: (N, d)
S = torch.sum(E, dim=1, keepdim=True) # Shape: (N, 1)
O = E / S # Shape: (N, d)
return O

# Custom Backward Pass (Gradient Calculation)
def softmax_backward(dL_dO, O):
s_grad = torch.sum(O * dL_dO, dim=1, keepdim=True) # Shape: (N, 1)
dL_dX = O * (dL_dO - s_grad) # Shape: (N, d)
return dL_dX

# Example Inputs
X = torch.tensor([[1.0, 2.0, 3.0], [1.0, 3.0, 5.0]], requires_grad=True)
dL_dO = torch.tensor([[0.1, 0.2, 0.7], [0.2, 0.3, 0.5]])

# Custom Implementation - Forward
O_custom = softmax_forward(X)

# PyTorch Implementation - Forward
O_pytorch = F.softmax(X, dim=1)

# Verify Forward Output
print("Custom Softmax Output:\n", O_custom)
print("PyTorch Softmax Output:\n", O_pytorch)
print("Forward Pass Match:", torch.allclose(O_custom, O_pytorch))

# Custom Implementation - Backward
dL_dX_custom = softmax_backward(dL_dO, O_custom)

# PyTorch Automatic Gradient Calculation
O_pytorch.backward(dL_dO) # Computes gradient using PyTorch autograd
dL_dX_pytorch = X.grad

# Verify Backward Output
print("\nCustom Gradient w.r.t Input:\n", dL_dX_custom)
print("PyTorch Gradient w.r.t Input:\n", dL_dX_pytorch)
print("Backward Pass Match:", torch.allclose(dL_dX_custom, dL_dX_pytorch))

output:

Custom Softmax Output:
tensor([[0.0900, 0.2447, 0.6652],
[0.0159, 0.1173, 0.8668]], grad_fn=<DivBackward0>)
PyTorch Softmax Output:
tensor([[0.0900, 0.2447, 0.6652],
[0.0159, 0.1173, 0.8668]], grad_fn=<SoftmaxBackward0>)
Forward Pass Match: True

Custom Gradient w.r.t Input:
tensor([[-0.0381, -0.0792, 0.1173],
[-0.0043, -0.0202, 0.0245]], grad_fn=<MulBackward0>)
PyTorch Gradient w.r.t Input:
tensor([[-0.0381, -0.0792, 0.1173],
[-0.0043, -0.0202, 0.0245]])
Backward Pass Match: True

4.5. triton implementation

from typing import Optional

import torch
import triton
import triton.language as tl


@triton.jit
def softmax_fwd_kernel(
X,
O,
D: tl.constexpr,
B: tl.constexpr
):
i_n = tl.program_id(0)
o_d = tl.arange(0, B)
m_d = o_d < D

X_max = tl.max(tl.load(X + i_n * D + o_d, mask=m_d, other=-float('inf')), 0)
E = tl.exp(tl.load(X + i_n * D + o_d, mask=m_d, other=-float('inf')) - X_max)
S = tl.sum(E, 0)
P = E / S

tl.store(O + i_n * D + o_d, P.to(O.dtype.element_ty), mask=m_d)


@triton.jit
def softmax_bwd_kernel(
O,
dO,
dX,
D: tl.constexpr,
B: tl.constexpr
):
i_n = tl.program_id(0)
o_d = tl.arange(0, B)
m_d = o_d < D

P = tl.load(O + i_n * D + o_d, mask=m_d, other=0.)
dP = tl.load(dO + i_n * D + o_d, mask=m_d, other=0.)
s_grad = tl.sum(P * dP, 0)
dX_row = P * (dP - s_grad)

tl.store(dX + i_n * D + o_d, dX_row.to(dX.dtype.element_ty), mask=m_d)


def softmax_fwd(
X: torch.Tensor,
dtype: Optional[torch.dtype] = torch.float
) -> torch.Tensor:
shape = X.shape
X = X.view(-1, X.shape[-1])

N, D = X.shape
B = triton.next_power_of_2(D)

O = torch.empty_like(X, dtype=dtype)
softmax_fwd_kernel[(N,)](
X=X,
O=O,
D=D,
B=B
)
return O.view(*shape)


def softmax_bwd(
O: torch.Tensor,
dO: torch.Tensor,
dtype: Optional[torch.dtype] = torch.float
) -> torch.Tensor:
shape = O.shape
O = O.view(-1, O.shape[-1])
dX = torch.empty_like(O, dtype=dtype)

N, D = O.shape
B = triton.next_power_of_2(D)
softmax_bwd_kernel[(N,)](
O=O,
dO=dO,
dX=dX,
D=D,
B=B
)
return dX.view(*shape)

# Test code to verify correctness
import torch.nn.functional as F

# Example inputs
X = torch.tensor([[1.0, 2.0, 3.0], [1.0, 3.0, 5.0]], requires_grad=True, device='cuda')
dP = torch.tensor([[0.1, 0.2, 0.7], [0.2, 0.3, 0.5]], device='cuda')

# Forward pass
P_triton = softmax_fwd(X)
P_torch = F.softmax(X, dim=1)

# Verify forward pass
print( "P_triton:\n", P_triton)
print( "P_torch:\n", P_torch)
print("Forward Pass Match:", torch.allclose(P_triton, P_torch))

# Backward pass

dX_triton = softmax_bwd(P_triton, dP)
P_torch.backward(dP)
dX_torch = X.grad

# Verify backward pass
print( "dX_triton:\n", dX_triton)
print( "dX_torch:\n", dX_torch)
print("Backward Pass Match:", torch.allclose(dX_triton, dX_torch))

output:

P_triton:
tensor([[0.0900, 0.2447, 0.6652],
[0.0159, 0.1173, 0.8668]], device='cuda:0')
P_torch:
tensor([[0.0900, 0.2447, 0.6652],
[0.0159, 0.1173, 0.8668]], device='cuda:0', grad_fn=<SoftmaxBackward0>)
Forward Pass Match: True
dX_triton:
tensor([[-0.0381, -0.0792, 0.1173],
[-0.0043, -0.0202, 0.0245]], device='cuda:0')
dX_torch:
tensor([[-0.0381, -0.0792, 0.1173],
[-0.0043, -0.0202, 0.0245]], device='cuda:0')
Backward Pass Match: True

5. Results: speed comparison

The performance comparison between PyTorch and Triton implementations reveals:

Figureย 1: forward pass
Figureย 2: backward pass

Results show

  • forward pass: triton implementation is stable, while the PyTorch implementation is faster for most batch sizes but shows fluctuations for a few.
  • backward pass: triton implementation outperforms the pytorch implementation across most batch sizes. (the comparison may not be entirely fair, as triton caches the output ๐‘‚, whereas pytorch's handling intermediate values is unclear.)

6. Notations

symbol shape definition
๐’™ ๐‘‘ Input vector
๐’ ๐‘‘ Output vector (probability distribution)
๐ฟ Scalar Loss function
๐‘ฑ ๐‘‘ร—๐‘‘ Jacobian matrix
๐‘ฟ ๐‘ร—๐‘‘ Batch of input vectors (matrix)
๐‘ถ ๐‘ร—๐‘‘ Batch output probabilities
โˆ‚๐ฟโˆ‚๐‘ถ ๐‘ร—๐‘‘ Gradient w.r.t. output probabilities
โˆ‚๐ฟโˆ‚๐‘ฟ ๐‘ร—๐‘‘ Gradient w.r.t. input vectors
๐‘ grad ๐‘ร—1 Summation of gradients, ๐‘ gradย =(๐‘ถโˆ—โˆ‚๐ฟโˆ‚๐‘ถ)ย sum

Note:

  • Symbols like ๐‘ฅ, ๐’™, ๐‘ฟ represent scalars, vectors, or matrices, where uppercase denotes batch forms.
  • ๐‘ฟ:,๐‘– denotes a column vector, ๐‘ฟ๐‘–,: denotes a row vector, ๐‘ฟ๐‘–,๐‘— and denote the (๐‘–,๐‘—)-th element
  • ๐’™๐‘– denote the ๐‘–-th element.