In this post, we show that a modified attention map, which has softmax replaced by a different normalizing map, has a simpler total derivative. The goal is to improve the computational efficiency of backpropagation through attention-type maps.

$$ \newcommand{\Attn}{\mathrm{Att}} \newcommand{\bR}{\mathbb{R}} $$

Standard attention

The standard attention map is defined by

$$ \begin{align*} \Attn : \bR^{n \times d} \times \bR^{n \times d} \times \bR^{n \times d} &\to \bR^{n \times d} \\ (Q, K, V) &\mapsto \Attn(Q, K, V) = \sigma(QK^t) V, \end{align*} $$

where \(\sigma : \bR^n \to \bR^n\) is the softmax map applied row-wise. This means that

$$ \begin{align*} \Attn(Q, K, V) &= \sum_{i=1}^{n} e_i \sigma((e_i^t Q K^t)^t)^t V \\ &= \sum_{i=1}^{n} e_i \sigma(K Q^t e_i)^t V, \end{align*} $$

where \(e_i\) is the \(i\)th Euclidean basis vector in \(\bR^n\). This expression could be written in matrix form, of course, but it seems cleaner to put everything on a single line.

Note: Our definition of the attention map does not scale the entries of \(QK^t\) by \(1/\sqrt{d}\). This is just for convenience, to avoid writing the square root everywhere.

The total derivative of \(\Attn\) is

$$ \begin{align*} d \Attn(\Theta) \cdot \tilde{\Theta} = d_Q \Attn(\Theta) \cdot \tilde{Q} + d_K \Attn(\Theta) \cdot \tilde{K} + d_V \Attn(\Theta) \cdot \tilde{V}, \end{align*} $$

where \(\Theta = (Q,K,V)\) and \(\tilde{\Theta} = (\tilde{Q},\tilde{K},\tilde{V})\).

By the Leibniz rule, the partial derivatives of \(\Attn\) are

$$ \begin{align*} d_Q \Attn(\Theta) \cdot \tilde{Q} &= \sum_{i=1}^{n} e_i [d \sigma(p_i) \cdot \tilde{q}_i]^t V \\ d_K \Attn(\Theta) \cdot \tilde{K} &= \sum_{i=1}^{n} e_i [d \sigma(p_i) \cdot \tilde{k}_i]^t V \\ d_V \Attn(\Theta) \cdot \tilde{V} &= \sum_{i=1}^{n} e_i \sigma(p_i)^t \tilde{V} = \Attn(Q, K, \tilde{V}), \end{align*} $$

where \(p_i = K Q^t e_i\), \(\tilde{q}_i = K \tilde{Q}^t e_i\), and \(\tilde{k}_i = \tilde{K} Q^t e_i\). Combining terms, we have

$$ \begin{align*} d \Attn(\Theta) \cdot \tilde{\Theta} &= \sum_{i=1}^{n} e_i [d \sigma(p_i) \cdot \tilde{z}_i]^t V + \Attn(Q, K, \tilde{V}), \end{align*} $$

where \(\tilde{z}_i = \tilde{q}_i + \tilde{k}_i\). Using the well-known formula for \(d \sigma\), we have

$$ \begin{align*} &{} \sum_{i=1}^{n} e_i [d \sigma(p_i) \cdot \tilde{z}_i]^t V \\ &\qquad = \sum_{i=1}^{n} e_i (\sigma(p_i) \odot \tilde{z}_i)^t V - \sum_{i=1}^{n} \langle \sigma(p_i), \tilde{z}_i \rangle e_i \sigma(p_i)^t V \\ &\qquad = \sum_{i=1}^{n} e_i (\sigma(p_i) \odot \tilde{z}_i)^t V - (\tilde{\iota}(Q,K) \otimes 1_n^t) \odot \Attn(Q,K,V), \end{align*} $$

where \(\odot\) is the element-wise product and \(\tilde{\iota}(Q,K) \otimes 1_n^t\) is the Kronecker product of

$$ \tilde{\iota}(Q,K) = \begin{bmatrix} \langle \sigma(p_1), \tilde{z}_1 \rangle \\ \vdots \\ \langle \sigma(p_n), \tilde{z}_n \rangle \end{bmatrix} \qquad \mbox{and} \qquad 1_n^t = (1, \dots, 1). $$

In total, we have

$$ \begin{align*} d\Attn(\Theta) \cdot \tilde{\Theta} &= \sum_{i=1}^{n} e_i ({\color{cornflowerblue}\sigma(p_i)} \odot \tilde{z}_i)^t V - (\underbrace{\tilde{\iota}(Q,K)}_{\color{cornflowerblue}\sigma(p_i)} \otimes 1_n^t) \odot {\color{cornflowerblue}\Attn(\Theta)} + \underbrace{\Attn(Q,K,\tilde{V})}_{\color{cornflowerblue}\sum_{i=1}^{n} e_i \sigma(p_i)^t}. \end{align*} $$

The quantities in blue can be re-used from the forward pass.

In the next section, we consider how the above result changes if we replace \(\sigma\).

Other normalizing maps

