In this post, we review the generalized Newton’s method (GeN) proposed in [1]. Then, we explicitly compute the learning rates prescribed by the exact version of GeN, for a simple problem instance. Then, we give a high-level overview of a PyTorch implementation which runs the exact version of GeN for stochastic gradient descent.
$$ \newcommand{\bN}{\mathbb{N}} \newcommand{\bR}{\mathbb{R}} \newcommand{\sL}{\mathscr{L}} \newcommand{\tr}{\mathrm{trace}} \definecolor{lesserbox}{rgb}{0.85, 0.95, 1.0} \definecolor{magicmint}{rgb}{0.67, 0.94, 0.82} $$Generalized Newton’s method
Consider a smooth function \(f : \Theta \to \bR\), where \(\Theta\) is a finite-dimensional inner product space. Broadly speaking, gradient descent attempts to minimize \(f\) by iterating
$$ \theta_t \leftarrow \theta_{t-1} - \alpha_t \omega(\theta_{t-1}, \nabla f(\theta_{t-1})), $$where
- \(\alpha_t \in \bR_{\geq 0}\) is the learning rate at step \(t\),
- \(\omega : \Theta \times \Theta \to \Theta\) is the optimizer, and
- \(\nabla f(\theta)\) is the gradient of \(f\) at \(\theta\).
For example, in stochastic gradient descent (SGD) the optimizer is simply \(\omega(\theta_1, \theta_2) = \theta_2\), and the gradient descent iteration becomes
$$ \theta_t \leftarrow \theta_{t-1} - \alpha_t \nabla f(\theta_{t-1}). $$The generalized Newton’s method introduced in [1] is a learning rate scheduler. That is, the method is a prescription for how to choose \(\alpha_t\) at each gradient descent iteration, with the goal of accelerating the convergence of gradient descent. To motivate the generalized Newton’s method, suppose that \(\theta\) is fixed and define
$$ \begin{align*} g_\theta : \bR &\to \bR \\ \alpha &\mapsto g_\theta(\alpha) = f(\theta - \alpha \omega), \end{align*} $$where we have introduced the shorthand
$$ \omega = \omega(\theta, \nabla f(\theta)). $$By composition, \(g_\theta\) is smooth, and its second-order Taylor series expansion at \(0\) is
$$ g_\theta(\alpha) = f(\theta) - \alpha df(\theta) \cdot \omega + \frac{\alpha^2}{2} d^2 f(\theta) \cdot (\omega, \omega) + O(\alpha^3) $$for \(\alpha\) sufficiently small, say \(|\alpha| < \delta(\theta)\).
To minimize \(g_\theta\), we will work with the second-order approximating function
$$ \alpha \mapsto f(\theta) - \alpha df(\theta) \cdot \omega + \frac{\alpha^2}{2} d^2 f(\theta) \cdot (\omega, \omega). $$From single-variable calculus, we know that if
$$ d^2 f(\theta) \cdot (\omega, \omega) > 0, $$then the second-order approximating function attains its global minimum at
$$ \alpha_*(\theta) = \frac{\mathrm{num}(\theta)}{\mathrm{den}(\theta)} = \frac{ df(\theta) \cdot \omega }{ d^2 f(\theta) \cdot (\omega, \omega) }. $$Following [1], we say that \(\alpha_*(\theta)\) is the optimal learning rate for \(f\) at \(\theta\).
Note: The name “optimal learning rate” is somewhat misleading, since the derivation of
\(\alpha_*(\theta)\) is based on a local approximation. Indeed, even though
\(g_\theta\) may be quasi-parabolic globally, its Taylor series expansion at \(0\) may
not accurately capture the global parabolic shape. In this way, \(\alpha_*(\theta)\)
can be far from the point at which the second-order approximating function attains
its global minimum.
The above figure shows an example of this phenomenon, for a particular case where
\(f\) is the norm-squared loss function for a fully-connected neural network. The blue curve
is \(g_\theta\), and the green curve is the second-order approximating function.
\(\diamond\)
Note: For comparison purposes, the above equation is Equation (3.1) in [1], although they use different notation – for example, they write \(\eta^*_t\) instead of \(\alpha_*(\theta)\). \(\diamond\)
We can use the optimal learning rates to drive a learning rate scheduler, by setting
$$ \alpha_t = \alpha_*(\theta_t). $$In this case, the gradient descent iteration becomes
$$ \theta_t = \theta_{t-1} - \frac{ df(\theta_t) \cdot \omega_t }{ d^2 f(\theta_t) \cdot (\omega_t, \omega_t) } \omega_t, $$where we have introduced the shorthand
$$ \omega_t = \omega(\theta_t, \nabla f(\theta_t)). $$This learning rate scheduler is the exact generalized Newton’s method (exact GeN).
To avoid the computational burden of computing second-order partial derivatives (in other words, computing full Hessians or Hessian-vector products), the authors of [1] propose the following “backpropagation-free” differencing scheme: Instead of computing \(\alpha_*(\theta_t)\) at each gradient descent iteration, compute
$$ \Delta_t = \frac{\alpha_{t-1}}{2} \frac{ g_{\theta_t}(-\alpha_{t-1}) - g_{\theta_t}(\alpha_{t-1}) }{ g_{\theta_t}(-\alpha_{t-1}) - 2 g_{\theta_t}(0) + g_{\theta_t}(\alpha_{t-1}) }. $$If the denominator of \(\Delta_t\) is positive, then set
$$ \alpha_t = \gamma \alpha_{t-1} + (1 - \gamma) \Delta_t, $$where the moving average coefficient \(0 \leq \gamma < 1\) is chosen a priori. On the other hand, if the denominator of \(\Delta_t\) is not positive, then set \(\alpha_t\) to a small default learning rate.
This learning rate scheduler is the approximate generalized Newton’s method (approximate GeN). To see that the definition of \(\Delta_t\) is correct, observe that
$$ \begin{align*} \frac{g_{\theta}(-\alpha) - g_{\theta}(\alpha)}{2 \alpha} = df(\theta) \cdot \omega \end{align*} $$and
$$ \begin{align*} \frac{g_{\theta}(-\alpha) - 2 g_{\theta}(0) + g_{\theta}(\alpha)}{\alpha^2} = d^2 f(\theta) \cdot (\omega, \omega). \end{align*} $$Note: As described in [1], other schemes for calculating \(\Delta_t\) are possible – for example, we could compute a polynomial fit which incorporates more data points.
Note: Approximate GeN is “backpropagation-free” in the sense that it does not use additional backpropagation steps to compute second-order partial derivatives. Of course, backpropagation is still used to compute the parameter gradients passed to \(\omega_t\). \(\diamond\)
Note: As noted in [1], assuming that \(\alpha_*(\theta_t)\) is slowly varying with \(t\), we can recompute \(\alpha_t\) periodically to amortize the extra computation needed for exact or approximate GeN. \(\diamond\)
Note: In an implementation of exact or approximate GeN, care should be taken to ensure that each \(\alpha_t\) is in the interval where the second-order approximation is declared (or somehow determined to be) valid. If \(\alpha_t\) is outside this interval, then a reasonable policy is to set \(\alpha_t\) to a small default learning rate. \(\diamond\)
Note: In the machine-learning context, \(\theta\) represents trainable model parameters lumped into a single parameter vector. We can rephrase the calculation of \(\alpha_*(\theta)\) in terms of the individual parameter vectors, as follows. First, suppose that \(\Theta\) is the product inner product space \(\Theta = \Theta_1 \times \cdots \times \Theta_n\) with generic element \(\theta = (\theta_1,\dots,\theta_n)\). Then we can write
$$ \alpha_*(\theta) = \frac{ \sum_{i=1}^{n} d_{\theta_i} f(\theta) \cdot \omega_i }{ \sum_{i,j=1}^{n} d^2_{\theta_i,\theta_j} f(\theta) \cdot (\omega_i, \omega_j) }. $$Here we have introduced the following notation:
- \(\nabla_{\theta_i} f(\theta)\) is the partial gradient of \(f\) with respect to \(\theta_i\) at \(\theta\), and
- \(d^2_{\theta_i,\theta_j} f(\theta)\) is the second-order partial derivative of \(f\) with respect to \((\theta_i,\theta_j)\) at \(\theta\).
We will see an example of this in the next section. \(\diamond\)
Example: Exact GeN for single-layer fully-connected neural networks
In this section, we compute optimal learning rates prescribed by exact GeN. The focus here is a simple problem instance, where we have a single-layer fully-connected neural network activated by a piecewise-linear function (for example, ReLU), the objective is to minimize the norm-squared loss function, and the optimizer is SGD.
We begin by defining the parameter space to be
$$ \Theta = \bR^{n_1 \times n_0} \times \bR^{n_1} $$with generic element \(\theta = (W_1,b_1)\).
The activation function is a piecewise-linear function \(\sigma_0 : \bR \to \bR\).
The activation map \(\sigma: \bR^{n_1} \to \bR^{n_1}\) is defined by
$$ \begin{align*} \sigma(x) &= (\sigma_0(x^1), \dots, \sigma_0(x^{n_1}))^t. \end{align*} $$Note that the component indices of \(x\) are written with superscripts.
Wherever it is well-defined, the total derivative of \(\sigma\) at \(x\) is
$$ d\sigma(x) \cdot h = \Delta \sigma'(x) h, $$where
$$ \sigma'(x) = (\sigma_0'(x^1), \dots, \sigma_0'(x^{n_1}))^t $$and \(\Delta : \bR^{n_1} \to \bR^{n_1 \times n_1}\) sends \(x\) to the diagonal matrix whose \((i,i)\)th component is \(x^i\).
Similarly, wherever it is well-defined, the total derivative of \(\sigma'\) at \(x\) is
$$ d\sigma'(x) \cdot h = \Delta \sigma''(x) h = 0_{n_1}, $$since \(\sigma_0\) is piecewise-linear.
Consider the single-layer, fully-connected neural network
$$ \begin{align*} \hat{y} : \Theta \times \bR^{n_0} &\to \bR^{n_1} \end{align*} $$defined by
$$ \begin{align*} \hat{y}(\theta, x) &= \sigma (W_1 x + b_1). \end{align*} $$For notational convenience, we introduce the intermediate computation map
$$ \begin{align*} z_1 : \Theta \times \bR^{n_0} &\to \bR^{n_1} \end{align*} $$defined by
$$ \begin{align*} z_1(\theta, x) &= W_1 x + b_1. \end{align*} $$In the rest of this post, we will assume that we are working away from those points where the activation map is not differentiable. This is necessary since piecewise-linear functions are not everywhere differentiable, in general.
We will need the first- and second-order partial derivatives of \(\hat{y}\) with respect to \(W_1, b_1\).
The first-order partial derivatives of \(\hat{y}\) at \((\theta, x)\) are
$$ \begin{align*} d_{W_1} \hat{y}(\theta,x) \cdot W &= d \sigma(z_1(\theta,x)) \circ d_{W_1} z_1(\theta,x) \cdot W \\ &= \Delta \sigma'(z_1(\theta,x)) W x \end{align*} $$and
$$ \begin{align*} d_{b_1} \hat{y}(\theta,x) \cdot b &= d \sigma(z_1(\theta,x)) \circ d_{b_1} z_1(\theta,x) \cdot b \\ &= \Delta \sigma'(z_1(\theta,x)) b. \end{align*} $$The second-order partial derivatives of \(\hat{y}\) at \((\theta, x\)) are
$$ \begin{align*} d^2_{W_1} \hat{y}(\theta,x) \cdot (V,W) &= \Delta [d \sigma'(z_1(\theta,x)) \circ \cdots] W x = 0_{n_1} \end{align*} $$and
$$ \begin{align*} d^2_{b_1} \hat{y}(\theta,x) \cdot (a,b) &= \Delta [d \sigma'(z_1(\theta,x)) \circ \cdots] b = 0_{n_1} \end{align*} $$and finally
$$ \begin{align*} d^2_{W_1,b_1} \hat{y}(\theta,x) \cdot (W,b) &= \Delta [d \sigma'(z_1(\theta,x)) \circ \cdots] b = 0_{n_1}. \end{align*} $$Note that we have written \(d^2_{W_1} \hat{y}\) instead of \(d^2_{W_1, W_1} \hat{y}\) and similarly for the \(b_1\) derivative.
Turning to optimal learning rates, suppose that we have (mini-batch) training data
$$ \begin{align*} \{ (x_i, y_i) \}_{i=1}^{M}, \quad x_i \in \bR^{n_0}, \quad y_i \in \bR^{n_1} \end{align*} $$and we are using the norm-squared loss function
$$ \begin{align*} \sL : \Theta \to \bR \end{align*} $$defined by
$$ \begin{align*} \sL(\theta) &= \frac{1}{2M} \sum_{i=1}^{M} \langle \hat{y}(\theta, x_i) - y_i, \hat{y}(\theta, x_i) - y_i \rangle, \end{align*} $$where \(\langle\cdot,\cdot\rangle\) is the Euclidean inner product on \(\bR^{n_1}\). To make use of the results in the previous section, \(\sL\) must be equal to its second-order Taylor series expansion at \(\theta\). For now, let’s assume this is true; it will be proven below after partial derivative computations.
Since the optimier is SGD, the optimal learning rate for \(\sL\) at \(\theta\) is
$$ \begin{align*} \alpha_*(\theta) &= \frac{\mathrm{num}(\theta)}{\mathrm{den}(\theta)}, \end{align*} $$where
$$ \begin{align*} \mathrm{num}(\theta) &= d_{W_1} \sL(\theta) \cdot \nabla_{W_1} \sL(\theta) + d_{b_1} \sL(\theta) \cdot \nabla_{b_1} \sL(\theta) \\ \mathrm{den}(\theta) &= d^2_{W_1} \sL(\theta) \cdot (\nabla_{W_1} \sL(\theta), \nabla_{W_1} \sL(\theta)) \\ &\qquad + \, 2 d^2_{W_1,b_1} \sL(\theta) \cdot (\nabla_{W_1} \sL(\theta), \nabla_{b_1} \sL(\theta)) \\ &\qquad + \, d^2_{b_1} \sL(\theta) \cdot (\nabla_{b_1} \sL(\theta), \nabla_{b_1} \sL(\theta)). \end{align*} $$Recall that for \(\alpha_*(\theta)\) to be well-defined, we must have \(\mathrm{den}(\theta) > 0\).
To compute \(\alpha_*(\theta)\), we start with the partial gradients. For convenience, we set
$$ \colorbox{lesserbox} { $ \begin{align*} e_i &= \hat{y}(\theta, x_i) - y_i \\ \Delta'(z_{1,i}) &= \Delta \sigma'(z_1(\theta, x_i)). \end{align*} $ } $$Observe that, using the cyclic property of trace, we have
$$ \begin{align*} \langle v, d_{W_1} \hat{y}(\theta, x_i) \cdot W \rangle &= \tr(v^t \Delta' (z_{1,i}) Wx_i) \\ &= \tr(x_i v^t \Delta' (z_{1,i}) W) \\ &= \langle \Delta' (z_{1,i}) v x_i^t, W \rangle_F, \end{align*} $$where \(\langle A, B \rangle_F = \tr(A^t B)\) is the Frobenius inner product. This shows that
$$ \begin{align*} d_{W_1} \hat{y}(\theta, x_i)^* \cdot v = \Delta' (z_{1,i}) v x_i^t, \end{align*} $$where the superscript “\(*\)” denotes adjoint. Using this, we obtain
$$ \begin{align*} d_{W_1} \sL(\theta) \cdot W &= \frac{1}{M} \sum_{i=1}^{M} \langle d_{W_1} \hat{y}(\theta, x_i) \cdot W, e_i \rangle \\ &= \left\langle W, \frac{1}{M} \sum_{i=1}^{M} d_{W_1} \hat{y}(\theta, x_i)^* \cdot e_i \right\rangle_F \\ &= \left\langle W, \frac{1}{M} \sum_{i=1}^{M} \Delta'(z_{1,i}) e_i x_i^t \right\rangle_F. \end{align*} $$Recalling that \(\nabla_{W_1} \sL(\theta)\) is the unique element of \(\bR^{n_1 \times n_0}\) satisfying
$$ d_{W_1} \sL(\theta) \cdot W = \left\langle W, \nabla_{W_1} \sL(\theta) \right\rangle_F, $$we have
$$ \begin{align*} \nabla_{W_1} \sL(\theta) = \frac{1}{M} \sum_{i=1}^{M} \Delta'(z_{1,i}) e_i x_i^t \end{align*} $$and the first term in \(\mathrm{num}(\theta)\) is
$$ \begin{align*} d_{W_1} \sL(\theta) \cdot \nabla_{W_1} \sL(\theta) &= \frac{1}{M^2} \sum_{i,j=1}^{M} \langle \Delta'(z_{1,i}) \Delta'(z_{1,j}) e_j x_j^t x_i, e_i \rangle. \end{align*} $$In a similar way, we can compute
$$ \begin{align*} \nabla_{b_1} \sL(\theta) = \frac{1}{M} \sum_{i=1}^{M} \Delta'(z_{1,i}) e_i \end{align*} $$and the second term in \(\mathrm{num}(\theta)\) is
$$ \begin{align*} d_{b_1} \sL(\theta) \cdot \nabla_{b_1} \sL(\theta) &= \frac{1}{M^2} \sum_{i,j=1}^{M} \langle \Delta'(z_{1,i}) \Delta'(z_{1,j}) e_j, e_i \rangle. \end{align*} $$Combining terms, we have
$$ \colorbox{lesserbox} { $ \begin{align*} \mathrm{num}(\theta) = \frac{1}{M^2} \sum_{i,j=1}^{M} (1 + x_i^t x_j) \langle \Delta'(z_{1,i}) \Delta'(z_{1,j}) e_j, e_i \rangle. \end{align*} $ } $$As for the denominator of \(\alpha_*(\theta)\), we begin by computing
$$ \begin{align*} d^2_{W_1} \sL(\theta) \cdot (V, W) &= \frac{1}{M} \sum_{i=1}^{M} \langle d^2_{W_1} \hat{y}(\theta, x_i) \cdot (V, W), e_i \rangle \\ &\qquad+ \, \frac{1}{M} \sum_{i=1}^{M} \langle d_{W_1} \hat{y}(\theta, x_i) \cdot V, d_{W_1} \hat{y}(\theta, x_i) \cdot W \rangle \\ &= \frac{1}{M} \sum_{i=1}^{M} \langle \Delta'(z_{1,i}) V x_i, \Delta'(z_{1,i}) W x_i \rangle. \end{align*} $$Analogously, the other second-order partial derivatives are
$$ \begin{align*} d^2_{W_1,b_1} \sL(\theta) \cdot (W,b) &= \frac{1}{M} \sum_{i=1}^{M} \langle \Delta'(z_{1,i}) W x_i, \Delta'(z_{1,i}) b \rangle \end{align*} $$and
$$ \begin{align*} d^2_{b_1} \sL(\theta) \cdot (a, b) &= \frac{1}{M} \sum_{i=1}^{M} \langle \Delta'(z_{1,i}) a, \Delta'(z_{1,i}) b \rangle. \end{align*} $$Note that the above results imply \(d^3 \sL(\theta) \equiv 0\) and consequently \(\sL\) is equal to its second-order Taylor series expansion at \(\theta\). To see this, observe (for example) that
$$ \begin{align*} d^3_{W_1} \sL(\theta) \cdot (U,V,W) &= \frac{1}{M} \sum_{i=1}^{M} \langle \Delta [d\sigma'(z_{1,i}) \circ \cdots] V x_i, \cdots \rangle \\ &\qquad + \, \frac{1}{M} \sum_{i=1}^{M} \langle \cdots, \Delta [d\sigma'(z_{1,i}) \circ \cdots] W x_i \rangle = 0. \end{align*} $$Similarly, the other third-order partial derivatives are also identically equal to \(0\).
Plugging in the partial gradients, we obtain
$$ \begin{align*} &{} d^2_{W_1} \sL(\theta) \cdot (\nabla_{W_1} \sL(\theta), \nabla_{W_1} \sL(\theta)) \\ &= \frac{1}{M^3} \sum_{i,j,k=1}^{M} \langle \Delta'(z_{1,i}) \Delta'(z_{1,j}) e_j x_j^t x_i, \Delta'(z_{1,i}) \Delta'(z_{1,k}) e_k x_k^t x_i \rangle \end{align*} $$and
$$ \begin{align*} &{} d^2_{b_1} \sL(\theta) \cdot (\nabla_{b_1} \sL(\theta), \nabla_{b_1} \sL(\theta)) \\ &= \frac{1}{M^3} \sum_{i,j,k=1}^{M} \langle \Delta'(z_{1,i}) \Delta'(z_{1,j}) e_j, \Delta'(z_{1,i}) \Delta'(z_{1,k}) e_k \rangle. \end{align*} $$Finally,
$$ \begin{align*} &{} d^2_{W_1,b_1} \sL(\theta) \cdot (\nabla_{W_1} \sL(\theta), \nabla_{b_1} \sL(\theta)) \\ &= \frac{1}{M^3} \sum_{i,j,k=1}^{M} \langle \Delta'(z_{1,i}) \Delta'(z_{1,j}) e_j x_j^t x_i, \Delta'(z_{1,i}) \Delta'(z_{1,k}) e_k \rangle. \end{align*} $$Combining terms (and making a dummy index swap), the denominator of \(\alpha_*(\theta)\) is
$$ \colorbox{lesserbox} { $ \begin{align*} \mathrm{den}(\theta) &= \frac{1}{M^3} \sum_{i,j,k=1}^{M} (1 + x_i^t x_j) (1 + x_i^t x_k) \langle \Delta'(z_{1,i}) \Delta'(z_{1,j}) e_j, \Delta'(z_{1,i}) \Delta'(z_{1,k}) e_k \rangle. \end{align*} $ } $$Putting everything together, the optimal learning rate for \(\sL\) at \(\theta\) is
$$ \colorbox{magicmint} { $ \alpha_*(\theta) = \frac{ M \sum_{i,j=1}^{M} (1 + x_i^t x_j) \langle \Delta'(z_{1,i}) \Delta'(z_{1,j}) e_j, e_i \rangle }{ \sum_{i,j,k=1}^{M} (1 + x_i^t x_j) (1 + x_i^t x_k) \langle \Delta'(z_{1,i}) \Delta'(z_{1,j}) e_j, \Delta'(z_{1,i}) \Delta'(z_{1,k}) e_k \rangle }. $ } $$For the case \(M = 1\), this expression simplifies to
$$ \colorbox{lesserbox} { $ \alpha_*(\theta) = \frac{ \langle \Delta'(z_{1,1}) \Delta'(z_{1,1}) e_1, e_1 \rangle }{ (1 + x_1^t x_1) \langle \Delta'(z_{1,1}) \Delta'(z_{1,1}) e_1, \Delta'(z_{1,1}) \Delta'(z_{1,1}) e_1 \rangle }. $ } $$Further, if the activation function is standard ReLU, then clearly
$$ \Delta'(z_{1,1}) \Delta'(z_{1,1}) = \Delta'(z_{1,1}) $$and we have
$$ \colorbox{lesserbox} { $ \alpha_*(\theta) = \frac{ 1 }{ 1 + x_1^t x_1 }. $ } $$This directly relates the optimal learning rate to the norm of the input vector.
Implementation of exact GeN for SGD
This snippet computes first-order approximation coefficients.
def norm_of_tensor_dict(tensor_dict: _TensorDict, p: float = 2.0) -> _Scalar:
"""Helper function to sum the norms of each tensor in a dictionary.
Args:
tensor_dict: Dictionary containing only tensors.
ord: Order of the norm (default = 2.0).
Returns:
Scalar tensor with 2-norm.
"""
tensors = tensor_dict.values()
return sum(linalg.vector_norm(tensor, p) ** 2.0 for tensor in tensors)
def first_order_approximation_coeffs(
model: nn.Module, criterion: CriterionType, x: Real[Tensor, "..."], y: Real[Tensor, "..."]
) -> tuple[_ScalarTwoTuple, _TensorDict]:
"""Compute coefficients of first-order Taylor series approximation.
Args:
model: Network model.
criterion: Loss criterion function.
x: Input tensor.
y: Output tensor (target).
Returns:
Tuple containing:
- Tuple of scalar tensors with approximation coefficients.
- Dictionary with model parameter gradients. This can be ignored,
since it is only to avoid code duplication in the second-order
approximation code.
"""
# Extract parameters from `model` to pass to `torch.func.functional_call()`
params_dict = dict(model.named_parameters())
# Wrapper function for parameter-dependent loss
def parameterized_loss(params_dict):
y_hat = functional_call(model, params_dict, (x,))
return criterion(y_hat, y)
with torch.no_grad():
# Polynomial coefficients
coeff_0 = parameterized_loss(params_dict)
coeff_1 = torch.as_tensor(0.0)
# Compute parameter gradients
grad_params_dict = grad(parameterized_loss)(params_dict)
# Compute first-order coefficient
coeff_1 = norm_of_tensor_dict(grad_params_dict)
return (coeff_0, -coeff_1), grad_params_dict
This snippet builds on the previous one, to compute second-order approximation coefficients.
def second_order_approximation_coeffs(
model: nn.Module, criterion: CriterionType, x: Real[Tensor, "..."], y: Real[Tensor, "..."]
) -> _ScalarThreeTuple:
"""Compute coefficients of second-order Taylor series approximation.
Args:
model: Network model.
criterion: Loss criterion function.
x: Input tensor.
y: Output tensor (target).
Returns:
Tuple of scalar tensors with approximation coefficients.
"""
# Wrapper function for parameter-dependent loss
# - This version is compatible with `make_functional()`, which is needed
# for the call to `torch.autograd.functional.vhp()`. PyTorch issues a
# warning about using `make_functional()`, but there seems to be no
# analogue of `torch.autograd.functional.vhp()` which can be used with
# `torch.func.functional_call()`.
def parameterized_loss(*params):
model_func, _ = make_functional(model)
y_hat = model_func(params, x)
return criterion(y_hat, y)
with torch.no_grad():
coeffs, grad_params_dict = first_order_approximation_coeffs(model, criterion, x, y)
coeff_2 = torch.as_tensor(0.0)
# Compute second-order coefficient
params = tuple(model.parameters())
grad_params = tuple(grad_params_dict.values())
_, prod = vhp(parameterized_loss, params, grad_params)
for i, grad_param in enumerate(grad_params):
coeff_2 += torch.dot(grad_param.flatten(), prod[i].flatten())
# Note: Minus was already applied to first-order coefficient
return (coeffs[0], coeffs[1], coeff_2 / 2.0)
Now we can subclass torch.optim.lr_scheduler.LRScheduler
to implement
exact GeN.
class ExactGeNForSGD(LRScheduler):
"""Exact GeN for SGD.
Args:
optimizer: Optimizer.
last_epoch: Number of last epoch.
model: Network model.
criterion: Loss criterion function.
lr_min: Minimum learning rate to use.
lr_max: Maximum learning rate to use.
"""
_DEFAULT_LR = 1e-3
def __init__( # noqa: DCO010
self,
optimizer: torch.optim.Optimizer,
last_epoch: int,
model: nn.Module,
criterion: CriterionType,
lr_min: float,
lr_max: float,
) -> None:
super().__init__(optimizer, last_epoch)
self.model = model
self.criterion = criterion
self.lr_min = lr_min
self.lr_max = lr_max
self.base_lrs = [group["lr"] for group in optimizer.param_groups]
self.current_lrs = self.base_lrs.copy()
# Pylint complains that redefinition of step() has a different signature
def step( # pylint: disable=arguments-renamed
self, x: Optional[Real[Tensor, "..."]] = None, y: Optional[Real[Tensor, "..."]] = None
) -> list[float]:
"""Update learning rate(s) in the optimizer.
Args:
x: Input tensor.
y: Output tensor (target).
Returns:
List of learning rates for each parameter group.
"""
lrs = self.get_lr(x, y)
# Update learning rates in the optimizer
for param_group, lr in zip(self.optimizer.param_groups, lrs):
param_group["lr"] = lr
return lrs
def get_lr(
self, x: Optional[Real[Tensor, "..."]] = None, y: Optional[Real[Tensor, "..."]] = None
) -> list[float]:
"""Compute learning rate(s) for a particular batch.
Args:
x: Input tensor.
y: Output tensor (target).
Returns:
List of learning rates for each parameter group.
"""
# Handle initial step (in this case, `x` and `y` are not available)
if x is None and y is None:
lr = self._DEFAULT_LR
else:
# Get coefficients of second-order approximation
coeffs = second_order_approximation_coeffs(self.model, self.criterion, x, y)
coeffs = [coeff.item() for coeff in coeffs]
if coeffs[2] <= 0.0:
# Approximation is concave --> use default learning rate
lr = self._DEFAULT_LR
else:
# Approximation is convex --> use alpha_star
alpha_star = -coeffs[1] / (2.0 * coeffs[2])
lr = min(self.lr_max, max(alpha_star, self.lr_min))
# Update current learning rate(s)
num_groups = len(self.optimizer.param_groups)
self.current_lrs = [lr for _ in range(num_groups)]
return self.current_lrs
def get_last_lr(self) -> list[float]: # noqa
return self.current_lrs
The next snippet shows how to use exact GeN for SGD in a standard training loop.
# Make SGD optimizer
optimizer = optim.SGD(model.parameters())
# Make exact GeN for SGD
scheduler = ExactGeNForSGD(
optimizer, -1, model, criterion, config.lr_min, config.lr_max
)
<...>
model.train()
for epoch in range(config.num_epochs):
for x, y in dataloader:
# Move data to device
x = x.to(device)
y = y.to(device)
# Zero model parameter gradients
optimizer.zero_grad()
# Run forward pass
y_hat = model(x)
# Compute loss
loss = criterion(y_hat, y)
# Run backward pass
loss.backward()
# Adjust learning rate(s) in optimizer
scheduler.step(x, y)
# Adjust model parameters using new learning rate(s)
optimizer.step()
References
[1] Zi Bu and Shiyun Xu, Automatic gradient descent with generalized Newton’s method, arXiv:2407.02772 [cs.LG]