Neural network backflow#
Paper title: Backflow Transformations via Neural Networks for Quantum Many-Body Wave Functions
Paper authors: Di Luo and Bryan K. Clark
Phys. Rev. Lett. 122, 226401 (2019)
In this example, we use neural network backflow (NNBF) to solve the ground state of 4x4 Hubbard model at 1/8 hole doping.
Related tutorials: Exact diagonalization, Build your network, Samples and Measurement, Fermion mean field
Estimated cost: 1 A100 x 20 min
Define system and perform ED#
The Hubbard Hamiltonian is
In the 4x4 system with 1/8 doping, the ground state sector is defined by momentum \(k=0\), B1 representation of \(C_{4v}\) group, and spin inverse eigenvalue -1.
import numpy as np
import jax
import jax.numpy as jnp
import jax.random as jr
import equinox as eqx
import quantax as qtx
from quantax.symmetry import TransND, C4v, SpinInverse
import matplotlib.pyplot as plt
from IPython.display import clear_output
%config InlineBackend.figure_format = 'svg'
# 4x4 lattice with 1/8 hole doping, 7 spin-up and 7 spin-down fermions
lattice = qtx.sites.Square(
4, particle_type=qtx.PARTICLE_TYPE.spinful_fermion, Nparticles=(7, 7)
)
N = lattice.Nsites
H = qtx.operator.Hubbard(U=8)
symm = TransND() @ C4v(repr="B1") @ SpinInverse(-1)
E_gs, wf_gs = H.diagonalize(symm)
E_gs = E_gs[0]
print(E_gs)
-11.86883556956402
Slater-Jastrow determinant state (S0)#
The Jastrow-Slater wave function can be expressed as
where \(n \star \phi\) means the full orbital matrix is sliced according to the fermion occupation number, and \(J(n)\) is the Jastrow factor given by
The variational parameters are \(\phi_\uparrow\), \(\phi_\downarrow\), and \(v\). We start by training this simple wave function to obtain a good initial state.
from quantax.nn import fermion_idx
from quantax.utils import LogArray
def _init_params():
keys = qtx.get_subkeys(3)
Nup, Ndn = lattice.Nparticles
phi_up = jr.normal(keys[0], (N, Nup))
phi_dn = jr.normal(keys[1], (N, Ndn))
M = lattice.Nfmodes
v = jr.normal(keys[2], (M, M)) / M
return phi_up, phi_dn, v
def _slater_forward(phi_up, phi_dn, n):
idx = fermion_idx(n)
Nup = lattice.Nparticles[0]
idx_up = idx[:Nup]
idx_dn = idx[Nup:] - N
M_up = phi_up[idx_up]
M_dn = phi_dn[idx_dn]
sign_up, logabs_up = jnp.linalg.slogdet(M_up)
psi_up = LogArray(sign_up, logabs_up)
sign_dn, logabs_dn = jnp.linalg.slogdet(M_dn)
psi_dn = LogArray(sign_dn, logabs_dn)
return psi_up * psi_dn
class JastrowSlater(eqx.Module):
phi_up: jax.Array
phi_dn: jax.Array
v: jax.Array
def __init__(self):
self.phi_up, self.phi_dn, self.v = _init_params()
def __call__(self, n: jax.Array) -> jax.Array:
jastrow = qtx.nn.exp_by_log(-0.5 * n @ self.v @ n)
return jastrow * _slater_forward(self.phi_up, self.phi_dn, n)
model_s0 = JastrowSlater()
state_s0 = qtx.state.Variational(model_s0, max_parallel=8192*45)
sampler = qtx.sampler.ParticleHop(state_s0, 8192, sweep_steps=10 * N)
optimizer = qtx.optimizer.SR(state_s0, H)
energy = qtx.utils.DataTracer()
for i in range(500):
samples = sampler.sweep()
step = optimizer.get_step(samples)
state_s0.update(step * 0.02)
energy.append(optimizer.energy)
if i % 10 == 0:
clear_output()
energy.plot(start=-200, batch=10, baseline=E_gs)
plt.show()
The relative error of Jastrow-Slater determinant state is around 5%, as presented by S0 percentage error in Fig.2. It might converge to a local minimum, so try to restart training if you cannot get a similar accuracy.
E = energy[-20:].mean()
print("Relative error:", jnp.abs((E - E_gs) / E_gs))
Relative error: 0.04863341849280481
Slater determinant NNBF (SN)#
Now we start applying neural networks. It provides a correction to single-particle orbitals as
where \(a(n)\) is given by neural networks. The full wave function is given by
class SlaterBackflow(eqx.Module):
mlp_up: eqx.nn.MLP
mlp_dn: eqx.nn.MLP
phi_up: jax.Array
phi_dn: jax.Array
v: jax.Array
def __init__(self, width: int):
self.phi_up, self.phi_dn, self.v = _init_params()
keys = qtx.get_subkeys(2)
self.mlp_up = eqx.nn.MLP(
in_size=lattice.Nfmodes,
out_size=self.phi_up.size,
width_size=width,
depth=1,
use_final_bias=False, # final bias is phi
key=keys[0],
)
self.mlp_dn = eqx.nn.MLP(
in_size=lattice.Nfmodes,
out_size=self.phi_up.size,
width_size=width,
depth=1,
use_final_bias=False, # final bias is phi
key=keys[1],
)
def __call__(self, n: jax.Array) -> jax.Array:
jastrow = jnp.exp(0.5 * n @ self.v @ n)
phi_up = self.phi_up + self.mlp_up(n).reshape(self.phi_up.shape)
phi_dn = self.phi_dn + self.mlp_dn(n).reshape(self.phi_dn.shape)
return jastrow * _slater_forward(phi_up, phi_dn, n)
model_sn = SlaterBackflow(width=256)
# Initialize the backflow model with the optimized S0 parameters
model0 = state_s0.model
model_sn = eqx.tree_at(lambda model: model.phi_up, model_sn, model0.phi_up)
model_sn = eqx.tree_at(lambda model: model.phi_dn, model_sn, model0.phi_dn)
model_sn = eqx.tree_at(lambda model: model.v, model_sn, model0.v)
state_sn = qtx.state.Variational(model_sn, max_parallel=8192*45)
sampler = qtx.sampler.ParticleHop(state_sn, 8192, sweep_steps=10 * N)
optimizer = qtx.optimizer.SR(state_sn, H)
energy = qtx.utils.DataTracer()
for i in range(500):
samples = sampler.sweep()
step = optimizer.get_step(samples)
state_sn.update(step * 0.05)
energy.append(optimizer.energy)
if i % 10 == 0:
clear_output()
energy.plot(start=-200, batch=10, baseline=E_gs)
plt.show()
The relative error of Jastrow-Slater backflow state is around 1.6%, as presented by SN percentage error in Fig.2.
E = energy[-20:].mean()
print("Relative error:", jnp.abs((E - E_gs) / E_gs))
Relative error: 0.01732872707164705
Plot with SN result#
Then we can measure the charge and spin density as presented in Fig.3.
from quantax.operator import number_u, number_d
def charge(i):
return number_u(i) + number_d(i)
def spin(i):
return number_u(i) - number_d(i)
samples = sampler.sweep()
charge_op = [charge(i) for i in range(N)]
spin_op = [spin(i) for i in range(N)]
charge_density = [op.expectation(state_sn, samples) for op in charge_op]
charge_density = np.asarray(charge_density).reshape(lattice.shape[1:])
spin_density = [op.expectation(state_sn, samples) for op in spin_op]
spin_density = np.asarray(spin_density).reshape(lattice.shape[1:])
fig, axes = plt.subplots(1, 2, constrained_layout=True)
im = axes[0].imshow(charge_density, cmap='viridis', vmin=0.65, vmax=1)
axes[0].axis('off')
axes[0].set_title("Charge density")
fig.colorbar(im, ax=axes[0], shrink=0.5)
im = axes[1].imshow(spin_density, cmap='viridis', vmin=-1, vmax=1)
axes[1].axis('off')
axes[1].set_title("Spin density")
fig.colorbar(im, ax=axes[1], shrink=0.5)
plt.show()
The backflow parameters in Fig.4 and Fig.5 are shown below.
n = samples.spins[0]
print("Occupation number:\n", n.reshape(2, -1))
phi0u = state_sn.model.phi_up
phi0d = state_sn.model.phi_dn
phinu = phi0u + state_sn.model.mlp_up(n).reshape(phi0u.shape)
phind = phi0d + state_sn.model.mlp_dn(n).reshape(phi0d.shape)
fig, axes = plt.subplots(2, 2, figsize=(8, 4), constrained_layout=True)
phi_stack = jnp.stack([phi0u, phi0d, phinu, phind])
vmax = jnp.max(jnp.abs(phi_stack))
axes[0, 0].imshow(phi0u.T, cmap="RdYlBu", vmin=-vmax, vmax=vmax, aspect="auto")
axes[0, 0].set_ylabel("s.p.o")
axes[1, 0].imshow(phi0d.T, cmap="RdYlBu", vmin=-vmax, vmax=vmax, aspect="auto")
axes[1, 0].set_xlabel("Sites")
axes[1, 0].set_ylabel("s.p.o")
axes[0, 1].imshow(phinu.T, cmap="RdYlBu", vmin=-vmax, vmax=vmax, aspect="auto")
im = axes[1, 1].imshow(phind.T, cmap="RdYlBu", vmin=-vmax, vmax=vmax, aspect="auto")
axes[1, 1].set_xlabel("Sites")
fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.8)
plt.show()
Occupation number:
[[-1 -1 1 -1 -1 -1 -1 1 1 -1 1 -1 1 1 -1 1]
[ 1 1 -1 1 1 -1 1 -1 -1 1 -1 -1 -1 -1 1 -1]]
Wup = state_sn.model.mlp_up.layers[0].weight
Wdn = state_sn.model.mlp_dn.layers[0].weight
vmax_up = jnp.max(jnp.abs(Wup))
vmax_dn = jnp.max(jnp.abs(Wdn))
vmax = max(vmax_up, vmax_dn)
bup = state_sn.model.mlp_up.layers[0].bias
bdn = state_sn.model.mlp_dn.layers[0].bias
idx_up = jnp.argsort(bup)[::-1]
idx_dn = jnp.argsort(bdn)[::-1]
fig, axes = plt.subplots(3, 2, figsize=(6, 8))
axes[0, 0].imshow(
Wup[idx_up[0:32]], cmap="RdYlBu", vmin=-vmax, vmax=vmax, aspect="auto"
)
axes[0, 0].axis("off")
axes[1, 0].imshow(
Wup[idx_up[95:128]], cmap="RdYlBu", vmin=-vmax, vmax=vmax, aspect="auto"
)
axes[1, 0].axis("off")
axes[2, 0].imshow(
Wup[idx_up[223:256]], cmap="RdYlBu", vmin=-vmax, vmax=vmax, aspect="auto"
)
axes[2, 0].axis("off")
axes[0, 1].imshow(
Wdn[idx_dn[0:32]], cmap="RdYlBu", vmin=-vmax, vmax=vmax, aspect="auto"
)
axes[0, 1].axis("off")
axes[1, 1].imshow(
Wdn[idx_dn[95:128]], cmap="RdYlBu", vmin=-vmax, vmax=vmax, aspect="auto"
)
axes[1, 1].axis("off")
im = axes[2, 1].imshow(
Wdn[idx_dn[223:256]], cmap="RdYlBu", vmin=-vmax, vmax=vmax, aspect="auto"
)
axes[2, 1].axis("off")
fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.8)
plt.show()
Another wave function utilized in the paper is pairing NNBF (PN), which is left for readers to try.