utils#

Data#

DataTracer

The structure used to keep track of the data updates

Sharding#

get_global_sharding

Return the sharding that distributes arrays across all devices in jax.devices() in the array's first dimension.

get_replicate_sharding

Return the sharding that replicates arrays across all devices in jax.devices().

Array#

is_sharded_array

Whether the input array is sharded.

to_global_array

Transform the array to be sharded across all devices in the first dimension.

to_replicate_array

Transform the array to be replicated across all devices.

global_to_local

In multi-host jobs, use jax.experimental.multihost_utils.global_array_to_host_local_array to transform a sharded array to be local on each device.

local_to_global

In multi-host jobs, use jax.experimental.multihost_utils.host_local_array_to_global_array to transform local arrays to be sharded.

local_to_replicate

In multi-host jobs, use jax.experimental.multihost_utils.host_local_array_to_global_array to transform local arrays to be replicated on each device.

to_replicate_numpy

In multi-host jobs, use jax.experimental.multihost_utils.global_array_to_host_local_array to transform a sharded array to be replicated numpy arrays on each device.

array_extend

Extend the array.

array_set

Equivalent to array.at[inds].set(array_set), but significantly faster for complex-valued inputs.

sharded_segment_sum

Equivalent to jax.ops.segment_sum, but avoid data transfer among devices when data and segment_ids are both properly sharded.

Spins#

ints_to_array

Converts QuSpin basis integers to int8 state array

array_to_ints

Converts int8 state array to QuSpin basis integers

neel

Return a single neel state

stripe

Return a single stripe state

Sqz_factor

Spin structure factor \(\left< \frac{1}{2 \sqrt{N}} S^z_r S^z_0 e^{-iqr} \right>\)

rand_states

Random basis states.

Pytree#

tree_fully_flatten

Return the array given by jax.flatten_util.ravel_pytree

filter_replicate

Transform the pytree to be replicated on all devices.

filter_tree_map

The same as jax.tree.map but with filter, which means the map only applies to arrays.

tree_split_cpl

Split a pytree potentially with complex values to two real pytrees, one for the real part and the other for the imaginary part

tree_combine_cpl

Combine two real pytrees to a complex one.

apply_updates

Similar to equinox.apply_updates, but the original data type of the model is kept unchanged.

Function#

chunk_shard_vmap

f -> jit(chunk_map(shard_map(vmap(f))))

chunk_map

Convert a vmapped function to a function with chunked batches and parallel computation on all available machines.

shmap

f -> shard_map(f), sharded along the first dimension

Linear algebra#

det

The same as jax.numpy.linalg.det, but with a customized vjp to accelerate gradients

pfaffian

Return pfaffian of the input matrix A.

det_update_rows

det_update_gen

pfa_eye

pfa_update

Applies a low-rank update to a pfaffian matrix, returning the ratio between the updated pfaffian and the old pfaffian as well as the inverse of the update orbitals