Reparameterization as preconditioning


Consider a single linear layer in a neural network as a function \(f : \mathbb{R}^n \longrightarrow \mathbb{R}^m\) parameterized by a weight matrix \(W \in \mathbb{R}^{m \times n}\):

\(\displaystyle f (x) = Wx, \quad x \in \mathbb{R}^n\)

We wish to optimize the weight matrix to minimize some loss function \(\mathcal{L} (W)\). We can modify the standard gradient descent approach with a preconditioning matrix \(P \in \mathbb{R}^{mn \times mn}\):

\(\displaystyle \operatorname{vec} (W^{(t + 1)}) =\operatorname{vec} (W^{(t)}) - P\operatorname{vec} \left( \frac{\partial \mathcal{L}}{\partial W^{(t)}} \right)\) (1)

The \(\operatorname{vec} (\ldots)\) operator simply stacks the columns of a matrix into a vector, which we need to do to the gradient before multiplying by the preconditioning matrix.

Notice that \(P\) has \((mn)^2\) entries which for a modern neural network layer can be quite large. There are various approaches for approximating the preconditioning matrix. Here we will focus on one such approximation that “implicitly” implements preconditioning by reparameterizing the layer. In particular, we will construct the weight matrix as the product of three matrices:

\(\displaystyle W^{(t)} = TU^{(t)} V, \quad T \in \mathbb{R}^{m \times m}, U_t \in \mathbb{R}^{m \times n}, V \in \mathbb{R}^{n \times n}\)

Notice that only \(U^{(t)}\) has the superscript indexing optimization step because \(U^{(t)}\) is the only parameter we are optimizing here–the matrices \(T\) and \(V\) may be estimated from training statistics [1] or learned in the outer loop of some meta-learning procedure [2]. \(W^{(t)}\) has a superscript only because the updates to \(U^{(t)}\) will also change \(W^{(t)}\), but \(W^{(t)}\) is not itself a parameter now.

We will see that \(T\) and \(V\) implicitly act as preconditioning for our gradient descent procedure. Recall that we are minimizing some loss \(\mathcal{L} (W)\), then we can get the derivative with respect to \(U\) in terms of \(\frac{\partial \mathcal{L}}{\partial W}\) (see below for details):

\(\displaystyle \frac{\mathrm{d} \mathcal{L}}{\mathrm{d} U} = T^T \frac{\partial \mathcal{L}}{\partial W} V^T\)

Now consider running standard gradient descent on \(U\) (no preconditioning). For simplicity, assume the learning rate is \(1\). The \(U\) update is:

\(\displaystyle U^{(t + 1)} = U^{(t)} - \frac{\mathrm{d} \mathcal{L}}{\mathrm{d} U^{(t)}} = U^{(t)} - T^T \frac{\partial \mathcal{L}}{\partial W^{(t)}} V^T\)

The update to \(U\) gives us the update to \(W\):

\begin{eqnarray*} W^{(t + 1)} & _{} = & TU^{(t + 1)} V\\ & = & T \left( U^{(t)} - T^T \frac{\partial \mathcal{L}}{\partial W^{(t)}} V^T \right) V\\ & = & W^{(t)} - (TT^T) \frac{\partial \mathcal{L}}{\partial W^{(t)}} (V^T V) \hspace{3cm} \text{(2)} \end{eqnarray*}

So we can see that \(W\) is updated by its gradient, but multiplied on the left and right sides by \(TT^T \) and \(V^T V\) .

Now we want to vectorize this update to compare with Eq. 1. We can do this using the Kronecker product of matrices, denoted “\(\otimes\)”. For matrices \(A, X, B\) with appropriate dimensions:

\(\displaystyle \operatorname{vec} (AXB) = (B^T \otimes A) \operatorname{vec} (X)\)

Applying the above trick to the update in Eq. 2, we obtain:

\(\displaystyle \operatorname{vec} (W^{(t + 1)}) =\operatorname{vec} (W^{(t)}) - [(V^T V) \otimes (TT^T)] \operatorname{vec} \left( \frac{\partial \mathcal{L}}{\partial W^{(t)}} \right)\)

Comparing this with Eq. 1, we see that this is a special case of preconditioned gradient descent with Kronecker factored preconditioning matrix:

\(\displaystyle P = (V^T V) \otimes (TT^T)\)

1Derivative of \(\mathcal{L}\) with respect to U

For some loss function \(\mathcal{L} (W)\) where the weights are reparameterized \(W (U) = TUV\), we obtain \(\frac{d\mathcal{L}}{d U}\) in terms of \(\frac{\partial \mathcal{L}}{\partial W}\) by inspecting the components. For simplicity we will be using Einstein notation throughout. First, we look at the derivative with respect to the \(s p\)'th entry of \(U\) using the chain rule:

\(\displaystyle \frac{\mathrm{d} \mathcal{L}}{\mathrm{d} U_{s p}} = \frac{\partial \mathcal{L}}{\partial W_{i j}} \frac{\partial W_{i j}}{\partial U_{s p}}\)

From the definition \(W = TUV\) we have that:

\(\displaystyle W_{i j} = T_{i k} U_{k l} V_{l j} \quad \Longrightarrow \quad \frac{\partial W_{i j}}{\partial U_{s p}} = T_{i s} V_{p j}\)

Substituting this back into the previous expression:

\(\displaystyle \frac{\mathrm{d} \mathcal{L}}{\mathrm{d} U_{s p}} = \frac{\partial \mathcal{L}}{\partial W_{i j}} T_{i s} V_{p j}\)

By observing that \(\left( \frac{\partial \mathcal{L}}{\partial W} \right)_{i j} = \frac{\partial \mathcal{L}}{\partial W_{i j}}\) by definition, we realize the above is simply the component-wise way of writing:

\(\displaystyle \frac{\mathrm{d} \mathcal{L}}{\mathrm{d} U} = T^T \frac{\partial \mathcal{L}}{\partial W} V^T\)


[1] Guillaume Desjardins, Karen Simonyan, Razvan Pascanu et al. Natural neural networks. In Advances in neural information processing systems, pages 2071–2079. 2015.

[2] Yoonho Lee and Seungjin Choi. Gradient-based meta-learning with learned layerwise metric and subspace. ArXiv preprint arXiv:1801.05558, 2018.