quantax.utils.chunk_shard_vmap#
- quantax.utils.chunk_shard_vmap(f: Callable, in_axes: tuple | int | None, out_axes: tuple | int | None, chunk_size: int | None = None, shard_axes: tuple | int | None = None) Callable #
f -> jit(chunk_map(shard_map(vmap(f))))
- Parameters:
f – The function to be converted. The arguments of f will be sharded.
in_axes – The mapped axes of f input arguments.
out_axes – The mapped axes of f outputs.