quantax.state.Variational#
- class quantax.state.Variational#
Bases:
State
Variational state. This is a wrapper of a jittable variational ansatz. The variational model should be given as an
equinox.Module
. For details of Equinox, see this documentation.Warning
There are many intermediate values stored in the class, so most functions like
__call__
in this class are non-jittable.Warning
Many quantities are only computed once in the initialization. Please don’t change the private attributes unless an update function is provided. One can define a new state if some changes are necessary.
- __init__(model: Module, param_file: str | Path | BinaryIO | None = None, symm: Symmetry | None = None, max_parallel: None | int | Tuple[int, int] = None, factor: Callable | None = None)#
- Parameters:
model – Variational model. Should be an
equinox.Module
.param_file – File for loading parameters which is saved by
save
orequinox.tree_serialise_leaves
, default to not loading parameters.symm – Symmetry of the network, default to
Identity
.max_parallel – The maximum foward pass allowed per device, default to no limit. Specifying a limited value is important for large batches to avoid memory overflow. For Heisenberg-like hamiltonian, this also helps to improve the efficiency of computing local energy by keeping constant amount of forward pass and avoiding re-jitting.
factor – The additional factor multiplied on the network outputs. The parameters in this factor won’t be updated together with the variational state, which is useful for expressing some fixed sign structures.
Denoting the network output and factor output as \(f(s)\) and \(g(s)\), and symmetry elements as \(T_i\) with characters \(\omega_i\), the final wave function is given by \(\psi(s) = \sum_i \omega_i \, f(T_i s) g(T_i s) / n_{symm}\)
- __call__(s: ndarray | Array) Array #
Compute \(\psi(s)\) for input states s.
- Parameters:
s – Input states s with entries \(\pm 1\).
Warning
This function is not jittable.
Note
The returned value is \(\psi(s)\) instead of \(\log\psi(s)\).
- property model: Module#
The variational model used in the variational state.
- property holomorphic: bool#
Whether the variational state is holomorphic.
- property forward_chunk: int#
The maximum foward pass allowed per device.
- property backward_chunk: int#
The maximum backward pass allowed per device.
- property ref_chunk: int#
The maximum reference forward with updates allowed per device.
- property nparams: int#
Number of total parameters in the variational state.
- property dtype: dtype#
The parameter data type of the variational state.
- jacobian(fock_states: Array) Array #
Compute the jacobian matrix \(\frac{1}{\psi} \frac{\partial \psi}{\partial \theta}\). See
VS_TYPE
for the definition of jacobian for different kinds of networks.- Parameters:
fock_states – The input fock states.
- Returns:
A 2D jacobian matrix with the first dimension for different inputs and the second dimension for different parameters. The order of parameters are the same as
get_params_flatten
.
- partition(model: Module | None = None) Tuple[Module, Module] #
Split the variational model into two pytrees, one containing all parameters and the other containing all other elements, similar to partition in Equinox.
- Parameters:
model – The model to be splitted, default to be the variational model in the variational state.
- combine(params: Module, others: Module) Module #
Combine two pytrees, one containing all parameters and the other containing all other elements, into one variational model. This is similar to combine in Equinox.
- Parameters:
params – The pytree containing only parameters.
others – The pytree containing other elements.
- get_params_flatten() Array #
Obtain a flattened 1D array of all parameters.
- get_params_unflatten(params: Array) PyTree #
Obtain the parameters pytree from a flattened 1D array of all parameters.
- rescale(factor: Array | ndarray | bool | number | bool | int | float | complex | None = None) None #
Rescale the variational state according to the maximum wave function stored during the forward pass.
Note
This only works if there is a
rescale
function in the given model, which exists in most models provided inquantax.model
.Overflow is very likely to happen in the training of variational states if there is no
rescale
function in the model.
- update(step: Array, rescale: bool = True) None #
Update the variational parameters of the state as \(\theta' = \theta - \delta\theta\).
- Parameters:
step – The update step \(\delta\theta\).
rescale – Whether the
rescale
function should be called to rescale the variational state, default toTrue
.
Note
The update direction is \(-\delta\theta\) instead of \(\delta\theta\).
- save(file: str | Path | BinaryIO) None #
Save the variational model in the given file. This file can be used be loaded when initializing
Variational
.
- to_flax_model(package='netket', make_complex: bool = False)#
Convert the state to a flax model compatible with other packages. Training the generated state in other packages is probably unstable, but the state can be used to measure observables.
- Parameters:
package –
Convert the current state to the format of the given package. The supported packages are
- netket (default)
input 1/-1, output \(\log\psi\)
- jvmc
input 1/0, output \(\log\psi\)
make_complex – Whether the network output should be made complex explicitly. This is necessary when \(\psi\) is real but contains negative values.