quantax.utils.chunk_shard_vmap#

quantax.utils.chunk_shard_vmap(f: Callable, in_axes: tuple | int | None, out_axes: tuple | int | None, shard_axes: tuple | int | None = None, chunk_size: int | None = None) Callable#

f -> chunk_map(shard_map(vmap(f)), use_scan=True)

Parameters:
  • f – The function to be converted. The arguments of f will be sharded.

  • in_axes – The mapped axes of f input arguments.

  • out_axes – The mapped axes of f outputs.

  • shard_axes – The sharded axes of f input arguments. Default to in_axes.

  • chunk_size – The chunk size on each machine. If None, no chunking will be applied.