global_defs#

In global_defs, users can define and get the global constants used in the simulation, including datatypes, random keys, and Hilbert space information.

quantax.set_default_dtype(dtype: dtype) None#

Set the default data type for the computation in quantax. Recommended to be jnp.float64 or jnp.complex128. Default to jnp.float64.

Note

This doesn’t alter the computation inside quantax.model.

quantax.get_default_dtype() dtype#

Return the default data type for the computation in quantax.

quantax.get_real_dtype() dtype#

Return the default real data type for the computation in quantax.

quantax.set_random_seed(seed: int) None#

Set the initial random seed for the computation in jax. Default to be a random number generated by numpy.

quantax.get_subkeys(num: int | None = None) Array#

Get a certain number of jax PRNG keys as an array.

Parameters:

num – The number of returned keys. If num is None (default), then return only 1 key instead of an array of keys.

Warning

This function is not jittable, because it reads and writes the global key stored in quantax.

quantax.PARTICLE_TYPE()#

The enums to distinguish different particle types.

0: spin

1: spinful_fermion

2: spinless_fermion

(Not implemented) 3: boson

quantax.get_sites() Sites#

Get the quantax.sites.Sites used in the current program.

Warning

Unlike other NQS packages, in quantax the geometry graph and the hilbert space is defined as a global constant which shouldn’t be changed within a single program.

quantax.get_lattice() Lattice#

Get the quantax.sites.Lattice used in the current program. This is similar to get_sites, but will raise an error if the defined Sites is not a Lattice.