lrux.det_lru#
- lrux.det_lru(Ainv: Array, u: Array | int | Tuple[Array, Array] | Tuple[Array, int], v: Array | int | Tuple[Array, Array] | Tuple[Array, int], return_update: bool = False) Array | Tuple[Array, Array] #
Low-rank update of determinant \(\det(A_1) = \det(A_0 + vu^T)\).
- Parameters:
Ainv – Inverse of the original matrix \(A_0^{-1}\) with shape (n, n)
u – Low-rank update vector(s) \(u\). See the note below for details.
v – Low-rank update vector(s) \(v\). See the note below for details.
return_update – Whether the new matrix inverse \(A_1^{-1}\) should be returned, defaul to False.
- Returns:
- ratio:
The ratio between two determinants
\[r = \frac{\det(A_1)}{\det(A_0)} = \det(R)\]where
\[R = I + u^T A_0^{-1} v\]- new_Ainv:
The new matrix inverse
\[A_1^{-1} = (A_0 + vu^T)^{-1} = A_0^{-1} - A_0^{-1} v R^{-1} u^T A_0^{-1}\]Only returned when
return_update
is True.
Tip
This function is compatible with
jax.jit
andjax.vmap
, whilereturn_update
is a static argument which shouldn’t be jitted or vmapped.Furthermore, we recommend setting
donate_argnums=0
injax.jit
to reuse the memory ofAinv
if it’s no longer needed. This helps to greatly reduce the time and memory cost. For instance,lru_vmap = jax.vmap(det_lru, in_axes=(0, 0, 0, None)) lru_jit = jax.jit(lru_vmap, static_argnums=3, donate_argnums=0)
Note
We often need to define
u
andv
as one-hot vectors, for instance,v = jnp.array([0, 1, 0, 0])
. In this case, the matrix product likeAinv @ v
can be alternatively performed byAinv[:, 1]
to achieve great acceleration.Therefore, we allow
u
andv
to be provided not only by dense arrays, but also by one-hot indeces. The acceptable inputs are listed below.- An array with shape (n,) or (n, k):
Dense array of
u
orv
.- An integer or array of integers with size k:
One-hot indices. The full array is defined as
u_full = jnp.zeros((n, k)).at[u, jnp.arange(k)].set(1)
and similarly forv
. For example, when you need a full matrixu = jnp.array([ [0, 0], [1, 0], [0, 0], [0, 1], ])
you can alternatively specify
u = jnp.array([1, 3])
- A tuple of two arrays with respective shapes (n, k0) and (k1,):
A concatenation of the previous two. For example, when you need a full matrix
u = jnp.array([ [x00, x01, 0, 0], [x10, x11, 1, 0], [x20, x21, 0, 0], [x30, x31, 0, 1], ])
you can alternatively specify
x = jnp.array([ [x00, x01], [x10, x11], [x20, x21], [x30, x31], ]) e = jnp.array([1, 3]) u = (x, e)
As discussed in the note below, we usually need the one-hot vectors of
v
to be on the left of dense vectors. Therefore, when we similarly definev = (x, e)
, it actually represents a different arrayv = jnp.array([ [0, 0, x00, x01], [1, 0, x10, x11], [0, 0, x20, x21], [0, 1, x30, x31], ])
Example
Here are examples of how to define
u
andv
before callingdet_lru(Ainv, u, v)
. Keep in mind that the low-rank update we need takes the form\[A_1 - A_0 = vu^T\]Rank-1 row update
\[\begin{split}A_1 - A_0 = \begin{pmatrix} 0 & 0 & 0 & 0 \\ u_0 & u_1 & u_2 & u_3 \\ 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 \\ \end{pmatrix} = \begin{pmatrix} 0 \\ 1 \\ 0 \\ 0 \end{pmatrix} (u_0, u_1, u_2, u_3)\end{split}\]u = jnp.array([u0, u1, u2, u3]) v = 1
Rank-1 column update
\[\begin{split}A_1 - A_0 = \begin{pmatrix} 0 & 0 & v_0 & 0 \\ 0 & 0 & v_1 & 0 \\ 0 & 0 & v_2 & 0 \\ 0 & 0 & v_3 & 0 \\ \end{pmatrix} = \begin{pmatrix} v_0 \\ v_1 \\ v_2 \\ v_3 \end{pmatrix} (0, 0, 1, 0)\end{split}\]u = 2 v = jnp.array([v0, v1, v2, v3])
Rank-2 row update
\[\begin{split}A_1 - A_0 = \begin{pmatrix} 0 & 0 & 0 & 0 \\ u_{00} & u_{01} & u_{02} & u_{03} \\ 0 & 0 & 0 & 0 \\ u_{10} & u_{11} & u_{12} & u_{13} \\ \end{pmatrix} = \begin{pmatrix} 0 & 0 \\ 1 & 0 \\ 0 & 0 \\ 0 & 1 \end{pmatrix} \begin{pmatrix} u_{00} & u_{01} & u_{02} & u_{03} \\ u_{10} & u_{11} & u_{12} & u_{13} \\ \end{pmatrix}\end{split}\]u = jnp.array([[u00, u10], [u01, u11], [u02, u12], [u03, u13]]) v = jnp.array([1, 3])
Simultaneous update of row and column
\[\begin{split}A_1 - A_0 = \begin{pmatrix} 0 & 0 & v_0 & 0 \\ u_0 & u_1 & u_2 + v_1 & u_3 \\ 0 & 0 & v_2 & 0 \\ 0 & 0 & v_3 & 0 \\ \end{pmatrix} = \begin{pmatrix} 0 & v_0 \\ 1 & v_1 \\ 0 & v_2 \\ 0 & v_3 \end{pmatrix} \begin{pmatrix} u_0 & u_1 & u_2 & u_3 \\ 0 & 0 & 1 & 0 \\ \end{pmatrix}\end{split}\]u = (jnp.array([u0, u1, u2, u3]), 2) v = (jnp.array([v0, v1, v2, v3]), 1)