Research Article - Xiaotian Han | Academic Insights

Attention and its gradient

1.Background
2.Attention
3.Backprop Derivation
4.PyTorch implementation

1. Background

Until now, the official flashattn implementation does not support bias term. Flexattention in torch is trying to support the bias term now. In this blog, I will show how to implement a minimal flashattn with trainable bias term.

2. Attention

The attention with gradient-enabled bias term is defined as:

𝑶=softmax(𝑸𝑲𝑇𝑑+𝑩)𝑽

where

  • 𝑩 is the bias term and the shape is (𝑛,,𝑙,𝑙)
  • The shape of 𝑸,𝑲,𝑽 is (𝑛,,𝑙,𝑑)
  • 𝑛 is batch size, is heads number, 𝑙 is sequence length, 𝑑 is hidden dimension.

The gradient of 𝑩 is accumulated during the training process.

3. Backprop Derivation

Let

𝑺=𝑸𝑲𝑇𝑑+𝑩𝑨=softmax(𝑺)=softmax(𝑸𝑲𝑇𝑑+𝑩)𝑶=𝑨𝑽=softmax(𝑺)𝑽=softmax(𝑸𝑲𝑇𝑑+𝑩)𝑽

We already have the gradient of 𝑶 is

𝑶([𝑛,,𝑙,𝑑]).
In the following, we think of each (𝑛,) slice as a separate matrix multiply.

3.1. Gradient of 𝑽 and 𝑨

Since

𝑶=𝑨𝑽([𝑛,,𝑙,𝑑]=[𝑛,,𝑙,𝑙]×[𝑛,,𝑙,𝑑])

, we get

𝑨=𝑶 bmm (𝑽𝑇),([𝑛,,𝑙,𝑙]=[𝑛,,𝑙,𝑙]×[𝑛,,𝑙,𝑑])𝑽=𝑨𝑇 bmm 𝑶,([𝑛,,𝑙,𝑑]=[𝑛,,𝑙,𝑙]×[𝑛,,𝑙,𝑑])

3.2. Gradient of 𝑺

It is easy to get the gradient of 𝑺 based on chain rule:

(𝑺)𝑖𝑗𝑘𝑙=𝑚,𝑛𝑨𝑖𝑗𝑚𝑛𝑺𝑖𝑗𝑘𝑙𝑨𝑖𝑗𝑚𝑛

where 𝑨𝑺 is the Jacobian of softmax function and has size (𝑛,,𝑙,𝑙,𝑙,𝑙). 𝑖,𝑗,𝑘,𝑙: Indices of the target tensor 𝑺. 𝑚,𝑛: Summation indices, specifying contraction over these dimensions. The 𝑚,𝑛 explicitly indicates summation over the indices 𝑚 and 𝑛.

For efficiency, we can rewrite the above equation as:

𝑺=𝑨(𝑨(𝑨𝑨𝑇)𝟏)([𝑛,,𝑙,𝑙]=[𝑛,,𝑙,𝑙]([𝑛,,𝑙,𝑙] bmm [𝑛,,𝑙,𝑙][𝑛,,𝑙,1]))

where 𝟏[𝑛,,𝑙,1], summation vector to normalize contributions.

3.3. Gradient of 𝑩

The gradient of 𝑩 is the same as the gradient of 𝑺, which is:

𝑩=𝑺

3.4. Gradient of 𝑸, 𝑲

The gradient of 𝑸 and 𝑲 is:

𝑸=𝑺𝑲𝑲=𝑺𝑸

3.5. All gradients

𝑸=𝑺𝑲𝑲=𝑺𝑸𝑽=𝑨𝑇 bmm 𝑶𝑨=𝑶 bmm (𝑽𝑇)𝑺=𝑨(𝑨(𝑨𝑨𝑇)𝟏)𝑩=𝑺

4. PyTorch implementation

import torch

def forward(Q, K, V, B, d):
S = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d, dtype=torch.float32)) + B
A = torch.softmax(S, dim=-1)
O = torch.matmul(A, V)
return O, A, S

