quantax.nn.ScaleFn#

class quantax.nn.ScaleFn#

Bases: NoGradLayer

Apply a function to a rescaled input \(f(x) = fn(x * \mathrm{scale})\). The scale is automatically computed from the function to ensure that \(\sigma (\sum \log |f(x)|) = 0.1 \sqrt{N}\) when \(\sigma(x) = 1\) and the system has N sites.

This is particularly helpful for the stability of networks when \(\psi = \prod f(x)\), for instance the RBM.

Note

No matter which input data type is provided, the output data type is always given by quantax.get_default_dtype.

__init__(fn: ~typing.Callable, features: int, scaling: float = 1.0, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>)#
Parameters:
  • fn – The activation function to be applied.

  • features – The size of input x, not considering the batch dimension.

  • scaling – Additional scaling factor to apply on the input to rescale the inputs to \(\sigma(x) = 1\).

  • dtype – The data type of inputs.

Attributes

fn

scale