lrux.merge_det_delays#

lrux.merge_det_delays(carrier: DetCarrier) DetCarrier#

Merge the delayed updates in the carrier.

When \(\tau\) reaches the maximum delayed iterations \(T\) specified in init_det_carrier, i.e. current_delay == max_delay - 1, the current \(A_\tau\) should be set as the new \(A_0\), whose inverse is given by

\[A_\tau^{-1} = A_0^{-1} - \sum_{t=1}^\tau a_t b_t^T\]

The Ainv in new_carrier will be replaced by \(A_\tau^{-1}\), and a and b will be set to 0. See the example in det_lru_delayed for details.

Tip

This function is compatible with jax.jit and jax.vmap. We recommend setting donate_argnums=0 in jax.jit to reuse the memory of carrier if it’s no longer needed. This helps to greatly reduce the time and memory cost. For instance,

merge_vmap = jax.vmap(merge_det_delays)
merge_jit = jax.jit(merge_vmap, donate_argnums=0)