quantax.utils.chunk_map#
- quantax.utils.chunk_map(f: Callable, in_axes: Tuple | int | None = 0, out_axes: Tuple | int | None = 0, chunk_size: int | None = None, use_scan: bool = False) Callable #
Convert a vmapped function to a function with chunked batches and parallel computation on all available machines. The arguments will be unchanged if the batch size on each machine is smaller than the chunk size, but it will be padded with 0 if the batch size is larger than the chunk size and not a multiple of chunk size.
- Parameters:
f – The function to be converted. The arguments of f are assumed to be sharded.
in_axes – The vmapped axes of f which are to be chunked.
out_axes – The vmapped axes of outputs.
chunk_size – The chunk size on each machine.
use_scan – Whether to use
jax.lax.scan
in chunked function apply. The compilation will be accerlerated ifscan
is used, but the function must be jittable.