quantax.utils.shmap#

quantax.utils.shmap(f: Callable, in_axes: tuple | int | None, out_axes: tuple | int | None) Callable#

f -> shard_map(f), sharded along the first dimension

Parameters:
  • f – The function to be converted. The arguments of f will be sharded.

  • in_axes – The sharded axes of f input arguments.

  • out_axes – The sharded axes of f outputs.