lrux.merge_pf_delays#

lrux.merge_pf_delays(carrier: PfCarrier) PfCarrier#

Merge the delayed updates in the carrier.

When \(\tau\) reaches the maximum delayed iterations \(T\) specified in init_pf_carrier, i.e. current_delay == max_delay - 1, the current \(A_\tau\) should be set as the new \(A_0\), whose inverse is given by

\[A_\tau^{-1} = A_0^{-1} + \sum_{t=1}^\tau a_t R_t^{-1} a_t^T\]

new_carrier.Ainv will be replaced by \(A_\tau^{-1}\), and a and Rinv will be set to 0. See the example in pf_lru_delayed for details.

Tip

This function is compatible with jax.jit and jax.vmap. We recommend setting donate_argnums=0 in jax.jit to reuse the memory of carrier if it’s no longer needed. This helps to greatly reduce the time and memory cost. For instance,

merge_vmap = jax.vmap(merge_pf_delays)
merge_jit = jax.jit(merge_vmap, donate_argnums=0)