lrux.pf_lru_delayed#
- lrux.pf_lru_delayed(carrier: PfCarrier, u: Array | int | Tuple[Array, Array] | Tuple[Array, int], return_update: bool = False, current_delay: int | None = None) Array | Tuple[Array, PfCarrier] #
Delayed low-rank update of pfaffian.
- Parameters:
carrier –
The existing delayed update quantities, including \(A_0^{-1}\), \(R_t^{-1}\), and
\[a_t = A_{t-1}^{-1} u_t\]with \(t\) from 1 to \(\tau-1\). Initially provided by
init_pf_carrier
.u – Low-rank update vector(s) \(u_\tau\), the same as \(u\) in
lrux.det_lru
. The rank of u shouldn’t exceed the maximum allowed rank specified ininit_pf_carrier
.return_update – Whether the new carrier with updated quantities should be returned, defaul to False.
current_delay – The current iterations \(\tau\) of delayed updates, must be specified when
return_update
is True. As python starts counting at 0, the actual \(\tau\) value is given bycurrent_delay + 1
.
- Returns:
- ratio:
The ratio between two pfaffians
\[r_\tau = \frac{\mathrm{pf}(A_\tau)}{\mathrm{pf}(A_{\tau-1})} = \frac{\mathrm{pf}(R_\tau)}{\mathrm{pf}(J)}\]where
\[R_\tau = J + u_\tau^T A_0^{-1} u_\tau + \sum_{t=1}^{\tau-1} (u_\tau^T a_t) (a_t^T u_\tau)\]- new_carrier:
Only returned when
return_update
is True. The new carrier contains the quantities from the input carrier, and in addition \(R_\tau\) and\[a_\tau = A_{\tau-1}^{-1} u_\tau = A_0^{-1} u_\tau + \sum_{t=1}^{\tau-1} a_t R_t^{-1} (a_t^T u_\tau)\]
Warning
This function is only recommended for heavy users who understand why and when to use delayed updates. Otherwise, please choose
pf_lru
.Warning
When
current_delay
reaches the maximum delayed iteration, i.e.current_delay == max_delay - 1
, one should callmerge_pf_delays
to merge the delayed updates in the carrier, and reset the carrier for the next round. See the example below for details.Tip
Similar to
det_lru_delayed
andpf_lru
, this function is compatible withjax.jit
andjax.vmap
, whilereturn_update
andcurrent_delay
are static arguments which shouldn’t be jitted or vmapped.We still recommend setting
donate_argnums=0
injax.jit
to reuse the memory ofcarrier
if it’s no longer needed. For instance,lru_vmap = jax.vmap(pf_lru_delayed, in_axes=(0, 0, None, None)) lru_jit = jax.jit(lru_vmap, static_argnums=(2, 3), donate_argnums=0)
Here is a complete example of delayed updates.
import os os.environ["JAX_ENABLE_X64"] = "1" import random import jax import jax.numpy as jnp import jax.random as jr from lrux import skew_eye, pf, init_pf_carrier, merge_pf_delays, pf_lru_delayed def _get_key(): seed = random.randint(0, 2**31 - 1) return jr.key(seed) dtype = jnp.float64 n = 10 k = 2 max_delay = n // 2 A = jr.normal(_get_key(), (n, n), dtype) A = (A - A.T) / 2 carrier = init_pf_carrier(A, max_delay, k) pfA0 = pf(A) lru_fn = jax.jit(pf_lru_delayed, static_argnums=(2, 3), donate_argnums=0) merge_fn = jax.jit(merge_pf_delays, donate_argnums=0) for i in range(20): current_delay = i % max_delay ki = random.randint(0, k // 2) * 2 # ensure ki is even u = jr.normal(_get_key(), (n, ki), dtype) ratio, carrier = lru_fn(carrier, u, True, current_delay) if current_delay == max_delay - 1: carrier = merge_fn(carrier) J = skew_eye(ki // 2, dtype) A -= u @ J @ u.T pfA1 = pf(A) assert jnp.allclose(ratio, pfA1 / pfA0) pfA0 = pfA1