quantax.utils.sharded_segment_sum#

quantax.utils.sharded_segment_sum(data: Array, segment_ids: Array, num_segments: int) Array#

Equivalent to jax.ops.segment_sum, but avoid data transfer among devices when data and segment_ids are both properly sharded.