quantax.nn.ScaleFn#
- class quantax.nn.ScaleFn#
Bases:
NoGradLayer
Apply a function to a rescaled input \(f(x) = fn(x * \mathrm{scale})\). The scale is automatically computed from the function to ensure that \(\sigma (\sum \log |f(x)|) = 0.1 \sqrt{N}\) when \(\sigma(x) = 1\) and the system has N sites.
This is particularly helpful for the stability of networks when \(\psi = \prod f(x)\), for instance the RBM.
Note
No matter which input data type is provided, the output data type is always given by
quantax.get_default_dtype
.- __init__(fn: ~typing.Callable, features: int, scaling: float = 1.0, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>)#
- Parameters:
fn – The activation function to be applied.
features – The size of input x, not considering the batch dimension.
scaling – Additional scaling factor to apply on the input to rescale the inputs to \(\sigma(x) = 1\).
dtype – The data type of inputs.
Attributes
fn
scale