Attention and its gradient
  • 10/20/2024
  • Xiaotian Han

  • background

    Untill now, the offical flashattn implementation does not support bias term. Flexattention in torch is trying to support the bias term now. In this blog, I will show how to implement a minimal flashattn with trainable bias term.

    TL;DR
    • Gradient-enabled bias term is required for most of protein languange model, like evoformer.
    • Trainable bias term need to accumiluat the gradient.

    Math

    The attention with gradient-enabled bias term is defined as:

    \[\mathbf{O} = \text{softmax}\left(\frac{\mathbf{QK}^\top}{\sqrt{d}} + {\color{red}\mathbf{B}}\right)\mathbf{V}\]

    where

    • \({\color{red}\mathbf{B}}\) is the bias term and the shape is \((n, h, l, l)\)
    • The shape of \(\mathbf{Q}, \mathbf{K}, \mathbf{V}\) is \((n, h, l, d)\),
    • \(n\) is batch size, \(h\) is heads number, \(l\) is sequence length, \(d\) is hidden dimension.

    The gradient of \(\mathbf{B}\) is accumilated during the training process.

    backprop derivation

    Let

    \[\begin{aligned} \mathbf{S} &= \frac{\mathbf{QK}^\top}{\sqrt{d}} + \mathbf{B} \\ \mathbf{A} &= \text{softmax}\left( \mathbf{S} \right) = \text{softmax}\left(\frac{\mathbf{QK}^\top}{\sqrt{d}} + \mathbf{B}\right) \\ \mathbf{O} &= \mathbf{AV} = \text{softmax}\left( \mathbf{S} \right)\mathbf{V} = \text{softmax}\left(\frac{\mathbf{QK}^\top}{\sqrt{d}} + \mathbf{B}\right)\mathbf{V} \end{aligned}\]

    We already have the gradient of \(\mathbf{O}\) is \(\frac{\partial \mathcal{L}}{\partial \mathbf{O}} \quad ([n,\,h,\,l,\,d]).\)

    In the following, we think of each \((n,h)\) slice as a separate matrix multiply.

    gradient of \(\mathbf{V}\) and \(\mathbf{A}\)

    Since \(\mathbf{O} = \mathbf{AV} \quad ([n,h,l,d] = [n,h,l,l] \times [n,h,l,d])\), we get

    \[\frac{\partial \mathcal{L}}{\partial \mathbf{A}}=\frac{\partial \mathcal{L}}{\partial \mathbf{O}}\;\mathrm{bmm}\;(\mathbf{V}^\top), \quad ([n,h,l,l] = [n,h,l,l] \times [n,h,l,d])\] \[\frac{\partial \mathcal{L}}{\partial \mathbf{V}}= \mathbf{A}^\top\;\mathrm{bmm}\;\frac{\partial \mathcal{L}}{\partial \mathbf{O}}, \quad ([n,h,l,d] = [n,h,l,l] \times [n,h,l,d])\]

    gradient of \(\mathbf{S}\)

    It is easy to get the gradient of \(\mathbf{S}\) based on chain rule:

    \[\frac{\partial \mathcal{L}}{\partial \mathbf{S}}_{ijkl} = \sum_{m,n} \frac{\partial \mathbf{A}_{ijmn}}{\partial \mathbf{S}_{ijkl}} \frac{\partial \mathcal{L}}{\partial \mathbf{A}_{ijmn}}\]

    where \(\frac{\partial \mathbf{A}}{\partial \mathbf{S}}\) is the Jacobian of softmax function and has size \((n,h,l,l,l,l)\). \(i, j, k, l\): Indices of the target tensor \(\frac{\partial \mathcal{L}}{\partial \mathbf{S}}\). \(m, n\): Summation indices, specifying contraction over these dimensions. The \(\sum_{m, n}\) explicitly indicates summation over the indices \(m\) and \(n\).

    For efficiency, we can rewrite the above equation as:

    \[\frac{\partial \mathcal{L}}{\partial \mathbf{S}} = \frac{\partial \mathcal{L}}{\partial \mathbf{A}}\odot\left(\frac{\partial \mathcal{L}}{\partial \mathbf{A}} - \left(\frac{\partial \mathcal{L}}{\partial \mathbf{A}}\cdot \mathbf{A}^\top\right)\mathbf{1}\right) \quad ([n,h,l,l] = [n,h,l,l] \odot ([n,h,l,l] \; \mathrm{bmm} \; [n,h,l,l] \cdot [n,h,l,1] ))\]

    where \(\mathbf{1} \in [n,h,l,1]\), summation vector to normalize contributions.

    gradient of \(\mathbf{B}\)

    The gradient of \(\mathbf{B}\) is the same as the gradient of \(\mathbf{S}\), which is:

    \[\frac{\partial \mathcal{L}}{\partial \mathbf{B}} = \frac{\partial \mathcal{L}}{\partial \mathbf{S}}\]

    gradient of \(\mathbf{Q}\), \(\mathbf{K}\)

    The gradient of \(\mathbf{Q}\) and \(\mathbf{K}\) is:

    \[\begin{aligned} \frac{\partial \mathcal{L}}{\partial \mathbf{Q}} &= \frac{\partial \mathcal{L}}{\partial \mathbf{S}}\cdot \mathbf{K} \\ \frac{\partial \mathcal{L}}{\partial \mathbf{K}} &= \frac{\partial \mathcal{L}}{\partial \mathbf{S}}\cdot \mathbf{Q} \end{aligned}\]

    all gradients

    \[\begin{aligned} \frac{\partial \mathcal{L}}{\partial \mathbf{Q}} &= \frac{\partial \mathcal{L}}{\partial \mathbf{S}}\cdot \mathbf{K} \\ \frac{\partial \mathcal{L}}{\partial \mathbf{K}} &= \frac{\partial \mathcal{L}}{\partial \mathbf{S}}\cdot \mathbf{Q} \\ \frac{\partial \mathcal{L}}{\partial \mathbf{V}} &= \mathbf{A}^\top\;\mathrm{bmm}\;\frac{\partial \mathcal{L}}{\partial \mathbf{O}} \\ \frac{\partial \mathcal{L}}{\partial \mathbf{A}} &= \frac{\partial \mathcal{L}}{\partial \mathbf{O}}\;\mathrm{bmm}\;(\mathbf{V}^\top) \\ \frac{\partial \mathcal{L}}{\partial \mathbf{S}} &= \frac{\partial \mathcal{L}}{\partial \mathbf{A}}\odot\left(\frac{\partial \mathcal{L}}{\partial \mathbf{A}} - \left(\frac{\partial \mathcal{L}}{\partial \mathbf{A}}\cdot \mathbf{A}^\top\right)\mathbf{1}\right) \\ \frac{\partial \mathcal{L}}{\partial \mathbf{B}} &= \frac{\partial \mathcal{L}}{\partial \mathbf{S}} \end{aligned}\]

    pytorch implementation

    import torch
    
    def forward(Q, K, V, B, d):
        S = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d, dtype=torch.float32)) + B
        A = torch.softmax(S, dim=-1)
        O = torch.matmul(A, V)
        return O, A, S
    
    @torch.no_grad
    def compute_gradients(Q, K, V, B, d, dO):
        # Compute forward pass
        S = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d, dtype=torch.float32)) + B
        A = torch.softmax(S, dim=-1)
        O = torch.matmul(A, V)
    
        # Gradient of V and A
        dA = torch.matmul(dO, V.transpose(-2, -1))
        dV = torch.matmul(A.transpose(-2, -1), dO)
    
        # Gradient of S using Jacobian-vector product (JVP)
        dS = dA * A - (A * dA).sum(dim=-1, keepdim=True) * A
        # dS = dA * A - torch.matmul(dA * A, A.transpose(-2, -1))
    
        # Gradient of B (same as dS)
        dB = dS.clone()
    
        # Gradient of Q and K
        dQ = torch.matmul(dS, K) / torch.sqrt(torch.tensor(d, dtype=torch.float32))
        dK = torch.matmul(dS.transpose(-2, -1), Q) / torch.sqrt(torch.tensor(d, dtype=torch.float32))
    
        return dQ, dK, dV, dB
    
    
    # Example usage
    n, h, l, d = 2, 4, 8, 16
    torch.manual_seed(0)
    Q = torch.randn(n, h, l, d, requires_grad=True)
    K = torch.randn(n, h, l, d, requires_grad=True)
    V = torch.randn(n, h, l, d, requires_grad=True)
    B = torch.randn(n, h, l, l, requires_grad=True)
    dO = torch.randn(n, h, l, d)
    
    O, A, S = forward(Q, K, V, B, d)
    dQ, dK, dV, dB = compute_gradients(Q, K, V, B, d, dO)
    
    # Verify correctness with autograd
    O.backward(dO, retain_graph=True)
    
    
    
    
    print( V.grad[0][0][0])
    print( dV[0][0][0]  )
    
    print( B.grad[0][0][0])
    print( dB[0][0][0]  )
    
    print( Q.grad[0][0][0])
    print( dQ[0][0][0]  )
    
    
    
    assert torch.allclose(V.grad, dV, atol=1e-5), "dV mismatch"
    assert torch.allclose(B.grad, dB, atol=1e-5), "dB mismatch"
    assert torch.allclose(Q.grad, dQ, atol=1e-5), "dQ mismatch"
    assert torch.allclose(K.grad, dK, atol=1e-5), "dK mismatch"
    
    
    print("Autograd verification passed.")
    
    print("O:", O.shape)
    print("dQ:", dQ.shape)
    print("dK:", dK.shape)
    print("dV:", dV.shape)
    print("dB:", dB.shape)
    

    output:

    tensor([-0.9583, -0.7990, -0.7401,  0.4045, -1.1326, -0.8535,  0.9846,  0.8070,
            -0.6478, -0.0538,  0.6266,  1.0380, -0.9200,  0.5653,  0.9200, -0.0638])
    tensor([-0.9583, -0.7990, -0.7401,  0.4045, -1.1326, -0.8535,  0.9846,  0.8070,
            -0.6478, -0.0538,  0.6266,  1.0380, -0.9200,  0.5653,  0.9200, -0.0638])
    tensor([-8.4880e-02, -6.7330e-01, -5.2291e-04,  3.3246e-02, -2.7012e-02,
             5.0888e-01,  2.4558e-01, -1.9837e-03])
    tensor([-8.4880e-02, -6.7330e-01, -5.2293e-04,  3.3246e-02, -2.7012e-02,
             5.0888e-01,  2.4558e-01, -1.9838e-03])
    tensor([-0.1274, -0.2580,  0.2316,  0.1266, -0.3056,  0.0579, -0.2824,  0.2191,
            -0.0199,  0.2176, -0.0755, -0.1700,  0.1564,  0.2221, -0.0909,  0.0172])
    tensor([-0.1274, -0.2580,  0.2316,  0.1266, -0.3056,  0.0579, -0.2824,  0.2191,
            -0.0199,  0.2176, -0.0755, -0.1700,  0.1564,  0.2221, -0.0909,  0.0172])
    Autograd verification passed.
    O: torch.Size([2, 4, 8, 16])
    dQ: torch.Size([2, 4, 8, 16])
    dK: torch.Size([2, 4, 8, 16])
    dV: torch.Size([2, 4, 8, 16])
    dB: torch.Size([2, 4, 8, 8])