lrux.pf_lru_delayed#

lrux.pf_lru_delayed(carrier: PfCarrier, u: Array | int | Tuple[Array, Array] | Tuple[Array, int], return_update: bool = False, current_delay: int | None = None) Array | Tuple[Array, PfCarrier]#

Delayed low-rank update of pfaffian.

Parameters:
  • carrier

    The existing delayed update quantities, including \(A_0^{-1}\), \(R_t^{-1}\), and

    \[a_t = A_{t-1}^{-1} u_t\]

    with \(t\) from 1 to \(\tau-1\). Initially provided by init_pf_carrier.

  • u – Low-rank update vector(s) \(u_\tau\), the same as \(u\) in lrux.det_lru. The rank of u shouldn’t exceed the maximum allowed rank specified in init_pf_carrier.

  • return_update – Whether the new carrier with updated quantities should be returned, defaul to False.

  • current_delay – The current iterations \(\tau\) of delayed updates, must be specified when return_update is True. As python starts counting at 0, the actual \(\tau\) value is given by current_delay + 1.

Returns:

ratio:

The ratio between two pfaffians

\[r_\tau = \frac{\mathrm{pf}(A_\tau)}{\mathrm{pf}(A_{\tau-1})} = \frac{\mathrm{pf}(R_\tau)}{\mathrm{pf}(J)}\]

where

\[R_\tau = J + u_\tau^T A_0^{-1} u_\tau + \sum_{t=1}^{\tau-1} (u_\tau^T a_t) (a_t^T u_\tau)\]
new_carrier:

Only returned when return_update is True. The new carrier contains the quantities from the input carrier, and in addition \(R_\tau\) and

\[a_\tau = A_{\tau-1}^{-1} u_\tau = A_0^{-1} u_\tau + \sum_{t=1}^{\tau-1} a_t R_t^{-1} (a_t^T u_\tau)\]

Warning

This function is only recommended for heavy users who understand why and when to use delayed updates. Otherwise, please choose pf_lru.

Warning

When current_delay reaches the maximum delayed iteration, i.e. current_delay == max_delay - 1, one should call merge_pf_delays to merge the delayed updates in the carrier, and reset the carrier for the next round. See the example below for details.

Tip

Similar to det_lru_delayed and pf_lru, this function is compatible with jax.jit and jax.vmap, while return_update and current_delay are static arguments which shouldn’t be jitted or vmapped.

We still recommend setting donate_argnums=0 in jax.jit to reuse the memory of carrier if it’s no longer needed. For instance,

lru_vmap = jax.vmap(pf_lru_delayed, in_axes=(0, 0, None, None))
lru_jit = jax.jit(lru_vmap, static_argnums=(2, 3), donate_argnums=0)

Here is a complete example of delayed updates.

import os
os.environ["JAX_ENABLE_X64"] = "1"

import random
import jax
import jax.numpy as jnp
import jax.random as jr
from lrux import skew_eye, pf, init_pf_carrier, merge_pf_delays, pf_lru_delayed

def _get_key():
    seed = random.randint(0, 2**31 - 1)
    return jr.key(seed)

dtype = jnp.float64
n = 10
k = 2
max_delay = n // 2
A = jr.normal(_get_key(), (n, n), dtype)
A = (A - A.T) / 2
carrier = init_pf_carrier(A, max_delay, k)
pfA0 = pf(A)

lru_fn = jax.jit(pf_lru_delayed, static_argnums=(2, 3), donate_argnums=0)
merge_fn = jax.jit(merge_pf_delays, donate_argnums=0)

for i in range(20):
    current_delay = i % max_delay
    ki = random.randint(0, k // 2) * 2  # ensure ki is even
    u = jr.normal(_get_key(), (n, ki), dtype)
    ratio, carrier = lru_fn(carrier, u, True, current_delay)

    if current_delay == max_delay - 1:
        carrier = merge_fn(carrier)

    J = skew_eye(ki // 2, dtype)
    A -= u @ J @ u.T
    pfA1 = pf(A)
    assert jnp.allclose(ratio, pfA1 / pfA0)
    pfA0 = pfA1