quantax.model.ResConv#

class quantax.model.ResConv#

Deep convolutional residual network.

__init__(nblocks: int, channels: int, kernel_size: int | ~typing.Sequence[int], final_activation: ~typing.Callable[[~jax.Array], ~numpy.ndarray | ~jax.Array | ~quantax.utils.big_array.LogArray | ~quantax.utils.big_array.ScaleArray] | None = None, trans_symm: ~quantax.symmetry.symmetry.Symmetry | None = None, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, out_dtype: ~numpy.dtype | None = None)#

The convolutional residual network with a summation in the end.

Parameters:
  • nblocks – The number of residual blocks. Each block contains two convolutional layers.

  • channels – The number of channels. Each layer has the same amount of channels.

  • kernel_size – The kernel size. Each layer has the same kernel size.

  • final_activation – The activation function in the last layer. By default, exp_by_scale is used.

  • trans_symm – The translation symmetry to be applied in the last layer, see ConvSymmetrize.

  • dtype – The data type of the parameters. Must be a real dtype.

  • out_dtype – The data type of the output wavefunction. By default, it is the same as dtype. If out_dtype is complex, pair_cpl will be applied to the output of convolutional layers to make the final output complex.

Tip

This is the recommended architecture for deep NQS in spin systems.