lrux.merge_det_delays#
- lrux.merge_det_delays(carrier: DetCarrier) DetCarrier#
Merge the delayed updates in the carrier.
When \(\tau\) reaches the maximum delayed iterations \(T\) specified in
init_det_carrier, i.e.current_delay == max_delay - 1, the current \(A_\tau\) should be set as the new \(A_0\), whose inverse is given by\[A_\tau^{-1} = A_0^{-1} - \sum_{t=1}^\tau a_t b_t^T\]The
Ainvinnew_carrierwill be replaced by \(A_\tau^{-1}\), andaandbwill be set to 0. See the example indet_lru_delayedfor details.Tip
This function is compatible with
jax.jitandjax.vmap. We recommend settingdonate_argnums=0injax.jitto reuse the memory ofcarrierif it’s no longer needed. This helps to greatly reduce the time and memory cost. For instance,merge_vmap = jax.vmap(merge_det_delays) merge_jit = jax.jit(merge_vmap, donate_argnums=0)