In this post, we show that a modified attention map, which has softmax replaced by a different normalizing map, has a simpler total derivative. The goal is to improve the computational efficiency of backpropagation through attention-type maps.
Standard attention
The standard attention map is defined by
where
where
Note: Our definition of the attention map does not scale the entries of
The total derivative of
where
By the Leibniz rule, the partial derivatives of
where
where
where
In total, we have
The quantities in blue can be re-used from the forward pass.
In the next section, we consider how the above result changes if we replace
Other normalizing maps
In this section,
Replacing
where
Repeating the analysis of the previous section, we have
and the total derivative of
where
In total, we have
The quantities in blue can be re-used from the forward pass.
The punchline is that the homogeneity property could make the computation
of
For a particular example, consider the simple normalizing map
The total derivatives of
By definition,
This is clearly not possible for
Rearranging terms, we see that
Plausibility of replacing
Putting aside potential efficiency gains, can we learn effectively with
To quickly test this for
we can build on the nanoGPT
project. In a nutshell, we need to implement
The nn.Module
that implements
"""Implementation of beta map."""
import torch
import torch.nn as nn
from torch import Tensor
class Beta(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: Tensor) -> Tensor:
"""Compute beta() along last dimension of x."""
return x / (1.0 + torch.norm(x, dim=-1, keepdim=True))
Training character-level Tiny Shakespeare for 10000 iterations with nanoGPT
defaults:
step 10000: train loss 1.0040, val loss 1.5550
The final validation loss is comparable to the results obtained using standard attention:
step 10000: train loss 0.7140, val loss 1.6204