lrux.pf_lru#
- lrux.pf_lru(Ainv: Array, u: Array | int | Tuple[Array, Array] | Tuple[Array, int], return_update: bool = False) Array | Tuple[Array, Array] #
Low-rank update of pfaffian \(\mathrm{pf}(A_1) = \mathrm{pf}(A_0 - u J u^T)\). Here \(J\) is the skew-symmetric identity matrix.
\[\begin{split}J = \begin{pmatrix} 0 & I \\ -I & 0 \end{pmatrix}\end{split}\]as given in
skew_eye
.- Parameters:
Ainv – Inverse of the original skew-symmetric matrix \(A_0^{-1}\), shape (n, n)
u – Low-rank update vector(s) \(u\), the same as \(u\) in
lrux.det_lru
.return_update – Whether the new matrix inverse \(A_1^{-1}\) should be returned, defaul to False.
- Returns:
- ratio:
The ratio between two pfaffians
\[r = \frac{\mathrm{pf}(A_1)}{\mathrm{pf}(A_0)} = \frac{\mathrm{pf}(R)}{\mathrm{pf}(J)}\]where
\[R = J + u^T A_0^{-1} u\]- new_Ainv:
The new matrix inverse
\[A_1^{-1} = (A_0 + u J u^T)^{-1} = A_0^{-1} + (A_0^{-1} u) R^{-1} (A_0^{-1} u)^T\]Only returned when
return_update
is True.
Tip
This function is compatible with
jax.jit
andjax.vmap
, whilereturn_update
is a static argument which shouldn’t be jitted or vmapped.Furthermore, we recommend setting
donate_argnums=0
injax.jit
to reuse the memory ofAinv
if it’s no longer needed. This helps to greatly reduce the time and memory cost. For instance,lru_vmap = jax.vmap(pf_lru, in_axes=(0, 0, None)) lru_jit = jax.jit(lru_vmap, static_argnums=2, donate_argnums=0)
Example
Here are examples of how to define
u
before callingpf_lru(Ainv, u)
. Keep in mind that the low-rank update we need should be skew-symmetric and takes the form\[A_1 - A_0 = -u J u^T\]Update of 1 row and 1 column
\[\begin{split}A_1 - A_0 = \begin{pmatrix} 0 & -u_0 & 0 & 0 \\ u_0 & 0 & u_2 & u_3 \\ 0 & -u_2 & 0 & 0 \\ 0 & -u_3 & 0 & 0 \\ \end{pmatrix} = -\begin{pmatrix} u_0 & 0 \\ u_1 & 1 \\ u_2 & 0 \\ u_3 & 0 \end{pmatrix} \begin{pmatrix} 0 & 1 \\ -1 & 0 \end{pmatrix} \begin{pmatrix} u_0 & u_1 & u_2 & u_3 \\ 0 & 1 & 0 & 0\\ \end{pmatrix}\end{split}\]u = (jnp.array([u0, u1, u2, u3]), 1)
Update of 2 rows and 2 columns
\[\begin{split}\begin{split} A_1 - A_0 &= \begin{pmatrix} 0 & -u_{00} & 0 & -u_{10} \\ u_{00} & 0 & u_{02} & u_{03} - u_{11} \\ 0 & -u_{02} & 0 & -u_{12} \\ u_{10} & u_{11} - u_{03} & u_{12} & 0 \\ \end{pmatrix} \\ &= -\begin{pmatrix} u_{00} & u_{10} & 0 & 0 \\ u_{01} & u_{11} & 1 & 0 \\ u_{02} & u_{12} & 0 & 0 \\ u_{03} & u_{13} & 0 & 1 \\ \end{pmatrix} \begin{pmatrix} 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 1 \\ -1 & 0 & 0 & 0 \\ 0 & -1 & 0 & 0 \\ \end{pmatrix} \begin{pmatrix} u_{00} & u_{01} & u_{02} & u_{03} \\ u_{10} & u_{11} & u_{12} & u_{13} \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 0 & 1 \\ \end{pmatrix} \end{split}\end{split}\]x = jnp.array([[u00, u10], [u01, u11], [u02, u12], [u03, u13]]) e = jnp.array([1, 3]) u = (x, e)