global_defs#
- quantax.set_default_dtype(dtype: dtype) None#
Set the default data type for the computation in quantax. Recommended to be
jnp.float64orjnp.complex128. Default tojnp.float64.Note
This doesn’t alter the computation inside
quantax.model.
- quantax.get_default_dtype() dtype#
Return the default data type for the computation in quantax.
- quantax.set_random_seed(seed: int) None#
Set the initial random seed for the computation in jax. Default to be a random number generated by numpy.
- quantax.get_subkeys(num: int | None = None) Array#
Get a certain number of jax PRNG keys as an array.
- Parameters:
num – The number of returned keys. If
num is None(default), then return only 1 key instead of an array of keys.
Warning
This function is not jittable, because it reads and writes the global key stored in quantax.
- quantax.get_sites() Sites#
Get the
quantax.sites.Sitesused in the current program.Warning
Unlike other NQS packages, in quantax the geometry graph and the hilbert space is defined as a global constant which shouldn’t be changed within a single program.
- quantax.get_lattice() Lattice#
Get the
quantax.sites.Latticeused in the current program. This is similar toget_sites, but will raise an error if the definedSitesis not aLattice.