@torch.no_grad
def compute_gradients(Q, K, V, B, d, dO):
# Compute forward pass
S = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d, dtype=torch.float32)) + B
A = torch.softmax(S, dim=-1)
O = torch.matmul(A, V)

# Gradient of V and A
dA = torch.matmul(dO, V.transpose(-2, -1))
dV = torch.matmul(A.transpose(-2, -1), dO)

# Gradient of S using Jacobian-vector product (JVP)
dS = dA * A - (A * dA).sum(dim=-1, keepdim=True) * A
# dS = dA * A - torch.matmul(dA * A, A.transpose(-2, -1))

# Gradient of B (same as dS)
dB = dS.clone()

# Gradient of Q and K
dQ = torch.matmul(dS, K) / torch.sqrt(torch.tensor(d, dtype=torch.float32))
dK = torch.matmul(dS.transpose(-2, -1), Q) / torch.sqrt(torch.tensor(d, dtype=torch.float32))

return dQ, dK, dV, dB


# Example usage
n, h, l, d = 2, 4, 8, 16
torch.manual_seed(0)
Q = torch.randn(n, h, l, d, requires_grad=True)
K = torch.randn(n, h, l, d, requires_grad=True)
V = torch.randn(n, h, l, d, requires_grad=True)
B = torch.randn(n, h, l, l, requires_grad=True)
dO = torch.randn(n, h, l, d)

O, A, S = forward(Q, K, V, B, d)
dQ, dK, dV, dB = compute_gradients(Q, K, V, B, d, dO)

# Verify correctness with autograd
O.backward(dO, retain_graph=True)




print( V.grad[0][0][0])
print( dV[0][0][0] )

print( B.grad[0][0][0])
print( dB[0][0][0] )

print( Q.grad[0][0][0])
print( dQ[0][0][0] )



assert torch.allclose(V.grad, dV, atol=1e-5), "dV mismatch"
assert torch.allclose(B.grad, dB, atol=1e-5), "dB mismatch"
assert torch.allclose(Q.grad, dQ, atol=1e-5), "dQ mismatch"
assert torch.allclose(K.grad, dK, atol=1e-5), "dK mismatch"


print("Autograd verification passed.")

print("O:", O.shape)
print("dQ:", dQ.shape)
print("dK:", dK.shape)
print("dV:", dV.shape)
print("dB:", dB.shape)

Output:

tensor([-0.9583, -0.7990, -0.7401,  0.4045, -1.1326, -0.8535,  0.9846,  0.8070,
-0.6478, -0.0538, 0.6266, 1.0380, -0.9200, 0.5653, 0.9200, -0.0638])
tensor([-0.9583, -0.7990, -0.7401, 0.4045, -1.1326, -0.8535, 0.9846, 0.8070,
-0.6478, -0.0538, 0.6266, 1.0380, -0.9200, 0.5653, 0.9200, -0.0638])
tensor([-8.4880e-02, -6.7330e-01, -5.2291e-04, 3.3246e-02, -2.7012e-02,
5.0888e-01, 2.4558e-01, -1.9837e-03])
tensor([-8.4880e-02, -6.7330e-01, -5.2293e-04, 3.3246e-02, -2.7012e-02,
5.0888e-01, 2.4558e-01, -1.9838e-03])
tensor([-0.1274, -0.2580, 0.2316, 0.1266, -0.3056, 0.0579, -0.2824, 0.2191,
-0.0199, 0.2176, -0.0755, -0.1700, 0.1564, 0.2221, -0.0909, 0.0172])
tensor([-0.1274, -0.2580, 0.2316, 0.1266, -0.3056, 0.0579, -0.2824, 0.2191,
-0.0199, 0.2176, -0.0755, -0.1700, 0.1564, 0.2221, -0.0909, 0.0172])
Autograd verification passed.
O: torch.Size([2, 4, 8, 16])
dQ: torch.Size([2, 4, 8, 16])
dK: torch.Size([2, 4, 8, 16])
dV: torch.Size([2, 4, 8, 16])
dB: torch.Size([2, 4, 8, 8])