lrux.merge_pf_delays#
- lrux.merge_pf_delays(carrier: PfCarrier) PfCarrier #
Merge the delayed updates in the carrier.
When \(\tau\) reaches the maximum delayed iterations \(T\) specified in
init_pf_carrier
, i.e.current_delay == max_delay - 1
, the current \(A_\tau\) should be set as the new \(A_0\), whose inverse is given by\[A_\tau^{-1} = A_0^{-1} + \sum_{t=1}^\tau a_t R_t^{-1} a_t^T\]new_carrier.Ainv
will be replaced by \(A_\tau^{-1}\), anda
andRinv
will be set to 0. See the example inpf_lru_delayed
for details.Tip
This function is compatible with
jax.jit
andjax.vmap
. We recommend settingdonate_argnums=0
injax.jit
to reuse the memory ofcarrier
if it’s no longer needed. This helps to greatly reduce the time and memory cost. For instance,merge_vmap = jax.vmap(merge_pf_delays) merge_jit = jax.jit(merge_vmap, donate_argnums=0)