In this post, we will compute the second-order and third-order total derivatives of the softmax map. At some point, I thought that the higher-order derivatives could be used to cheaply compute approximations of the softmax map, but that’s a story for another post.
$$ \newcommand{\bR}{\mathbb{R}} \newcommand{\Diag}{\Delta} \newcommand{\Exp}{\mathrm{Exp}} \definecolor{magicmint}{rgb}{0.67, 0.94, 0.82} \definecolor{lesserbox}{rgb}{0.85, 0.95, 1.0} $$Softmax
To get started, recall that the softmax map \(\sigma : \bR^n \to \bR^n\) sends \(x\) to
$$ \sigma(x) = \begin{bmatrix} \displaystyle \frac{\exp(x^1)}{\sum_{i=1}^{n} \exp(x^i)} \\ \vdots \\ \displaystyle \frac{\exp(x^n)}{\sum_{i=1}^{n} \exp(x^i)} \end{bmatrix}, $$where \(x^i\) is the \(i\)th component of \(x\). Clearly, \(\sigma(x)\) is smooth and satisfies
$$ \langle \sigma(x), 1_n \rangle = 1, $$where \(\langle \cdot, \cdot \rangle\) is the Euclidean inner product on \(\bR^n\) and \(1_n\) is the all-ones vector.
Total derivatives
First-order total derivative
We begin by rewriting \(\sigma\) in a more convenient form as
$$ \sigma(x) = \frac{\Exp(x)}{\langle \Exp(x), 1_n \rangle}, $$where \(\Exp : \bR^n \to \bR^n\) is the component-wise application of \(\exp\). That is,
$$ \Exp(x) = \begin{bmatrix} \exp(x^1) \\ \vdots \\ \exp(x^n) \end{bmatrix}. $$The total derivative of \(\Exp\) at \(x\) is
$$ \begin{align} d\Exp(x) \cdot h = \Exp(x) \odot h, \end{align} $$where \(\odot\) is the component-wise product. To see this, observe that
$$ \Exp(x) = \begin{bmatrix} \exp \circ\, \pi^1(x) \\ \vdots \\ \exp \circ\, \pi^n(x) \end{bmatrix}, $$where \(\pi^i : \bR^n \to \bR\) is projection onto the \(i\)th factor. By the chain rule, we have
$$ \begin{align*} d\Exp(x) \cdot h = \begin{bmatrix} \exp(x^1) \pi^1(h) \\ \vdots \\ \exp(x^n) \pi^n(h) \end{bmatrix} = \Exp(x) \odot h. \end{align*} $$Using this fact, together with the quotient rule, the total derivative of \(\sigma\) at \(x\) is
$$ \begin{align*} d\sigma(x) \cdot h &= \frac{\langle \Exp(x), 1_n \rangle d\Exp(x) \cdot h - \langle d \Exp(x) \cdot h, 1_n \rangle \Exp(x) }{\langle \Exp(x), 1_n \rangle^2} \\ &= \frac{\langle \Exp(x), 1_n \rangle \Exp(x) \odot h - \langle \Exp(x) \odot h, 1_n \rangle \Exp(x) }{\langle \Exp(x), 1_n \rangle^2} \\ &= \frac{\Exp(x) \odot h}{\langle \Exp(x), 1_n \rangle} - \frac{\langle \Exp(x) \odot h, 1_n \rangle \Exp(x)}{\langle \Exp(x), 1_n \rangle^2} \\ &= \frac{\Exp(x)}{\langle \Exp(x), 1_n \rangle} \odot h - \left\langle \frac{\Exp(x)}{\langle \Exp(x), 1_n \rangle}, h \right\rangle \frac{\Exp(x)}{\langle \Exp(x), 1_n \rangle} \\ &= \sigma(x) \odot h - \langle \sigma(x), h \rangle \sigma(x). \end{align*} $$Reiterated on a single line:
$$ \colorbox{magicmint}{ $d\sigma(x) \cdot h = \sigma(x) \odot h - \langle \sigma(x), h \rangle \sigma(x)$. } $$Finally, note that we can write
$$ d \sigma(x) \cdot h = \sigma(x) \odot h - \sigma(x) \sigma(x)^t h, $$which implies that the Jacobian matrix of \(\sigma\) at \(x\) is
$$ \colorbox{lesserbox}{ $J\sigma(x) = \Diag \sigma(x) - \sigma(x) \sigma(x)^t$. } $$Here, \(\Delta : \bR^n \to \bR^{n \times n}\) sends \(x\) to the diagonal matrix whose \((i,i)\)th entry is \(x^i\).
Second-order total derivative
Moving to the second-order total derivative, recall that
$$ d^2 \sigma(x) \cdot (h_1, h_2) = d[x \mapsto d\sigma(x) \cdot h_1](x) \cdot h_2. $$Using the expression for \(d\sigma(x)\) from the previous section, we obtain
$$ \begin{align*} d^2 \sigma(x) \cdot (h_1, h_2) &= (d\sigma(x) \cdot h_2) \odot h_1 \\ &\qquad - \, \langle d\sigma(x) \cdot h_2, h_1 \rangle \sigma(x) \\ &\qquad - \, \langle \sigma(x), h_1 \rangle d\sigma(x) \cdot h_2. \end{align*} $$The first term on the right-hand side is
$$ \begin{align*} (d\sigma(x) \cdot h_2) \odot h_1 &= (\sigma(x) \odot h_2 - \langle \sigma(x), h_2 \rangle \sigma(x)) \odot h_1 \\ &= \sigma(x) \odot h_1 \odot h_2 - \langle \sigma(x), h_2 \rangle \sigma(x) \odot h_1. \end{align*} $$The second term is
$$ \begin{align*} - \langle d\sigma(x) \cdot h_2, h_1 \rangle \sigma(x) &= - \langle \sigma(x) \odot h_2 - \langle \sigma(x), h_2 \rangle \sigma(x), h_1 \rangle \sigma(x) \\ &= - \langle \sigma(x) \odot h_2, h_1 \rangle \sigma(x) + \langle \sigma(x), h_2 \rangle \langle \sigma(x), h_1 \rangle \sigma(x) \\ &= - \langle \sigma(x), h_1 \odot h_2 \rangle \sigma(x) + \langle \sigma(x), h_1 \rangle \langle \sigma(x), h_2 \rangle \sigma(x). \end{align*} $$Finally, the third term is
$$ \begin{align*} - \langle \sigma(x), h_1 \rangle d\sigma(x) \cdot h_2 &= - \langle \sigma(x), h_1 \rangle (\sigma(x) \odot h_2 - \langle \sigma(x), h_2 \rangle \sigma(x)) \\ &= - \langle \sigma(x), h_1 \rangle \sigma(x) \odot h_2 + \langle \sigma(x), h_1 \rangle \langle \sigma(x), h_2 \rangle \sigma(x). \end{align*} $$In total, we have
$$ \colorbox{magicmint}{ $ \begin{align*} d^2 \sigma(x) \cdot (h_1, h_2) &= \sigma(x) \odot h_1 \odot h_2 \\ &\qquad - \, \langle \sigma(x), h_2 \rangle \sigma(x) \odot h_1 \\ &\qquad - \, \langle \sigma(x), h_1 \rangle \sigma(x) \odot h_2 \\ &\qquad - \, \langle \sigma(x), h_1 \odot h_2 \rangle \sigma(x) \\ &\qquad + \, 2 \langle \sigma(x), h_1 \rangle \langle \sigma(x), h_2 \rangle \sigma(x). \end{align*} $ } $$Note that the right-hand side of the above expression is symmetric and bilinear.
Third-order total derivative
Moving to the third-order total derivative, recall that
$$ d^3 \sigma(x) \cdot (h_1, h_2, h_3) = d[x \mapsto d^2\sigma(x) \cdot (h_1, h_2)](x) \cdot h_3. $$Using the expression for \(d^2 \sigma(x)\) from the previous section, we obtain
$$ \begin{align*} d^3 \sigma(x) \cdot (h_1, h_2, h_3) &= [d\sigma(x) \cdot h_3] \odot h_1 \odot h_2 \\ &\qquad - \, \langle d\sigma(x) \cdot h_3, h_2 \rangle \sigma(x) \odot h_1 \\ &\qquad - \, \langle \sigma(x), h_2 \rangle [d\sigma(x) \cdot h_3] \odot h_1 \\ &\qquad - \, \langle d\sigma(x) \cdot h_3, h_1 \rangle \sigma(x) \odot h_2 \\ &\qquad - \, \langle \sigma(x), h_1 \rangle [d\sigma(x) \cdot h_3] \odot h_2 \\ &\qquad - \, \langle d\sigma(x) \cdot h_3, h_1 \odot h_2 \rangle \sigma(x) \\ &\qquad - \, \langle \sigma(x), h_1 \odot h_2 \rangle d\sigma(x) \cdot h_3 \\ &\qquad + \, 2 \langle d\sigma(x) \cdot h_3, h_1 \rangle \langle \sigma(x), h_2 \rangle \sigma(x) \\ &\qquad + \, 2 \langle \sigma(x), h_1 \rangle \langle d\sigma(x) \cdot h_3, h_2 \rangle \sigma(x) \\ &\qquad + \, 2 \langle \sigma(x), h_1 \rangle \langle \sigma(x), h_2 \rangle d\sigma(x) \cdot h_3. \end{align*} $$The first term on the right-hand side is
$$ \begin{align*} \sigma(x) \odot h_1 \odot h_2 \odot h_3 - \langle \sigma(x), h_3 \rangle \sigma(x) \odot h_1 \odot h_2. \end{align*} $$The second term is
$$ \begin{align*} - \langle \sigma(x), h_2 \odot h_3 \rangle \sigma(x) \odot h_1 + \langle \sigma(x), h_2 \rangle \langle \sigma(x), h_3 \rangle \sigma(x) \odot h_1. \end{align*} $$The third term is
$$ \begin{align*} - \langle \sigma(x), h_2 \rangle \sigma(x) \odot h_1 \odot h_3 + \langle \sigma(x), h_2 \rangle \langle \sigma(x), h_3 \rangle \sigma(x) \odot h_1. \end{align*} $$The fourth term is
$$ \begin{align*} - \langle \sigma(x), h_1 \odot h_3 \rangle \sigma(x) \odot h_2 + \langle \sigma(x), h_1 \rangle \langle \sigma(x), h_3 \rangle \sigma(x) \odot h_2. \end{align*} $$The fifth term is
$$ \begin{align*} - \langle \sigma(x), h_1 \rangle \sigma(x) \odot h_2 \odot h_3 + \langle \sigma(x), h_1 \rangle \langle \sigma(x), h_3 \rangle \sigma(x) \odot h_2. \end{align*} $$The sixth term is
$$ \begin{align*} - \langle \sigma(x), h_1 \odot h_2 \odot h_3 \rangle \sigma(x) + \langle \sigma(x), h_3 \rangle \langle \sigma(x), h_1 \odot h_2 \rangle \sigma(x). \end{align*} $$The seventh term is
$$ \begin{align*} - \langle \sigma(x), h_1 \odot h_2 \rangle \sigma(x) \odot h_3 + \langle \sigma(x), h_1 \odot h_2 \rangle \langle \sigma(x), h_3 \rangle \sigma(x). \end{align*} $$The eighth term is
$$ \begin{align*} 2 \langle \sigma(x), h_1 \odot h_3 \rangle \langle \sigma(x), h_2 \rangle \sigma(x) - 2 \langle \sigma(x), h_1 \rangle \langle \sigma(x), h_2 \rangle \langle \sigma(x), h_3 \rangle \sigma(x). \end{align*} $$The ninth term is
$$ \begin{align*} 2 \langle \sigma(x), h_1 \rangle \langle \sigma(x), h_2 \odot h_3 \rangle \sigma(x) - 2 \langle \sigma(x), h_1 \rangle \langle \sigma(x), h_2 \rangle \langle \sigma(x), h_3 \rangle \sigma(x). \end{align*} $$Finally, the tenth term is
$$ \begin{align*} 2 \langle \sigma(x), h_1 \rangle \langle \sigma(x), h_2 \rangle \sigma(x) \odot h_3 - 2 \langle \sigma(x), h_1 \rangle \langle \sigma(x), h_2 \rangle \langle \sigma(x), h_3 \rangle \sigma(x). \end{align*} $$In total, we have
$$ \colorbox{magicmint}{ $ \begin{align*} d^3 \sigma(x) \cdot (h_1, h_2, h_3) &= \sigma(x) \odot h_1 \odot h_2 \odot h_3 \\ &\qquad + \, 2 \langle \sigma(x), h_2 \rangle \langle \sigma(x), h_3 \rangle \sigma(x) \odot h_1 \\ &\qquad + \, 2 \langle \sigma(x), h_1 \rangle \langle \sigma(x), h_3 \rangle \sigma(x) \odot h_2 \\ &\qquad + \, 2 \langle \sigma(x), h_1 \rangle \langle \sigma(x), h_2 \rangle \sigma(x) \odot h_3 \\ &\qquad - \, \langle \sigma(x), h_3 \rangle \sigma(x) \odot h_1 \odot h_2 \\ &\qquad - \, \langle \sigma(x), h_2 \rangle \sigma(x) \odot h_1 \odot h_3 \\ &\qquad - \, \langle \sigma(x), h_1 \rangle \sigma(x) \odot h_2 \odot h_3 \\ &\qquad - \, \langle \sigma(x), h_2 \odot h_3 \rangle \sigma(x) \odot h_1 \\ &\qquad - \, \langle \sigma(x), h_1 \odot h_3 \rangle \sigma(x) \odot h_2 \\ &\qquad - \, \langle \sigma(x), h_1 \odot h_2 \rangle \sigma(x) \odot h_3 \\ &\qquad - \, \langle \sigma(x), h_1 \odot h_2 \odot h_3 \rangle \sigma(x) \\ &\qquad + \, 2 \langle \sigma(x), h_2 \odot h_3 \rangle \langle \sigma(x), h_1 \rangle \sigma(x) \\ &\qquad + \, 2 \langle \sigma(x), h_1 \odot h_3 \rangle \langle \sigma(x), h_2 \rangle \sigma(x) \\ &\qquad + \, 2 \langle \sigma(x), h_1 \odot h_2 \rangle \langle \sigma(x), h_3 \rangle \sigma(x) \\ &\qquad - \, 6 \langle \sigma(x), h_1 \rangle \langle \sigma(x), h_2 \rangle \langle \sigma(x), h_3 \rangle \sigma(x). \end{align*} $ } $$One can check that the right-hand side of the above expression is symmetric and trilinear.
Taylor series approximations
The results above can be used to produce useful approximations to \(\sigma\). Recall that for a smooth map \(f : \bR^n \to \bR^n\), the \(p\)-th order Taylor series approximation to \(f\) at \(x\) is
$$ f(x + h) \approx \sum_{k=0}^{p} \frac{1}{k!} d^k f(x) \cdot [h]^k, $$where \(d^0 f = f\) and \([h]^k\) is the \(k\)-tuple whose components are all equal to \(h\).
First-order approximation
The first-order Taylor series approximation to \(\sigma\) at \(x\) is
$$ \colorbox{lesserbox}{ $ \begin{align*} \sigma(x + h) \approx \sigma(x) + \sigma(x) \cdot h - \langle \sigma(x), h \rangle \sigma(x). \end{align*} $ } $$Second-order approximation
First observe that
$$ \begin{align*} d^2 \sigma(x) \cdot (h, h) &= \sigma(x) \odot h \odot h - 2 \langle \sigma(x), h \rangle \sigma(x) \odot h \\ &\qquad - \, \langle \sigma(x), h \odot h \rangle \sigma(x) + 2 \langle \sigma(x), h \rangle \langle \sigma(x), h \rangle \sigma(x). \end{align*} $$The second-order Taylor series approximation to \(\sigma\) at \(x\) is
$$ \colorbox{lesserbox}{ $ \begin{align*} \sigma(x + h) &\approx \sigma(x) + \sigma(x) \cdot h - \langle \sigma(x), h \rangle \sigma(x) \\ &\qquad + \, \frac{1}{2} \sigma(x) \odot h \odot h - \langle \sigma(x), h \rangle \sigma(x) \odot h \\ &\qquad - \, \frac{1}{2} \, \langle \sigma(x), h \odot h \rangle \sigma(x) + \langle \sigma(x), h \rangle \langle \sigma(x), h \rangle \sigma(x). \end{align*} $ } $$Third-order approximation
First observe that
$$ \begin{align*} d^3 \sigma(x) \cdot (h, h, h) &= \sigma(x) \odot h \odot h \odot h + 6 \langle \sigma(x), h \rangle \langle \sigma(x), h \rangle \sigma(x) \odot h \\ &\qquad - \, 3 \langle \sigma(x), h \rangle \sigma(x) \odot h \odot h - 3 \langle \sigma(x), h \odot h \rangle \sigma(x) \odot h \\ &\qquad - \, \langle \sigma(x), h \odot h \odot h \rangle \sigma(x) + 6 \langle \sigma(x), h \odot h \rangle \langle \sigma(x), h \rangle \sigma(x) \\ &\qquad - \, 6 \langle \sigma(x), h \rangle \langle \sigma(x), h \rangle \langle \sigma(x), h \rangle \sigma(x). \end{align*} $$The third-order Taylor series approximation to \(\sigma\) at \(x\) is
$$ \colorbox{lesserbox}{ $ \begin{align*} \sigma(x + h) &\approx \sigma(x) + \sigma(x) \cdot h - \langle \sigma(x), h \rangle \sigma(x) \\ &\qquad + \, \frac{1}{2} \sigma(x) \odot h \odot h - \langle \sigma(x), h \rangle \sigma(x) \odot h \\ &\qquad - \, \frac{1}{2} \, \langle \sigma(x), h \odot h \rangle \sigma(x) + \langle \sigma(x), h \rangle \langle \sigma(x), h \rangle \sigma(x) \\ &\qquad + \, \frac{1}{6} \sigma(x) \odot h \odot h \odot h + \langle \sigma(x), h \rangle \langle \sigma(x), h \rangle \sigma(x) \odot h \\ &\qquad - \, \frac{1}{2} \langle \sigma(x), h \rangle \sigma(x) \odot h \odot h - \frac{1}{2} \langle \sigma(x), h \odot h \rangle \sigma(x) \odot h \\ &\qquad - \, \frac{1}{6} \langle \sigma(x), h \odot h \odot h \rangle \sigma(x) + \langle \sigma(x), h \odot h \rangle \langle \sigma(x), h \rangle \sigma(x) \\ &\qquad - \, \langle \sigma(x), h \rangle \langle \sigma(x), h \rangle \langle \sigma(x), h \rangle \sigma(x). \end{align*} $ } $$