Ensemble Kalman Filter [Kovachki, Stuart (2019)]

Idea: use differences to approximate the gradient \(\nabla_\theta L\)

Instead of calculating the gradient of loss at a given parameter \(\nabla_\theta L\) to update \(\theta\), we introduce Ensemble Kalman Filter (EnKF). Let \(\theta\in\mathbb{R}^d\) be the parameter vector in an ANN and the objective is to minimize a loss function associated with solving PDE using ANN. EnKF will sample \[ \theta^{(1)}, \cdots, \theta^{(J)} \stackrel{i.i.d}{\sim} \mathcal{N}(\theta, \Sigma) \] where \(\theta^{(j)}\) is called a particle and \(J\) is the ensemble size.

Notably, \(\theta^{(j)}\) always have the same dimension as \(\theta\). The loss function will be evaluated at all collocation points in the PDE domain of interest for each particle \(\theta^{(j)}\) as the ANN parameter,

The loss function will be evaluated at all collocation points in the PDE domain of interest for each particle \(\theta^{(j)}\) as the ANN parameter, i.e. \[ \mathcal{G}\left(\{\theta^{(j)}\}_{j=1}^{J}\right)= \begin{bmatrix} L(\theta^{(1)}; (t,x)) & L(\theta^{(1)}; (t,x)) & \cdots & L(\theta^{(1)}; (t,x)) & L(\theta^{(1)}; (t,x)) \\ L(\theta^{(2)}; (t,x)) & L(\theta^{(2)}; (t,x)) & \cdots & L(\theta^{(2)}; (t,x)) & L(\theta^{(2)}; (t,x)) \\ \vdots & \vdots & \ddots & \vdots & \vdots \\ L(\theta^{(J-1)}; (t,x)) & L(\theta^{(J-1)}; (t,x)) & \cdots & L(\theta^{(J-1)}; (t,x)) & L(\theta^{(J-1)}; (t,x)) \\ L(\theta^{(J)}; (t,x)) & L(\theta^{(J)}; (t,x)) & \cdots & L(\theta^{(J)}; (t,x)) & L(\theta^{(J)}; (t,x)) \end{bmatrix}. \]

We denote the row of this matrix forward-mapping \(\mathcal{G}(\cdot):\mathbb{R}^{d}\to\mathbb{R}^{N}\), so the \(j\)-th row of this matrix is \(\mathcal{G}(\theta^{(j)})^\intercal\). We use the transpose \( ^\intercal\) notation here because we view row vectors different from column vectors, and by default, an element in \(\mathbb{R}^N\) is viewed as a column vector. Using the convention that upper bar denotes the mean, \[ \bar{\mathcal{G}} := \frac{1}{J}\sum_{j=1}^{J} \mathcal{G}(\theta^{(j)}). \]

The most important matrix \(D(\theta)\in\mathbb{R}^{J\times J}\) is defined elementwise, whose \(k\)-th row and \(j\)-th column is \[ [D(\theta)]_{k,j} := \frac{1}{J} \left(\mathcal{G}(\theta^{(k)})-\bar{\mathcal{G}} \right)^\intercal \mathcal{G}(\theta^{(j)}). \]

The update rule for each particle \(\theta^{(j)}\) is \[ \theta^{(j)} \leftarrow \theta^{(j)} - h D(\theta)\theta^{(j)}, \] where \(h\) is a scalar that plays the same role as learning rate in gradient-based optimization methods, and a typical choice is \[ h = \frac{1}{\epsilon+ \|D(\theta)\|_F }, \] where \(\|\cdot\|_F\) denotes the Frobenius norm.