lrux.det_lru_delayed#
- lrux.det_lru_delayed(carrier: DetCarrier, u: Array | int | Tuple[Array, Array] | Tuple[Array, int], v: Array | int | Tuple[Array, Array] | Tuple[Array, int], return_update: bool = False, current_delay: int | None = None) Array | Tuple[Array, DetCarrier] #
Delayed low-rank update of determinant.
- Parameters:
carrier –
The existing delayed update quantities, including \(A_0^{-1}\), and
\[a_t = A_{t-1}^{-1} v_t\]\[b_t = (A_{t-1}^{-1})^T u_t\]with \(t\) from 1 to \(\tau-1\). Initially provided by
init_det_carrier
.u – Low-rank update vector(s) \(u_\tau\), the same as \(u\) in
lrux.det_lru
. The rank of u shouldn’t exceed the maximum allowed rank specified ininit_det_carrier
.v – Low-rank update vector(s) \(v_\tau\), the same as \(v\) in
lrux.det_lru
. The rank of v shouldn’t exceed the maximum allowed rank specified ininit_det_carrier
.return_update – Whether the new carrier with updated quantities should be returned, defaul to False.
current_delay – The current iterations \(\tau\) of delayed updates, must be specified when
return_update
is True. As python starts counting at 0, the actual \(\tau\) value is given bycurrent_delay + 1
.
- Returns:
- ratio:
The ratio between two determinants
\[r_\tau = \frac{\det(A_\tau)}{\det(A_{\tau-1})} = \det(R_\tau)\]where
\[R_\tau = I + u_\tau^T A_0^{-1} v_\tau - \sum_{t=1}^{\tau-1} (u_\tau^T a_t) (b_t^T v_\tau)\]- new_carrier:
Only returned when
return_update
is True. The new carrier contains the quantities from the input carrier, and in addition\[a_\tau = A_{\tau-1}^{-1} v_\tau = A_0^{-1} v_\tau - \sum_{t=1}^{\tau-1} a_t (b_t^T v_\tau)\]\[b_\tau = (A_{\tau-1}^{-1})^T u_\tau = (A_0^{-1})^T u_\tau - \sum_{t=1}^{\tau-1} b_t (a_t^T u_\tau)\]
Warning
This function is only recommended for heavy users who understand why and when to use delayed updates. Otherwise, please choose
det_lru
.Warning
When
current_delay
reaches the maximum delayed iteration, i.e.current_delay == max_delay - 1
, one should callmerge_det_delays
to merge the delayed updates in the carrier, and reset the carrier for the next round. See the example below for details.Tip
Similar to
det_lru
, this function is compatible withjax.jit
andjax.vmap
, whilereturn_update
andcurrent_delay
are static arguments which shouldn’t be jitted or vmapped.We still recommend setting
donate_argnums=0
injax.jit
to reuse the memory ofcarrier
if it’s no longer needed. For instance,lru_vmap = jax.vmap(det_lru_delayed, in_axes=(0, 0, 0, None, None)) lru_jit = jax.jit(lru_vmap, static_argnums=(3, 4), donate_argnums=0)
Here is a complete example of delayed updates.
import os os.environ["JAX_ENABLE_X64"] = "1" import random import jax import jax.numpy as jnp import jax.random as jr from lrux import init_det_carrier, merge_det_delays, det_lru_delayed def _get_key(): seed = random.randint(0, 2**31 - 1) return jr.key(seed) dtype = jnp.float64 n = 10 max_delay = n // 2 max_rank = 2 A = jr.normal(_get_key(), (n, n), dtype) carrier = init_det_carrier(A, max_delay, max_rank) detA0 = jnp.linalg.det(A) lru_fn = jax.jit(det_lru_delayed, static_argnums=(3, 4), donate_argnums=0) merge_fn = jax.jit(merge_det_delays, donate_argnums=0) for i in range(20): current_delay = i % max_delay k = random.randint(0, max_rank) u = jr.normal(_get_key(), (n, k), dtype) v = jr.normal(_get_key(), (n, k), dtype) ratio, carrier = lru_fn(carrier, u, v, True, current_delay) if current_delay == max_delay - 1: carrier = merge_fn(carrier) # verify the low-rank update result A += v @ u.T detA1 = jnp.linalg.det(A) assert jnp.allclose(ratio, detA1 / detA0) detA0 = detA1