In this post, we show that a modified attention map, which has softmax replaced by a different normalizing map, has a simpler total derivative. The goal is to improve the computational efficiency of backpropagation through attention-type maps.

Standard attention

The standard attention map is defined by

Att:Rn×d×Rn×d×Rn×dRn×d(Q,K,V)Att(Q,K,V)=σ(QKt)V,

where σ:RnRn is the softmax map applied row-wise. This means that

Att(Q,K,V)=i=1neiσ((eitQKt)t)tV=i=1neiσ(KQtei)tV,

where ei is the ith Euclidean basis vector in Rn. This expression could be written in matrix form, of course, but it seems cleaner to put everything on a single line.

Note: Our definition of the attention map does not scale the entries of QKt by 1/d. This is just for convenience, to avoid writing the square root everywhere.

The total derivative of Att is

dAtt(Θ)Θ~=dQAtt(Θ)Q~+dKAtt(Θ)K~+dVAtt(Θ)V~,

where Θ=(Q,K,V) and Θ~=(Q~,K~,V~).

By the Leibniz rule, the partial derivatives of Att are

dQAtt(Θ)Q~=i=1nei[dσ(pi)q~i]tVdKAtt(Θ)K~=i=1nei[dσ(pi)k~i]tVdVAtt(Θ)V~=i=1neiσ(pi)tV~=Att(Q,K,V~),

where pi=KQtei, q~i=KQ~tei, and k~i=K~Qtei. Combining terms, we have

dAtt(Θ)Θ~=i=1nei[dσ(pi)z~i]tV+Att(Q,K,V~),

where z~i=q~i+k~i. Using the well-known formula for dσ, we have

i=1nei[dσ(pi)z~i]tV=i=1nei(σ(pi)z~i)tVi=1nσ(pi),z~ieiσ(pi)tV=i=1nei(σ(pi)z~i)tV(ι~(Q,K)1nt)Att(Q,K,V),

where is the element-wise product and ι~(Q,K)1nt is the Kronecker product of

ι~(Q,K)=[σ(p1),z~1σ(pn),z~n]and1nt=(1,,1).

In total, we have

dAtt(Θ)Θ~=i=1nei(σ(pi)z~i)tV(ι~(Q,K)σ(pi)1nt)Att(Θ)+Att(Q,K,V~)i=1neiσ(pi)t.

The quantities in blue can be re-used from the forward pass.

In the next section, we consider how the above result changes if we replace σ.

Other normalizing maps

In this section, β:RnRn is a smooth map with the following homogeneity property: There exists a smooth function f:Rn×RnR such that

dβ(x)h=f(x,h)β(x),

Replacing σ with β, the β-attention map is defined by

Attβ:Rn×d×Rn×d×Rn×dRn×d(Q,K,V)Attβ(Q,K,V)=β(QKt)V,

where β is applied row-wise.

Repeating the analysis of the previous section, we have

Attβ(Θ)=i=1neiβ(pi)tV

and the total derivative of Attβ at Θ is

dAttβ(Θ)Θ~=i=1nei[dβ(pi)z~i]tV+Attβ(Q,K,V~)=i=1nf(pi,z~i)eiβ(pi)tV+Attβ(Q,K,V~)=(f~(Q,K)1nt)Attβ(Θ)+Attβ(Q,K,V~),

where pi, z~i are defined as in the previous section and

f~(Q,K)=[f(p1,z~1)f(pn,z~n)].

In total, we have

dAttβ(Θ)Θ~=(f~(Q,K)pi1nt)Attβ(Θ)+Attβ(Q,K,V~)i=1neiβ(pi)t.

The quantities in blue can be re-used from the forward pass.

The punchline is that the homogeneity property could make the computation of dAttβ(Θ)Θ~ more efficient by “removing” the first term of dAtt(Θ)Θ~.

For a particular example, consider the simple normalizing map

β(x)=x1+x.

The total derivatives of β are

dβ(x)h={h1+xx,hxx(1+x)2,x0nh,x=0n.

By definition, f must satisfy

f(x,h)x1+x={h1+xx,hxx(1+x)2,x0nh,x=0n.

This is clearly not possible for x=0n, so from this point forward let’s assume that we are working away from x=0n (this can be made rigorous, but we’ll skip that).

Rearranging terms, we see that

f(x,h)=x,hx2x,hx(1+x)=x,hx2(1+x).

Plausibility of replacing σ

Putting aside potential efficiency gains, can we learn effectively with Attβ?

To quickly test this for

β(x)=x1+x,

we can build on the nanoGPT project. In a nutshell, we need to implement β, disable flash attention, and adjust the causal masking to accommodate β.

The nn.Module that implements β is very straightforward:

"""Implementation of beta map."""

import torch
import torch.nn as nn
from torch import Tensor

class Beta(nn.Module):

    def __init__(self) -> None:
        super().__init__()

    def forward(self, x: Tensor) -> Tensor:
        """Compute beta() along last dimension of x."""
        return x / (1.0 + torch.norm(x, dim=-1, keepdim=True))

Training character-level Tiny Shakespeare for 10000 iterations with nanoGPT defaults:

step 10000: train loss 1.0040, val loss 1.5550

The final validation loss is comparable to the results obtained using standard attention:

step 10000: train loss 0.7140, val loss 1.6204