Restricted Boltzmann machine#

Paper title: Solving the quantum many-body problem with artificial neural networks

Paper authors: Giuseppe Carleo and Matthias Troyer

arXiv:1606.02318 (2016)

Science 355, 602 (2017)

In this example, we solve the ground state of 10x10 Heisenberg model by utilizing a restricted Boltzmann machine with channel number \(\alpha=16\).

Related tutorials: Quick start

Estimated cost: 1 RTX4090 x 10 min

import quantax as qtx
import matplotlib.pyplot as plt
from IPython.display import clear_output
%config InlineBackend.figure_format = 'svg'

lattice = qtx.sites.Square(10, Nparticles=(50, 50))
H = qtx.operator.Heisenberg(msr=True)
model = qtx.model.RBM_Conv(channels=16)
state = qtx.state.Variational(model, max_parallel=8192*140)
sampler = qtx.sampler.SpinExchange(state, nsamples=8192)

# SR solver with pseudo-inverse
# In the original paper, the regularization is done by a diagonal shift
solver = qtx.optimizer.lstsq_pinv_eig(rtol=1e-9)
optimizer = qtx.optimizer.SR(state, H, solver=solver)
E_QMC = -268.62107

energy = qtx.utils.DataTracer()

for i in range(1000):
    samples = sampler.sweep()
    step = optimizer.get_step(samples)
    state.update(step * 2e-3)
    energy.append(optimizer.energy)

    if i % 10 == 0:
        clear_output()
        energy.plot(start=-200, batch=10, baseline=E_QMC)
        plt.show()
../_images/d7c438ac7ace38d89bb9674ce66166b6bc67dfab23d6f54565abe6145eb5200c.svg

The relative error of variational accuracy, given by

\[ \epsilon_\mathrm{rel} = (E_\mathrm{NQS} - E_\mathrm{QMC}) / |E_\mathrm{QMC}|, \]

is similar to the \(10^{-3}\) result presented in Fig. 3(C) of the original paper

E = energy[-100:].mean()
rel_err = (E - E_QMC) / abs(E_QMC)
print(rel_err)
0.0009721019350138251

Here we reproduce Fig. 2 of the original paper, which shows the weights in RBM. The scale looks different due to training details, but the patterns are similar.

import jax.numpy as jnp
from matplotlib.colors import TwoSlopeNorm

W = state.model.layers[1].weight

# single symmetric color scale (centered at 0)
v = jnp.max(jnp.abs(W))
norm = TwoSlopeNorm(vmin=-v, vcenter=0.0, vmax=v)

fig, axes = plt.subplots(4, 4, figsize=(8, 8), constrained_layout=True)

for i, ax in enumerate(axes.flat):
    im = ax.imshow(W[i, 0], cmap='RdYlBu_r', norm=norm)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(rf'$W^{({i+1})}$', fontsize=10, fontstyle='italic')

# one horizontal colorbar under all panels
cbar = fig.colorbar(im, ax=axes, orientation='horizontal', pad=0.08, shrink=0.9)

plt.show()
../_images/4e60df9d62e62bc0307fd98026ea7ac76093def0142eb3527784adea90d286e9.svg