lrux.merge_pf_delays#
- lrux.merge_pf_delays(carrier: PfCarrier) PfCarrier#
Merge the delayed updates in the carrier.
When \(\tau\) reaches the maximum delayed iterations \(T\) specified in
init_pf_carrier, i.e.current_delay == max_delay - 1, the current \(A_\tau\) should be set as the new \(A_0\), whose inverse is given by\[A_\tau^{-1} = A_0^{-1} + \sum_{t=1}^\tau a_t R_t^{-1} a_t^T\]new_carrier.Ainvwill be replaced by \(A_\tau^{-1}\), andaandRinvwill be set to 0. See the example inpf_lru_delayedfor details.Tip
This function is compatible with
jax.jitandjax.vmap. We recommend settingdonate_argnums=0injax.jitto reuse the memory ofcarrierif it’s no longer needed. This helps to greatly reduce the time and memory cost. For instance,merge_vmap = jax.vmap(merge_pf_delays) merge_jit = jax.jit(merge_vmap, donate_argnums=0)