quantax.utils.filter_replicate#
- quantax.utils.filter_replicate(tree: PyTree) PyTree#
Transform the arrays in pytree to be replicated on all devices. See
get_replicate_shardingfor the sharding.
Transform the arrays in pytree to be replicated on all devices.
See get_replicate_sharding for the sharding.