quantax.utils.filter_replicate#

quantax.utils.filter_replicate(tree: PyTree) PyTree#

Transform the pytree to be replicated on all devices. Filter means the transformation only applies to arrays. See get_replicate_sharding for the sharding.