quantax.utils.sharded_segment_sum#
- quantax.utils.sharded_segment_sum(data: Array, segment_ids: Array, num_segments: int) Array #
Equivalent to
jax.ops.segment_sum
, but avoid data transfer among devices whendata
andsegment_ids
are both properly sharded.