lrux.det_lru#

lrux.det_lru(Ainv: Array, u: Array | int | Tuple[Array, Array] | Tuple[Array, int], v: Array | int | Tuple[Array, Array] | Tuple[Array, int], return_update: bool = False) Array | Tuple[Array, Array]#

Low-rank update of determinant \(\det(A_1) = \det(A_0 + vu^T)\).

Parameters:
  • Ainv – Inverse of the original matrix \(A_0^{-1}\) with shape (n, n)

  • u – Low-rank update vector(s) \(u\). See the note below for details.

  • v – Low-rank update vector(s) \(v\). See the note below for details.

  • return_update – Whether the new matrix inverse \(A_1^{-1}\) should be returned, defaul to False.

Returns:

ratio:

The ratio between two determinants

\[r = \frac{\det(A_1)}{\det(A_0)} = \det(R)\]

where

\[R = I + u^T A_0^{-1} v\]
new_Ainv:

The new matrix inverse

\[A_1^{-1} = (A_0 + vu^T)^{-1} = A_0^{-1} - A_0^{-1} v R^{-1} u^T A_0^{-1}\]

Only returned when return_update is True.

Tip

This function is compatible with jax.jit and jax.vmap, while return_update is a static argument which shouldn’t be jitted or vmapped.

Furthermore, we recommend setting donate_argnums=0 in jax.jit to reuse the memory of Ainv if it’s no longer needed. This helps to greatly reduce the time and memory cost. For instance,

lru_vmap = jax.vmap(det_lru, in_axes=(0, 0, 0, None))
lru_jit = jax.jit(lru_vmap, static_argnums=3, donate_argnums=0)

Note

We often need to define u and v as one-hot vectors, for instance, v = jnp.array([0, 1, 0, 0]). In this case, the matrix product like Ainv @ v can be alternatively performed by Ainv[:, 1] to achieve great acceleration.

Therefore, we allow u and v to be provided not only by dense arrays, but also by one-hot indeces. The acceptable inputs are listed below.

An array with shape (n,) or (n, k):

Dense array of u or v.

An integer or array of integers with size k:

One-hot indices. The full array is defined as u_full = jnp.zeros((n, k)).at[u, jnp.arange(k)].set(1) and similarly for v. For example, when you need a full matrix

u = jnp.array([
    [0, 0],
    [1, 0], 
    [0, 0], 
    [0, 1],
])

you can alternatively specify

u = jnp.array([1, 3])
A tuple of two arrays with respective shapes (n, k0) and (k1,):

A concatenation of the previous two. For example, when you need a full matrix

u = jnp.array([
    [x00, x01, 0, 0],
    [x10, x11, 1, 0], 
    [x20, x21, 0, 0], 
    [x30, x31, 0, 1],
])

you can alternatively specify

x = jnp.array([
    [x00, x01],
    [x10, x11], 
    [x20, x21], 
    [x30, x31],
])
e = jnp.array([1, 3])
u = (x, e)

As discussed in the note below, we usually need the one-hot vectors of v to be on the left of dense vectors. Therefore, when we similarly define v = (x, e), it actually represents a different array

v = jnp.array([
    [0, 0, x00, x01],
    [1, 0, x10, x11], 
    [0, 0, x20, x21], 
    [0, 1, x30, x31],
])

Example

Here are examples of how to define u and v before calling det_lru(Ainv, u, v). Keep in mind that the low-rank update we need takes the form

\[A_1 - A_0 = vu^T\]

Rank-1 row update

\[\begin{split}A_1 - A_0 = \begin{pmatrix} 0 & 0 & 0 & 0 \\ u_0 & u_1 & u_2 & u_3 \\ 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 \\ \end{pmatrix} = \begin{pmatrix} 0 \\ 1 \\ 0 \\ 0 \end{pmatrix} (u_0, u_1, u_2, u_3)\end{split}\]
u = jnp.array([u0, u1, u2, u3])
v = 1

Rank-1 column update

\[\begin{split}A_1 - A_0 = \begin{pmatrix} 0 & 0 & v_0 & 0 \\ 0 & 0 & v_1 & 0 \\ 0 & 0 & v_2 & 0 \\ 0 & 0 & v_3 & 0 \\ \end{pmatrix} = \begin{pmatrix} v_0 \\ v_1 \\ v_2 \\ v_3 \end{pmatrix} (0, 0, 1, 0)\end{split}\]
u = 2
v = jnp.array([v0, v1, v2, v3])

Rank-2 row update

\[\begin{split}A_1 - A_0 = \begin{pmatrix} 0 & 0 & 0 & 0 \\ u_{00} & u_{01} & u_{02} & u_{03} \\ 0 & 0 & 0 & 0 \\ u_{10} & u_{11} & u_{12} & u_{13} \\ \end{pmatrix} = \begin{pmatrix} 0 & 0 \\ 1 & 0 \\ 0 & 0 \\ 0 & 1 \end{pmatrix} \begin{pmatrix} u_{00} & u_{01} & u_{02} & u_{03} \\ u_{10} & u_{11} & u_{12} & u_{13} \\ \end{pmatrix}\end{split}\]
u = jnp.array([[u00, u10], [u01, u11], [u02, u12], [u03, u13]])
v = jnp.array([1, 3])

Simultaneous update of row and column

\[\begin{split}A_1 - A_0 = \begin{pmatrix} 0 & 0 & v_0 & 0 \\ u_0 & u_1 & u_2 + v_1 & u_3 \\ 0 & 0 & v_2 & 0 \\ 0 & 0 & v_3 & 0 \\ \end{pmatrix} = \begin{pmatrix} 0 & v_0 \\ 1 & v_1 \\ 0 & v_2 \\ 0 & v_3 \end{pmatrix} \begin{pmatrix} u_0 & u_1 & u_2 & u_3 \\ 0 & 0 & 1 & 0 \\ \end{pmatrix}\end{split}\]
u = (jnp.array([u0, u1, u2, u3]), 2)
v = (jnp.array([v0, v1, v2, v3]), 1)