quantax.utils.get_replicate_sharding#

quantax.utils.get_replicate_sharding() NamedSharding#

Return the sharding that replicates arrays across all devices in jax.devices().