Attention and its gradient
1. Background
Until now, the official flashattn implementation does not support bias term. Flexattention in torch
is trying to support the bias term now. In this blog, I will show how to implement a minimal flashattn with trainable bias term.
2. Attention
The attention with gradient-enabled bias term is defined as:
where
- is the bias term and the shape is
- The shape of is
- is batch size, is heads number, is sequence length, is hidden dimension.
The gradient of is accumulated during the training process.
3. Backprop Derivation
Let
We already have the gradient of is
3.1. Gradient of and
Since
, we get
3.2. Gradient of
It is easy to get the gradient of based on chain rule:
where is the Jacobian of softmax function and has size . : Indices of the target tensor . : Summation indices, specifying contraction over these dimensions. The explicitly indicates summation over the indices and .
For efficiency, we can rewrite the above equation as:
where , summation vector to normalize contributions.
3.3. Gradient of
The gradient of is the same as the gradient of , which is:
3.4. Gradient of ,
The gradient of and is:
3.5. All gradients
4. PyTorch implementation
import torch
def forward(Q, K, V, B, d):
S = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d, dtype=torch.float32)) + B
A = torch.softmax(S, dim=-1)
O = torch.matmul(A, V)
return O, A, S
@torch.no_grad
def compute_gradients(Q, K, V, B, d, dO):
# Compute forward pass
S = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d, dtype=torch.float32)) + B
A = torch.softmax(S, dim=-1)
O = torch.matmul(A, V)
# Gradient of V and A
dA = torch.matmul(dO, V.transpose(-2, -1))
dV = torch.matmul(A.transpose(-2, -1), dO)
# Gradient of S using Jacobian-vector product (JVP)
dS = dA * A - (A * dA).sum(dim=-1, keepdim=True) * A
# dS = dA * A - torch.matmul(dA * A, A.transpose(-2, -1))
# Gradient of B (same as dS)
dB = dS.clone()
# Gradient of Q and K
dQ = torch.matmul(dS, K) / torch.sqrt(torch.tensor(d, dtype=torch.float32))
dK = torch.matmul(dS.transpose(-2, -1), Q) / torch.sqrt(torch.tensor(d, dtype=torch.float32))
return dQ, dK, dV, dB
# Example usage
n, h, l, d = 2, 4, 8, 16
torch.manual_seed(0)
Q = torch.randn(n, h, l, d, requires_grad=True)
K = torch.randn(n, h, l, d, requires_grad=True)
V = torch.randn(n, h, l, d, requires_grad=True)
B = torch.randn(n, h, l, l, requires_grad=True)
dO = torch.randn(n, h, l, d)
O, A, S = forward(Q, K, V, B, d)
dQ, dK, dV, dB = compute_gradients(Q, K, V, B, d, dO)
# Verify correctness with autograd
O.backward(dO, retain_graph=True)
print( V.grad[0][0][0])
print( dV[0][0][0] )
print( B.grad[0][0][0])
print( dB[0][0][0] )
print( Q.grad[0][0][0])
print( dQ[0][0][0] )
assert torch.allclose(V.grad, dV, atol=1e-5), "dV mismatch"
assert torch.allclose(B.grad, dB, atol=1e-5), "dB mismatch"
assert torch.allclose(Q.grad, dQ, atol=1e-5), "dQ mismatch"
assert torch.allclose(K.grad, dK, atol=1e-5), "dK mismatch"
print("Autograd verification passed.")
print("O:", O.shape)
print("dQ:", dQ.shape)
print("dK:", dK.shape)
print("dV:", dV.shape)
print("dB:", dB.shape)
Output:
tensor([-0.9583, -0.7990, -0.7401, 0.4045, -1.1326, -0.8535, 0.9846, 0.8070,
-0.6478, -0.0538, 0.6266, 1.0380, -0.9200, 0.5653, 0.9200, -0.0638])
tensor([-0.9583, -0.7990, -0.7401, 0.4045, -1.1326, -0.8535, 0.9846, 0.8070,
-0.6478, -0.0538, 0.6266, 1.0380, -0.9200, 0.5653, 0.9200, -0.0638])
tensor([-8.4880e-02, -6.7330e-01, -5.2291e-04, 3.3246e-02, -2.7012e-02,
5.0888e-01, 2.4558e-01, -1.9837e-03])
tensor([-8.4880e-02, -6.7330e-01, -5.2293e-04, 3.3246e-02, -2.7012e-02,
5.0888e-01, 2.4558e-01, -1.9838e-03])
tensor([-0.1274, -0.2580, 0.2316, 0.1266, -0.3056, 0.0579, -0.2824, 0.2191,
-0.0199, 0.2176, -0.0755, -0.1700, 0.1564, 0.2221, -0.0909, 0.0172])
tensor([-0.1274, -0.2580, 0.2316, 0.1266, -0.3056, 0.0579, -0.2824, 0.2191,
-0.0199, 0.2176, -0.0755, -0.1700, 0.1564, 0.2221, -0.0909, 0.0172])
Autograd verification passed.
O: torch.Size([2, 4, 8, 16])
dQ: torch.Size([2, 4, 8, 16])
dK: torch.Size([2, 4, 8, 16])
dV: torch.Size([2, 4, 8, 16])
dB: torch.Size([2, 4, 8, 8])