quantax.utils.get_distribute_sharding#

quantax.utils.get_distribute_sharding() NamedSharding#

Return the sharding that distributes arrays across all devices in jax.devices() in the array’s first dimension.