Implicit differentiation is often used for hyperparameter optimization, to avoid backpropagating through an optimization process. A recent example can be found in this paper that applies the technique to learn, among other things, data augmentation that improves generalization. Define the following hyperparameter optimization problem, where denotes hyperparameters and denotes (optimized) parameters:
The optimized parameters are output from another optimization, which is the training process. And the training problem should depend on the hyperparameters:
Notice we have separate losses: for example is the validation loss, and is the training loss. Now to solve the hyperparameter optimization, assume we are doing something like gradient descent (GD) on with current iterate , so with learning rate the update is:
Expanding the derivative via chain rule (and dropping arguments to reduce clutter):
The hard part above is computing the term which requires differentiating through the training optimization (Eq. 2) which produced . If the training optimization is also gradient descent, then we could consider the entire GD procedure as one giant forward pass, and backpropagate through the GD updates. However, if producing took thousands of GD steps this backpropagation can be both slow and memory intensive. Instead, the implicit function theorem (IFT) gives us an alternative to backpropagation: since we know is an optimum of the training problem, we have:
Now applying to the front, expanding by chain rule, and re-arranging, we have:
We have assumed that the above Hessian will be invertible, and even if that’s true in practice it can be hard to compute exactly, so people often use approximations instead.
Actually, the above is the second part of IFT, so we have skipped a step. The first part of the implicit function theorem says that if there exists a point and a smooth function satisfying and is invertible, then there exists some differentiable function from a neighborhood of to a neighborhood of satisfying and . I.e., around the point we can write as a function , and we can differentiate implicitly as exampled above. In the calculations above we have simply used and overloaded notation to write instead of .