# Implicit Differentiation for Hyperparameter Optimization

## 2021/02/06

Implicit differentiation is often used for hyperparameter optimization, to avoid backpropagating through an optimization process. A recent example can be found in this paper that applies the technique to learn, among other things, data augmentation that improves generalization. Define the following hyperparameter optimization problem, where $\lambda$ denotes hyperparameters and ${𝜃}^{\star }$ denotes (optimized) parameters:

 ${\lambda }^{\star }=\mathrm{arg}\underset{\lambda }{\mathrm{min}}g\left({𝜃}^{\star }\left(\lambda \right),\lambda \right)$ (1)

The optimized parameters ${𝜃}^{\star }$ are output from another optimization, which is the training process. And the training problem should depend on the hyperparameters:

 ${𝜃}^{\star }\left(\lambda \right)=\mathrm{arg}\underset{𝜃}{\mathrm{min}}f\left(𝜃,\lambda \right)$ (2)

Notice we have separate losses: for example $g$ is the validation loss, and $f$ is the training loss. Now to solve the hyperparameter optimization, assume we are doing something like gradient descent (GD) on $\lambda$ with current iterate ${\lambda }_{t}$, so with learning rate $\eta$ the update is:

 ${\lambda }_{t+1}={\lambda }_{t}-\eta \frac{\mathit{dg}}{\mathit{d\lambda }}\left({𝜃}^{\star }\left({\lambda }_{t}\right),{\lambda }_{t}\right)$

Expanding the derivative via chain rule (and dropping arguments to reduce clutter):

 $\frac{\mathit{dg}}{\mathit{d\lambda }}=\frac{\mathit{\partial g}}{\mathit{\partial 𝜃}}\frac{d{𝜃}^{\star }}{\mathit{d\lambda }}+\frac{\mathit{\partial g}}{\mathit{\partial \lambda }}$

The hard part above is computing the term $\frac{d{𝜃}^{\star }}{\mathit{d\lambda }}\left({\lambda }_{t}\right)$ which requires differentiating through the training optimization (Eq. 2) which produced ${𝜃}^{\star }$. If the training optimization is also gradient descent, then we could consider the entire GD procedure as one giant forward pass, and backpropagate through the GD updates. However, if producing ${𝜃}^{\star }$ took thousands of GD steps this backpropagation can be both slow and memory intensive. Instead, the implicit function theorem (IFT) gives us an alternative to backpropagation: since we know ${𝜃}^{\star }\left({\lambda }_{t}\right)$ is an optimum of the training problem, we have:

 $\frac{\mathit{\partial f}}{\mathit{\partial 𝜃}}\left({𝜃}^{\star },{\lambda }_{t}\right)=0$

Now applying $\frac{d\cdot }{\mathit{d\lambda }}$ to the front, expanding by chain rule, and re-arranging, we have:

$\begin{array}{llll}\hfill 0& =\frac{d}{\mathit{d\lambda }}\frac{\mathit{\partial f}}{\mathit{\partial 𝜃}}\phantom{\rule{2em}{0ex}}& \hfill & \phantom{\rule{2em}{0ex}}\\ \hfill 0& =\frac{{\partial }^{2}f}{\partial {𝜃}^{2}}\frac{d{𝜃}^{\star }}{\mathit{d\lambda }}+\frac{{\partial }^{2}f}{\mathit{d\lambda d𝜃}}\phantom{\rule{2em}{0ex}}& \hfill & \phantom{\rule{2em}{0ex}}\\ \hfill \frac{d{𝜃}^{\star }}{\mathit{d\lambda }}& =-{\left[\frac{{\partial }^{2}f}{\partial {𝜃}^{2}}\right]}^{-1}\frac{{\partial }^{2}f}{\mathit{d\lambda d𝜃}}\phantom{\rule{2em}{0ex}}& \hfill & \phantom{\rule{2em}{0ex}}\end{array}$

We have assumed that the above Hessian will be invertible, and even if that’s true in practice it can be hard to compute exactly, so people often use approximations instead.

Actually, the above is the second part of IFT, so we have skipped a step. The first part of the implicit function theorem says that if there exists a point $\left({𝜃}^{\star },{\lambda }_{t}\right)$ and a smooth function $h$ satisfying $h\left({𝜃}^{\star },{\lambda }_{t}\right)=0$ and $\frac{\mathit{\partial h}}{\mathit{\partial 𝜃}}\left({𝜃}^{\star },{\lambda }_{t}\right)$ is invertible, then there exists some differentiable function $y\left(\lambda \right)$ from a neighborhood of ${\lambda }_{t}$ to a neighborhood of ${𝜃}^{\star }$ satisfying ${𝜃}^{\star }=y\left({\lambda }_{t}\right)$ and $h\left(y\left(\lambda \right),\lambda \right)=0$. I.e., around the point $\left({𝜃}^{\star },{\lambda }_{t}\right)$ we can write $𝜃$ as a function $y\left(\lambda \right)$, and we can differentiate $y\left(\lambda \right)$ implicitly as exampled above. In the calculations above we have simply used $h=\frac{\mathit{\partial f}}{\mathit{\partial 𝜃}}$ and overloaded notation to write ${𝜃}^{\star }\left(\lambda \right)$ instead of ${𝜃}^{\star }=y\left(\lambda \right)$.