Build your network#
Variational wavefunctions in Quantax are built on Equinox. In this tutorial, we will introduce how to build your network on Equinox, and how to test the performance of your network in Quantax.
Reference:
Equinox: neural networks in JAX via callable PyTrees and filtered transformations
import jax
import jax.numpy as jnp
import jax.random as jr
import quantax as qtx
import matplotlib.pyplot as plt
L = 8
Equinox quick start#
Equinox is a minimalist neural network library built directly on top of JAX. Unlike Flax and Haiku, which come with higher-level abstractions, Equinox emphasizes flexibility and transparency: everything is just PyTrees and functions, making it easy to integrate with raw JAX code. This means you don’t have to fight against the framework when experimenting with unconventional architectures, physics-inspired models, or custom training loops.
A customized Equinox network is usually like this.
import equinox as eqx
class Linear(eqx.Module):
weight: jax.Array
bias: jax.Array
def __init__(self, in_size, out_size, key):
wkey, bkey = jr.split(key)
self.weight = jr.normal(wkey, (out_size, in_size))
self.bias = jr.normal(bkey, (out_size,))
def __call__(self, x):
return self.weight @ x + self.bias
key = jr.key(0)
linear = Linear(2, 3, key)
print(linear)
Linear(weight=f64[3,2], bias=f64[3])
eqx.Module has two important properties.
It’s a PyTree. In this example,
weightandbiasare leaves on this PyTree.
print("weight:", linear.weight)
print("bias:", linear.bias)
vals, treedef = jax.tree.flatten(linear)
print("Flattened leaves:", vals)
weight: [[ 1.88002989 -0.48121497]
[ 0.41545723 2.38184008]
[-0.57536705 -0.37054353]]
bias: [-1.4008841 1.432145 0.6248107]
Flattened leaves: [Array([[ 1.88002989, -0.48121497],
[ 0.41545723, 2.38184008],
[-0.57536705, -0.37054353]], dtype=float64), Array([-1.4008841, 1.432145 , 0.6248107], dtype=float64)]
It’s callable, since the
__call__method is defined in this object.
inputs = jnp.array([1.0, 2.0])
outputs = linear(inputs)
jitted_fn = jax.jit(lambda linear, x: linear(x)) # `linear` is jittable as it's a PyTree
jitted_outputs = jitted_fn(linear, inputs)
jacobian = jax.jacrev(jitted_fn)(linear, inputs)
print("outputs:", outputs)
print("jitted outputs:", jitted_outputs)
print("jacobian:", jacobian)
outputs: [-0.48328416 6.61128238 -0.69164342]
jitted outputs: [-0.48328416 6.61128238 -0.69164342]
jacobian: Linear(weight=f64[3,3,2], bias=f64[3,3])
Apart from that, Equinox provides convenient filtered functions like filter_jit for PyTree. It’s similar to jax.jit but available for PyTree with non-jittable leaves. See the code below for example. In Quantax, we use these filtered functions for better flexibility.
def summation(l):
return sum(jnp.sum(x) for x in l if isinstance(x, jax.Array))
l = [1, 2.0, jnp.array([1.0, 2.0]), "string", jnp.array([3.0, 4.0])]
try:
out = jax.jit(summation)(l)
except TypeError as e:
print("`jax.jit` failed due to non-jittable string data type in the list.")
out = eqx.filter_jit(summation)(l)
print("eqx.filter_jit successful, output: ", out)
try:
g = jax.grad(summation)(l)
except TypeError as e:
print("`jax.grad` failed due to non-jittable string data type in the list.")
g = eqx.filter_grad(summation)(l)
print("eqx.filter_grad successful, gradient: ", g)
`jax.jit` failed due to non-jittable string data type in the list.
eqx.filter_jit successful, output: 10.0
`jax.grad` failed due to non-jittable string data type in the list.
eqx.filter_grad successful, gradient: [None, None, Array([1., 1.], dtype=float64), None, Array([1., 1.], dtype=float64)]
Build wavefunction#
Let’s start by building the following variational wavefunction
where the network has an array input \(s\) and a scalar output \(\psi\), and \(W^{(1)}\), \(b^{(1)}\), \(W^{(2)}\), and \(b^{(2)}\) are variational parameters.
import equinox as eqx
class MyModel(eqx.Module):
layer1: eqx.nn.Linear # eqx.nn.Linear is a built-in linear layer
layer2: eqx.nn.Linear
def __init__(self, in_size: int, width: int):
keys = qtx.get_subkeys(2) # Convenient function in Quantax to provide keys
layer1 = eqx.nn.Linear(in_size, width, key=keys[0])
self.layer1 = qtx.nn.apply_he_normal(keys[0], layer1) # He initialization
layer2 = eqx.nn.Linear(width, width, key=keys[1])
self.layer2 = qtx.nn.apply_lecun_normal(keys[1], layer2) # LeCun initialization
def __call__(self, x):
x = jax.nn.relu(self.layer1(x))
x = self.layer2(x)
psi = jnp.sum(jnp.exp(x))
return psi
model = MyModel(in_size=L, width=16)
print(model)
MyModel(
layer1=Linear(
weight=f64[16,8],
bias=f64[16],
in_features=8,
out_features=16,
use_bias=True
),
layer2=Linear(
weight=f64[16,16],
bias=f64[16],
in_features=16,
out_features=16,
use_bias=True
)
)
We can test it by making a forward pass.
s = jnp.ones(L)
psi = model(s)
print("psi =", psi)
psi = 94.73527623491115
Now let’s use this new network in Quantax. One should wrap the network by Variational to use it as a variational state. It supports batched forward pass.
lattice = qtx.sites.Chain(L)
state = qtx.state.Variational(model)
print("Number of parameters:", state.nparams)
s = qtx.utils.rand_states(8) # 8 random spin configurations
psi = state(s) # Batched forward pass
print("psi =", psi)
Number of parameters: 416
psi = [31.38266678 22.64277817 23.10738473 39.89670204 24.86820092 38.14627733
39.62772614 27.06081813]
Test by exact reconfiguration#
Exact reconfiguration (ER) is an optimization method that approximates imaginary-time evolution without Monte Carlo samples, which is only available in small systems. We can use ER to rapidly test the expressive power of neural networks.
H = qtx.operator.Ising(h=1.0)
E, wf = H.diagonalize()
exact_state = qtx.state.DenseState(wf)
optimizer = qtx.optimizer.ER(state, H)
energy = qtx.utils.DataTracer()
training_rate = 0.02
for i in range(100):
step = optimizer.get_step()
state.update(step * training_rate)
energy.append(optimizer.energy)
energy.plot(baseline=E)
plt.show()
In small systems, we can transform Variational to DenseState to check its overlap with the exact ground state.
dense = state.todense().normalize()
overlap = abs(dense @ exact_state)
print("Overlap with the exact ground state:", overlap)
Overlap with the exact ground state: 0.9917261588189411
Now we have a nice neural quantum state for solving the Ising model!
Avoid overflow#
In neural quantum state simulations, we often have very large wavefunctions beyond the range of float64. Here is an example.
model = MyModel(in_size=L, width=16)
W1 = model.layer1.weight
W2 = model.layer2.weight
# Manually multiply weights by 100 to cause overflow
model = eqx.tree_at(lambda model: model.layer1.weight, model, W1 * 100)
model = eqx.tree_at(lambda model: model.layer2.weight, model, W2 * 100)
s = qtx.utils.rand_states()
model(s)
Array(inf, dtype=float64)
To avoid this problem, we define two customized data types, LogArray and ScaleArray, to store large values. They are also accepted as network outputs in Quantax. Instead of using dangerous functions like jnp.exp that might cause overflow, one can use qtx.nn.exp_by_scale() to output safe values expressed by ScaleArray.
class NewModel(eqx.Module):
layer1: eqx.nn.Linear
layer2: eqx.nn.Linear
def __init__(self, in_size: int, width: int):
keys = qtx.get_subkeys(2)
layer1 = eqx.nn.Linear(in_size, width, key=keys[0])
self.layer1 = qtx.nn.apply_he_normal(keys[0], layer1)
layer2 = eqx.nn.Linear(width, width, key=keys[1])
self.layer2 = qtx.nn.apply_lecun_normal(keys[1], layer2)
def __call__(self, x):
x = jax.nn.relu(self.layer1(x))
x = self.layer2(x)
# Dangerous: psi = jnp.sum(jnp.exp(x))
# Safe:
psi = qtx.nn.exp_by_scale(x).sum()
return psi
model = NewModel(in_size=L, width=16)
model = eqx.tree_at(lambda model: model.layer1.weight, model, W1 * 100)
model = eqx.tree_at(lambda model: model.layer2.weight, model, W2 * 100)
psi = model(s)
print(psi)
ScaleArray(
significand=1.0,
exponent=11656.639936899364
)
Here the output ScaleArray is a PyTree with significand \(x\) and exponent \(\theta\). The true expressed value is \(x e^\theta\), which is beyond the range of float64. In most calculations, this quantity can be treated like an ordinary array object, as shown below.
psi = psi.repeat(8).reshape(2, 4)
print("Reshape psi:", psi)
psi = psi.sum(axis=1)
print("Sum psi:", psi)
psi = psi ** (1 / 10000)
print("Power psi:", psi)
psi = jnp.asarray(psi)
print("To jax Array:", psi)
Reshape psi: ScaleArray(
significand=[[1. 1. 1. 1.]
[1. 1. 1. 1.]],
exponent=11656.639936899364
)
Sum psi: ScaleArray(
significand=[4. 4.],
exponent=11656.639936899364
)
Power psi: ScaleArray(
significand=[1.00013864 1.00013864],
exponent=1.1656639936899365
)
To jax Array: [3.20849706 3.20849706]
However, JAX doesn’t have a full support for customized arrays, so one should be careful when using these customized arrays. Here we list several possible problems.
Manipulations like
jnp.fn(array)transform customized arrays tojax.Array, causing overflow. To avoid it, callarray.fn().Computations like
jax_array * customized_arrayalways calljax_array.__mul__(customized_array), which returns ajax.Arraythat might cause overflow. To avoid it, usecustomized_array * jax_array.
Here are some examples
significand = jnp.array([0.0, 1.0, 2.0, 3.0])
exponent = jnp.array(10000.0)
psi = qtx.utils.ScaleArray(significand, exponent)
print("Wrong sum:", jnp.sum(psi))
print("Correct sum:", psi.sum())
a = jnp.arange(4)
print("Wrong mul:", a * psi)
print("Correct mul:", psi * a)
Wrong sum: nan
Correct sum: ScaleArray(
significand=6.0,
exponent=10000.0
)
Wrong mul: [nan inf inf inf]
Correct mul: ScaleArray(
significand=[0. 0.33333333 1.33333333 3. ],
exponent=10001.098612288668
)