quantax.state.UnrestrictedDetState#
- class quantax.state.UnrestrictedDetState#
Bases:
MeanFieldFermionStateUnrestricted determinant mean-field state, a wrapper of
UnrestrictedDet.- __init__(model: Module | None = None, param_file: str | Path | BinaryIO | None = None, max_parallel: None | int | Tuple[int, int] = None, use_refmodel: bool = True)#
- Parameters:
model – Variational model. Should be an
equinox.Module.param_file – File for loading parameters which is saved by
saveorequinox.tree_serialise_leaves, default to not loading parameters.symm – Symmetry of the network, default to
Identity. Denoting the network output as \(f(s)\) and symmetry elements as \(T_i\) with characters \(\omega_i\), the wave function is given by \(\psi(s) = \sum_i \omega_i \, f(T_i s) / n_{symm}\)max_parallel –
The maximum chunk size allowed per device. Specifying a limited value is important for avoiding memory overflow. For many hamiltonians, this also helps to improve the efficiency by keeping a constant amount of forward pass and avoiding re-jitting. The allowed input formats are:
- None:
No chunk size (default).
- int:
The same chunk size for all forward and backward passes.
- Tuple[int, int]:
The chunk size for forward and backward passes respectively.
- Tuple[int, int, int]:
The chunk size for forward pass, backward pass and
ref_forward_with_updatesrespectively.
use_ref – Whether
ref_forwardandref_forward_with_updateswill be used when the model is aRefModel. When the model is not aRefModel, this argument has no effect. Default toTrue.
- rho_from_model = Partial( func=_JitWrapper( fn='UnrestrictedDetState.rho_from_model', filter_warning=False, donate_first=False, donate_rest=False ), args=(quantax.state.fermion_mf.UnrestrictedDetState,), keywords={} )#
- property Nparticles: Tuple[int, int] | None#
Number of particle convervation of the state
- property Nsites: int#
Number of sites
- property backward_chunk: int#
The maximum chunk size of backward pass allowed per device.
- property basis#
Quspin basis of the state
- combine(params: Module, others: Module) Module#
Combine two pytrees, one containing all parameters and the other containing all other elements, into one variational model. This is similar to combine in Equinox.
- Parameters:
params – The pytree containing only parameters.
others – The pytree containing other elements.
- property dtype: dtype#
The parameter data type of the variational state.
- property energy: float | None#
The energy in the previous optimization step.
- expectation(operator: Operator) jax.Array#
Compute the expectation value of an operator.
- Parameters:
operator – The operator to compute the expectation value of. It should be an instance of
Operator.model – The mean-field model to use. If None, use the current model.
- property forward_chunk: int#
The maximum chunk size of forward pass allowed per device.
- get_loss_fn(hamiltonian: Operator)#
Get the loss function for optimization.
- Parameters:
hamiltonian – The Hamiltonian to compute the gradient of.
- get_params_flatten() Array#
Obtain a flattened 1D array of all parameters.
- get_params_unflatten(params: Array) PyTree#
Obtain the parameters pytree from a flattened 1D array of all parameters.
- get_step(hamiltonian: Operator) jax.Array#
Get the gradient of the energy with respect to the mean-field parameters.
- Parameters:
hamiltonian – The Hamiltonian to compute the gradient of.
- property holomorphic: bool#
Whether the variational state is holomorphic.
- init_internal(s: Array) PyTree#
Initialize the internal state of the model for the given input s.
- classmethod is_paired() bool#
Whether the state is a paired state (pfaffian) or not (determinant), default to False
- jacobian(fock_states: Array) Array#
Compute the jacobian matrix \(\frac{1}{\psi} \frac{\partial \psi}{\partial \theta}\). See
VS_TYPEfor the definition of jacobian for different kinds of networks.- Parameters:
fock_states – The input fock states.
- Returns:
A 2D jacobian matrix with the first dimension for different inputs and the second dimension for different parameters. The order of parameters are the same as
get_params_flatten.
- property model: Module#
The variational model used in the variational state.
- norm(ord: int | None = None) ndarray | Array | LogArray | ScaleArray#
Norm of state
- Parameters:
ord – Order of the norm, default to 2-norm \(\sqrt{\sum_s |\psi(s)|^2}\)
- property nparams: int#
Number of total parameters in the variational state.
- property nsymm: int#
Number of symmetry group elements
- partition(model: Module | None = None) Tuple[Module, Module]#
Split the variational model into two pytrees, one containing all parameters and the other containing all other elements, similar to partition in Equinox.
- Parameters:
model – The model to be splitted, default to be the variational model in the variational state.
- property ref_chunk: int#
The maximum chunk size of
ref_forward_with_updatesallowed per device.
- ref_forward(s: ndarray | Array, s_old: Array, nflips: int, idx_segment: Array, internal: PyTree) ndarray | Array | LogArray | ScaleArray#
Compute the forward pass given reference internal state of the model.
- Parameters:
s – Input states s with entries \(\pm 1\).
s_old – The old states before the updates, with entries \(\pm 1\).
nflips – The number of flips in the updates.
idx_segment – The indices of the segment to be updated, which is used to select the old states and internal.
internal – The internal state of the model, which is initialized by
init_internal.
- Returns:
The output wave function \(\psi(s)\).
- ref_forward_with_updates(s: ndarray | Array, s_old: Array, nflips: int, internal: PyTree) Tuple[ndarray | Array | LogArray | ScaleArray, PyTree]#
Compute the forward pass and updates given reference internal state of the model.
- Parameters:
s – Input states s with entries \(\pm 1\).
s_old – The old states before the updates, with entries \(\pm 1\).
nflips – The number of flips in the updates.
internal – The internal state of the model, which is initialized by
init_internal.
- Returns:
A tuple of the output wave function \(\psi(s)\) and the updated internal state of the model.
- save(file: str | Path | BinaryIO) None#
Save the variational model in the given file. This file can be used be loaded when initializing
Variational.
- to_flax_model(package='netket', make_complex: bool = False)#
Convert the state to a flax model compatible with other packages. Training the generated state in other packages is probably unstable, but the state can be used to measure observables.
- Parameters:
package –
Convert the current state to the format of the given package. The supported packages are
- netket (default)
input 1/-1, output \(\log\psi\)
- jvmc
input 1/0, output \(\log\psi\)
make_complex – Whether the network output should be made complex explicitly. This is necessary when \(\psi\) is real but contains negative values.
- todense(symm: Symmetry | None = None) DenseState#
Obtain the
quantax.state.DenseStatecorresponding to the current state- Parameters:
symm – The symmetry of the state, default to the current symmetry of the state
Warning
Users are responsible to ensure that the state satisfies the given
symm.
- update(step: Array) None#
Update the variational parameters of the state as \(\theta' = \theta - \delta\theta\).
- Parameters:
step – The update step \(\delta\theta\).
Note
The update direction is \(-\delta\theta\) instead of \(\delta\theta\).
- property use_ref: bool#
Whether to use reference implementation for updates