In this section, \(\beta: \bR^n \to \bR^n\) is a smooth map with the following homogeneity property: There exists a smooth function \(f : \bR^n \times \bR^n \to \bR\) such that

$$ \begin{align*} d \beta(x) \cdot h = f (x, h) \beta(x), \end{align*} $$

Replacing \(\sigma\) with \(\beta\), the \(\beta\)-attention map is defined by

$$ \begin{align*} \Attn_\beta : \bR^{n \times d} \times \bR^{n \times d} \times \bR^{n \times d} &\to \bR^{n \times d} \\ (Q, K, V) &\mapsto \Attn_\beta(Q, K, V) = \beta(QK^t) V, \end{align*} $$

where \(\beta\) is applied row-wise.

Repeating the analysis of the previous section, we have

$$ \begin{align*} \Attn_\beta (\Theta) &= \sum_{i=1}^{n} e_i \beta(p_i)^t V \end{align*} $$

and the total derivative of \(\Attn_\beta\) at \(\Theta\) is

$$ \begin{align*} d \Attn_\beta (\Theta) \cdot \tilde{\Theta} &= \sum_{i=1}^{n} e_i [d \beta(p_i) \cdot \tilde{z}_i]^t V + \Attn_\beta (Q, K, \tilde{V}) \\ &= \sum_{i=1}^{n} f(p_i, \tilde{z}_i) e_i \beta(p_i)^t V + \Attn_\beta (Q, K, \tilde{V}) \\ &= (\tilde{f}(Q,K) \otimes 1_n^t) \odot \Attn_\beta (\Theta) + \Attn_\beta (Q, K, \tilde{V}), \end{align*} $$

where \(p_i\), \(\tilde{z}_i\) are defined as in the previous section and

$$ \tilde{f}(Q,K) = \begin{bmatrix} f(p_1, \tilde{z}_1) \\ \vdots \\ f(p_n, \tilde{z}_n) \end{bmatrix}. $$

In total, we have

$$ \begin{align*} d \Attn_\beta (\Theta) \cdot \tilde{\Theta} &= (\underbrace{\tilde{f}(Q,K)}_{\color{cornflowerblue}p_i} \otimes 1_n^t) \odot {\color{cornflowerblue}\Attn_\beta (\Theta)} + \underbrace{\Attn_\beta(Q,K,\tilde{V})}_{\color{cornflowerblue}\sum_{i=1}^{n} e_i \beta(p_i)^t}. \end{align*} $$

The quantities in blue can be re-used from the forward pass.

The punchline is that the homogeneity property could make the computation of \(d\Attn_\beta(\Theta) \cdot \tilde{\Theta}\) more efficient by “removing” the first term of \(d\Attn(\Theta)\cdot \tilde{\Theta}\).

For a particular example, consider the simple normalizing map

$$ \begin{align*} \beta(x) = \frac{x}{1 + \| x \|}. \end{align*} $$

The total derivatives of \(\beta\) are

$$ \begin{align*} d\beta(x) \cdot h &= \begin{cases} \displaystyle \frac{h}{1 + \|x\|} - \frac{\langle x, h \rangle x}{\|x\| (1 + \|x\|)^2}, & x \neq 0_n \\ h, & x = 0_n. \end{cases} \end{align*} $$

By definition, \(f\) must satisfy

$$ \begin{align*} f(x, h) \frac{x}{1 + \| x\|} &= \begin{cases} \displaystyle \frac{h}{1 + \|x\|} - \frac{\langle x, h \rangle x}{\| x\| (1 + \|x\|)^2}, & x \neq 0_n \\ h, & x = 0_n. \end{cases} \end{align*} $$

This is clearly not possible for \(x = 0_n\), so from this point forward let’s assume that we are working away from \(x = 0_n\) (this can be made rigorous, but we’ll skip that).

Rearranging terms, we see that

$$ \begin{align*} f(x, h) &= \frac{\langle x, h \rangle}{\|x\|^2} - \frac{\langle x, h \rangle}{\| x\| (1 + \|x\|)} \\ &= \frac{\langle x, h \rangle}{\|x\|^2 (1 + \|x\|)}. \end{align*} $$

Plausibility of replacing \(\sigma\)

Putting aside potential efficiency gains, can we learn effectively with \(\Attn_\beta\)?

To quickly test this for

$$ \begin{align*} \beta(x) = \frac{x}{1 + \| x \|}, \end{align*} $$

we can build on the nanoGPT project. In a nutshell, we need to implement \(\beta\), disable flash attention, and adjust the causal masking to accommodate \(\beta\).

The nn.Module that implements \(\beta\) is very straightforward:

"""Implementation of beta map."""

import torch
import torch.nn as nn
from torch import Tensor

class Beta(nn.Module):

    def __init__(self) -> None:
        super().__init__()

    def forward(self, x: Tensor) -> Tensor:
        """Compute beta() along last dimension of x."""
        return x / (1.0 + torch.norm(x, dim=-1, keepdim=True))

Training character-level Tiny Shakespeare for 10000 iterations with nanoGPT defaults:

step 10000: train loss 1.0040, val loss 1.5550

The final validation loss is comparable to the results obtained using standard attention:

step 10000: train loss 0.7140, val loss 1.6204