quantax.utils.get_global_sharding#

quantax.utils.get_global_sharding() NamedSharding#

Return the sharding that distributes arrays across all devices in jax.devices() in the array’s first dimension.