Low-rank structure

Goal: to mitigate overparametrization of wide linear layers

We solve a very simple ODE: \[ f'(x)=2x, \hspace{2em} x\in[-1, 1] \] with initial condition \(f(0)=0\). The MLP is chosen to be \[ \mathcal{NN}(x;\theta):= \begin{bmatrix} \square \\ \vdots \\ \square \end{bmatrix}^\intercal \tanh \left( \begin{bmatrix} \square & \cdots & \square \\ \vdots & \ddots & \vdots \\ \square & \cdots & \square \end{bmatrix}\tanh \left( \begin{bmatrix} \square \\ \vdots \\ \square \end{bmatrix} x + \begin{bmatrix} \square \\ \vdots \\ \square \end{bmatrix} \right) + \begin{bmatrix} \square \\ \vdots \\ \square \end{bmatrix} \right), \] where width is chosen to be 100 (\(\dim\theta=10400\)). We choose the collocation points to be \[ x_1=-1, x_2=-0.99, x_3=-0.98, \cdots, x_{199}=0.98, x_{200}=0.99, x_{201}=1. \] Collocation points can be uniformly chosen from the ODE/PDE domain, and can be randomly selected.

The loss function is proposed as \[ L(\theta)= \lambda\cdot(\mathcal{NN}(0;\theta)-0)^2 + \sum_{j=1}^{201} [ \mathcal{NN}'(x_j;\theta) - 2x_j ]^2. \] \(\lambda\) is a tunable parameter to leverage the importances of data and physics. The optimization method here is chosen to be gradient-descent based method with momentum (Adam).

For each epoch, we find the singular value distribution of the square matrix of \(100\times100\) in the MLP

Theorem (Marchenko-Pastur, 67'). Let \(X_{ij}, 1\leq i\leq p, 1\leq j\leq n\) be independent random variables with \(\mathbb{E}[X_{ij}]=0\) and \(\mathbb{E}[X_{ij}^2]=1\). Denote by \(\lambda_{1}\geq\cdots\geq\lambda_{p}\geq0\) the eigenvalues of the symmetric matrix \(W_p:=\frac{1}{n} X X^\intercal\in\mathbb{R}^{p\times p}\). Then, when \(p\to\infty\) and \(n\to\infty\) with fixed ratio \(\frac{p}{n}\to\lambda\in(0,\infty)\), there is convergence in distribution \[ \frac{1}{p} \sum_{k=1}^{p} 1_{\{\lambda_k \leq x\}} \to F_{\lambda}(x) \hspace{1em} \forall x\in\mathbb{R}, \] and its distributional derivative is \[ F'_{\lambda}(x) = \frac{\sqrt{(b-x)(x-a)}}{2\pi \lambda x}, \] where \(a=\left(1-\sqrt{\lambda}\right)^2\) and \(b=\left(1+\sqrt{\lambda}\right)^2\).

So if we denote the singular values of \(\frac{1}{\sqrt{n}}X\in\mathbb{R}^{n\times n}\) as \(\sigma_1\geq\sigma_2\geq\cdots\geq\sigma_n\geq0\), then the distribution of them follows \[ \lim_{n\to\infty} \frac{1}{n}\sum_{k=1}^{n} 1_{\{\sigma_k\leq x\}} =\frac{x\sqrt{4-x^2}}{2\pi} + 1 - \frac{2}{\pi}\arctan\frac{\sqrt{4-x^2}}{x}, \hspace{1em} x\in(0,2). \]

Observing this feature, we use an alternative architecture, Kronecker neural network [Jagtap, Shin, Kawaguchi and Karniadakis (2022)] (rank=1) \[ \mathcal{KNN}(x;\theta):= \begin{bmatrix} \square \\ \vdots \\ \square \end{bmatrix}^\intercal \tanh \left( (A\otimes B)\tanh \left( \begin{bmatrix} \square \\ \vdots \\ \square \end{bmatrix} x + \begin{bmatrix} \square \\ \vdots \\ \square \end{bmatrix} \right) + \begin{bmatrix} \square \\ \vdots \\ \square \end{bmatrix} \right) \] to solve the same ODE. The Kronecker product is defined as \[ A\otimes B = \begin{bmatrix} a_{1,1} B & \cdots & a_{1,n}B \\ \vdots & \ddots & \vdots \\ a_{m,1}B & \cdots & a_{n,n}B \end{bmatrix} \]

This architecture is equivalent to replacing \(A\otimes B\) with \(U_1 V_1^\intercal\), where \(U,V\) are singular value decomposition (SVD) factors. The resulting architecture has \(\dim\theta=600\), yet preserving much of the representability of the full-rank MLP.