lrux.pf_lru#

lrux.pf_lru(Ainv: Array, u: Array | int | Tuple[Array, Array] | Tuple[Array, int], return_update: bool = False) Array | Tuple[Array, Array]#

Low-rank update of pfaffian \(\mathrm{pf}(A_1) = \mathrm{pf}(A_0 - u J u^T)\). Here \(J\) is the skew-symmetric identity matrix.

\[\begin{split}J = \begin{pmatrix} 0 & I \\ -I & 0 \end{pmatrix}\end{split}\]

as given in skew_eye.

Parameters:
  • Ainv – Inverse of the original skew-symmetric matrix \(A_0^{-1}\), shape (n, n)

  • u – Low-rank update vector(s) \(u\), the same as \(u\) in lrux.det_lru.

  • return_update – Whether the new matrix inverse \(A_1^{-1}\) should be returned, defaul to False.

Returns:

ratio:

The ratio between two pfaffians

\[r = \frac{\mathrm{pf}(A_1)}{\mathrm{pf}(A_0)} = \frac{\mathrm{pf}(R)}{\mathrm{pf}(J)}\]

where

\[R = J + u^T A_0^{-1} u\]
new_Ainv:

The new matrix inverse

\[A_1^{-1} = (A_0 + u J u^T)^{-1} = A_0^{-1} + (A_0^{-1} u) R^{-1} (A_0^{-1} u)^T\]

Only returned when return_update is True.

Tip

This function is compatible with jax.jit and jax.vmap, while return_update is a static argument which shouldn’t be jitted or vmapped.

Furthermore, we recommend setting donate_argnums=0 in jax.jit to reuse the memory of Ainv if it’s no longer needed. This helps to greatly reduce the time and memory cost. For instance,

lru_vmap = jax.vmap(pf_lru, in_axes=(0, 0, None))
lru_jit = jax.jit(lru_vmap, static_argnums=2, donate_argnums=0)

Example

Here are examples of how to define u before calling pf_lru(Ainv, u). Keep in mind that the low-rank update we need should be skew-symmetric and takes the form

\[A_1 - A_0 = -u J u^T\]

Update of 1 row and 1 column

\[\begin{split}A_1 - A_0 = \begin{pmatrix} 0 & -u_0 & 0 & 0 \\ u_0 & 0 & u_2 & u_3 \\ 0 & -u_2 & 0 & 0 \\ 0 & -u_3 & 0 & 0 \\ \end{pmatrix} = -\begin{pmatrix} u_0 & 0 \\ u_1 & 1 \\ u_2 & 0 \\ u_3 & 0 \end{pmatrix} \begin{pmatrix} 0 & 1 \\ -1 & 0 \end{pmatrix} \begin{pmatrix} u_0 & u_1 & u_2 & u_3 \\ 0 & 1 & 0 & 0\\ \end{pmatrix}\end{split}\]
u = (jnp.array([u0, u1, u2, u3]), 1)

Update of 2 rows and 2 columns

\[\begin{split}\begin{split} A_1 - A_0 &= \begin{pmatrix} 0 & -u_{00} & 0 & -u_{10} \\ u_{00} & 0 & u_{02} & u_{03} - u_{11} \\ 0 & -u_{02} & 0 & -u_{12} \\ u_{10} & u_{11} - u_{03} & u_{12} & 0 \\ \end{pmatrix} \\ &= -\begin{pmatrix} u_{00} & u_{10} & 0 & 0 \\ u_{01} & u_{11} & 1 & 0 \\ u_{02} & u_{12} & 0 & 0 \\ u_{03} & u_{13} & 0 & 1 \\ \end{pmatrix} \begin{pmatrix} 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 1 \\ -1 & 0 & 0 & 0 \\ 0 & -1 & 0 & 0 \\ \end{pmatrix} \begin{pmatrix} u_{00} & u_{01} & u_{02} & u_{03} \\ u_{10} & u_{11} & u_{12} & u_{13} \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 0 & 1 \\ \end{pmatrix} \end{split}\end{split}\]
x = jnp.array([[u00, u10], [u01, u11], [u02, u12], [u03, u13]])
e = jnp.array([1, 3])
u = (x, e)