quantax.utils.filter_replicate#

quantax.utils.filter_replicate(tree: PyTree) PyTree#

Transform the arrays in pytree to be replicated on all devices. See get_replicate_sharding for the sharding.