Implicit Differentiation for Hyperparameter Optimization

2021/02/06

Implicit differentiation is often used for hyperparameter optimization, to avoid backpropagating through an optimization process. A recent example can be found in this paper that applies the technique to learn, among other things, data augmentation that improves generalization. Define the following hyperparameter optimization problem, where Ξ» denotes hyperparameters and πœƒβ‹† denotes (optimized) parameters:

λ⋆ = arg ⁑ min ⁑ Ξ»g(πœƒβ‹†(Ξ»),Ξ») (1)

The optimized parameters πœƒβ‹† are output from another optimization, which is the training process. And the training problem should depend on the hyperparameters:

πœƒβ‹†(Ξ») = arg ⁑ min ⁑ πœƒf(πœƒ,Ξ») (2)

Notice we have separate losses: for example g is the validation loss, and f is the training loss. Now to solve the hyperparameter optimization, assume we are doing something like gradient descent (GD) on Ξ» with current iterate Ξ»t, so with learning rate Ξ· the update is:

Ξ»t+1 = Ξ»t βˆ’ Ξ·dg dΞ»(πœƒβ‹†(Ξ» t),Ξ»t)

Expanding the derivative via chain rule (and dropping arguments to reduce clutter):

dg dΞ» = βˆ‚g βˆ‚πœƒ dπœƒβ‹† dΞ» + βˆ‚g βˆ‚Ξ»

The hard part above is computing the term dπœƒβ‹† dΞ» (Ξ»t) which requires differentiating through the training optimization (Eq. 2) which produced πœƒβ‹†. If the training optimization is also gradient descent, then we could consider the entire GD procedure as one giant forward pass, and backpropagate through the GD updates. However, if producing πœƒβ‹† took thousands of GD steps this backpropagation can be both slow and memory intensive. Instead, the implicit function theorem (IFT) gives us an alternative to backpropagation: since we know πœƒβ‹†(Ξ»t) is an optimum of the training problem, we have:

βˆ‚f βˆ‚πœƒ (πœƒβ‹†,Ξ» t) = 0

Now applying dβ‹… dΞ» to the front, expanding by chain rule, and re-arranging, we have:

0 = d dΞ» βˆ‚f βˆ‚πœƒ 0 = βˆ‚2f βˆ‚πœƒ2 dπœƒβ‹† dΞ» + βˆ‚2f dΞ»dπœƒ dπœƒβ‹† dΞ» = βˆ’βˆ‚2f βˆ‚πœƒ2 βˆ’1 βˆ‚2f dΞ»dπœƒ

We have assumed that the above Hessian will be invertible, and even if that’s true in practice it can be hard to compute exactly, so people often use approximations instead.

Actually, the above is the second part of IFT, so we have skipped a step. The first part of the implicit function theorem says that if there exists a point (πœƒβ‹†,Ξ»t) and a smooth function h satisfying h(πœƒβ‹†,Ξ»t) = 0 and βˆ‚h βˆ‚πœƒ(πœƒβ‹†,Ξ» t) is invertible, then there exists some differentiable function y(Ξ») from a neighborhood of Ξ»t to a neighborhood of πœƒβ‹† satisfying πœƒβ‹† = y(Ξ»t) and h(y(Ξ»),Ξ») = 0. I.e., around the point (πœƒβ‹†,Ξ»t) we can write πœƒ as a function y(Ξ»), and we can differentiate y(Ξ») implicitly as exampled above. In the calculations above we have simply used h = βˆ‚f βˆ‚πœƒ and overloaded notation to write πœƒβ‹†(Ξ») instead of πœƒβ‹† = y(Ξ»).