quantax.state.MultiPfState#

class quantax.state.MultiPfState#

Bases: MeanFieldFermionState

Multi-Pfaffian mean-field state, a wrapper of MultiPf.

__init__(model: Module | None = None, param_file: str | Path | BinaryIO | None = None, max_parallel: None | int | Tuple[int, int] = None, use_refmodel: bool = True)#
Parameters:
  • model – Variational model. Should be an equinox.Module.

  • param_file – File for loading parameters which is saved by save or equinox.tree_serialise_leaves, default to not loading parameters.

  • symm – Symmetry of the network, default to Identity. Denoting the network output as \(f(s)\) and symmetry elements as \(T_i\) with characters \(\omega_i\), the wave function is given by \(\psi(s) = \sum_i \omega_i \, f(T_i s) / n_{symm}\)

  • max_parallel

    The maximum chunk size allowed per device. Specifying a limited value is important for avoiding memory overflow. For many hamiltonians, this also helps to improve the efficiency by keeping a constant amount of forward pass and avoiding re-jitting. The allowed input formats are:

    • None:

      No chunk size (default).

    • int:

      The same chunk size for all forward and backward passes.

    • Tuple[int, int]:

      The chunk size for forward and backward passes respectively.

    • Tuple[int, int, int]:

      The chunk size for forward pass, backward pass and ref_forward_with_updates respectively.

  • use_ref – Whether ref_forward and ref_forward_with_updates will be used when the model is a RefModel. When the model is not a RefModel, this argument has no effect. Default to True.

classmethod is_paired() bool#

Whether the state is a paired state (pfaffian) or not (determinant)

expectation(operator: Operator, model: MultiPf | None = None) jax.Array#

Compute the expectation value of an operator.

Parameters:
  • operator – The operator to compute the expectation value of. It should be an instance of Operator.

  • model – The mean-field model to use. If None, use the current model.

property Nparticles: Tuple[int, int] | None#

Number of particle convervation of the state

property Nsites: int#

Number of sites

property backward_chunk: int#

The maximum chunk size of backward pass allowed per device.

property basis#

Quspin basis of the state

combine(params: Module, others: Module) Module#

Combine two pytrees, one containing all parameters and the other containing all other elements, into one variational model. This is similar to combine in Equinox.

Parameters:
  • params – The pytree containing only parameters.

  • others – The pytree containing other elements.

property dtype: dtype#

The parameter data type of the variational state.

property energy: float | None#

The energy in the previous optimization step.

property forward_chunk: int#

The maximum chunk size of forward pass allowed per device.

get_loss_fn(hamiltonian: Operator)#

Get the loss function for optimization.

Parameters:

hamiltonian – The Hamiltonian to compute the gradient of.

get_params_flatten() Array#

Obtain a flattened 1D array of all parameters.

get_params_unflatten(params: Array) PyTree#

Obtain the parameters pytree from a flattened 1D array of all parameters.

get_step(hamiltonian: Operator) jax.Array#

Get the gradient of the energy with respect to the mean-field parameters.

Parameters:

hamiltonian – The Hamiltonian to compute the gradient of.

property holomorphic: bool#

Whether the variational state is holomorphic.

init_internal(s: Array) PyTree#

Initialize the internal state of the model for the given input s.

jacobian(fock_states: Array) Array#

Compute the jacobian matrix \(\frac{1}{\psi} \frac{\partial \psi}{\partial \theta}\). See VS_TYPE for the definition of jacobian for different kinds of networks.

Parameters:

fock_states – The input fock states.

Returns:

A 2D jacobian matrix with the first dimension for different inputs and the second dimension for different parameters. The order of parameters are the same as get_params_flatten.

property model: Module#

The variational model used in the variational state.

norm(ord: int | None = None) ndarray | Array | LogArray | ScaleArray#

Norm of state

Parameters:

ord – Order of the norm, default to 2-norm \(\sqrt{\sum_s |\psi(s)|^2}\)

property nparams: int#

Number of total parameters in the variational state.

property nsymm: int#

Number of symmetry group elements

partition(model: Module | None = None) Tuple[Module, Module]#

Split the variational model into two pytrees, one containing all parameters and the other containing all other elements, similar to partition in Equinox.

Parameters:

model – The model to be splitted, default to be the variational model in the variational state.

property ref_chunk: int#

The maximum chunk size of ref_forward_with_updates allowed per device.

ref_forward(s: ndarray | Array, s_old: Array, nflips: int, idx_segment: Array, internal: PyTree) ndarray | Array | LogArray | ScaleArray#

Compute the forward pass given reference internal state of the model.

Parameters:
  • s – Input states s with entries \(\pm 1\).

  • s_old – The old states before the updates, with entries \(\pm 1\).

  • nflips – The number of flips in the updates.

  • idx_segment – The indices of the segment to be updated, which is used to select the old states and internal.

  • internal – The internal state of the model, which is initialized by init_internal.

Returns:

The output wave function \(\psi(s)\).

ref_forward_with_updates(s: ndarray | Array, s_old: Array, nflips: int, internal: PyTree) Tuple[ndarray | Array | LogArray | ScaleArray, PyTree]#

Compute the forward pass and updates given reference internal state of the model.

Parameters:
  • s – Input states s with entries \(\pm 1\).

  • s_old – The old states before the updates, with entries \(\pm 1\).

  • nflips – The number of flips in the updates.

  • internal – The internal state of the model, which is initialized by init_internal.

Returns:

A tuple of the output wave function \(\psi(s)\) and the updated internal state of the model.

rho_from_model = Partial(   func=_JitWrapper(     fn='MeanFieldFermionState.rho_from_model',     filter_warning=False,     donate_first=False,     donate_rest=False   ),   args=(quantax.state.fermion_mf.MultiPfState,),   keywords={} )#
save(file: str | Path | BinaryIO) None#

Save the variational model in the given file. This file can be used be loaded when initializing Variational.

property symm: Symmetry#

Symmetry of the state

to_flax_model(package='netket', make_complex: bool = False)#

Convert the state to a flax model compatible with other packages. Training the generated state in other packages is probably unstable, but the state can be used to measure observables.

Parameters:
  • package

    Convert the current state to the format of the given package. The supported packages are

    netket (default)

    input 1/-1, output \(\log\psi\)

    jvmc

    input 1/0, output \(\log\psi\)

  • make_complex – Whether the network output should be made complex explicitly. This is necessary when \(\psi\) is real but contains negative values.

todense(symm: Symmetry | None = None) DenseState#

Obtain the quantax.state.DenseState corresponding to the current state

Parameters:

symm – The symmetry of the state, default to the current symmetry of the state

Warning

Users are responsible to ensure that the state satisfies the given symm.

update(step: Array) None#

Update the variational parameters of the state as \(\theta' = \theta - \delta\theta\).

Parameters:

step – The update step \(\delta\theta\).

Note

The update direction is \(-\delta\theta\) instead of \(\delta\theta\).

property use_ref: bool#

Whether to use reference implementation for updates

property vs_type: VS_TYPE#

The type of variational state.