lrux.merge_det_delays#
- lrux.merge_det_delays(carrier: DetCarrier) DetCarrier #
Merge the delayed updates in the carrier.
When \(\tau\) reaches the maximum delayed iterations \(T\) specified in
init_det_carrier
, i.e.current_delay == max_delay - 1
, the current \(A_\tau\) should be set as the new \(A_0\), whose inverse is given by\[A_\tau^{-1} = A_0^{-1} - \sum_{t=1}^\tau a_t b_t^T\]The
Ainv
innew_carrier
will be replaced by \(A_\tau^{-1}\), anda
andb
will be set to 0. See the example indet_lru_delayed
for details.Tip
This function is compatible with
jax.jit
andjax.vmap
. We recommend settingdonate_argnums=0
injax.jit
to reuse the memory ofcarrier
if it’s no longer needed. This helps to greatly reduce the time and memory cost. For instance,merge_vmap = jax.vmap(merge_det_delays) merge_jit = jax.jit(merge_vmap, donate_argnums=0)