{
"cells": [
{
"cell_type": "markdown",
"id": "26c27d95",
"metadata": {},
"source": [
"# Restricted Boltzmann machine\n",
"\n",
"Paper title: Solving the quantum many-body problem with artificial neural networks\n",
"\n",
"Paper authors: Giuseppe Carleo and Matthias Troyer\n",
"\n",
"[arXiv:1606.02318 (2016)](https://arxiv.org/abs/1606.02318)\n",
"\n",
"[Science 355, 602 (2017)](https://iopscience.iop.org/article/10.1088/1361-648X/abe268)\n",
"\n",
"In this example, we solve the ground state of 10x10 Heisenberg model by utilizing a restricted Boltzmann machine with channel number $\\alpha=16$.\n",
"\n",
"Related tutorials: {doc}`Quick start <../tutorials/quick_start>`\n",
"\n",
"Estimated cost: 1 RTX4090 x 10 min"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "06ac78b4",
"metadata": {},
"outputs": [],
"source": [
"import quantax as qtx\n",
"import matplotlib.pyplot as plt\n",
"from IPython.display import clear_output\n",
"%config InlineBackend.figure_format = 'svg'\n",
"\n",
"lattice = qtx.sites.Square(10, Nparticles=(50, 50))"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "98bb46b6",
"metadata": {},
"outputs": [],
"source": [
"H = qtx.operator.Heisenberg(msr=True)\n",
"model = qtx.model.RBM_Conv(channels=16)\n",
"state = qtx.state.Variational(model, max_parallel=8192*140)\n",
"sampler = qtx.sampler.SpinExchange(state, nsamples=8192)\n",
"\n",
"# SR solver with pseudo-inverse\n",
"# In the original paper, the regularization is done by a diagonal shift\n",
"solver = qtx.optimizer.lstsq_pinv_eig(rtol=1e-9)\n",
"optimizer = qtx.optimizer.SR(state, H, solver=solver)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "81bc1c84",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"E_QMC = -268.62107\n",
"\n",
"energy = qtx.utils.DataTracer()\n",
"\n",
"for i in range(1000):\n",
" samples = sampler.sweep()\n",
" step = optimizer.get_step(samples)\n",
" state.update(step * 2e-3)\n",
" energy.append(optimizer.energy)\n",
"\n",
" if i % 10 == 0:\n",
" clear_output()\n",
" energy.plot(start=-200, batch=10, baseline=E_QMC)\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"id": "407e4aeb",
"metadata": {},
"source": [
"The relative error of variational accuracy, given by\n",
"\n",
"$$\n",
"\\epsilon_\\mathrm{rel} = (E_\\mathrm{NQS} - E_\\mathrm{QMC}) / |E_\\mathrm{QMC}|,\n",
"$$\n",
"\n",
"is similar to the $10^{-3}$ result presented in Fig. 3(C) of the original paper"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "4cb058fd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.0009721019350138251\n"
]
}
],
"source": [
"E = energy[-100:].mean()\n",
"rel_err = (E - E_QMC) / abs(E_QMC)\n",
"print(rel_err)"
]
},
{
"cell_type": "markdown",
"id": "9cc31a4e",
"metadata": {},
"source": [
"Here we reproduce Fig. 2 of the original paper, which shows the weights in RBM.\n",
"The scale looks different due to training details, but the patterns are similar."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "48580faa",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import jax.numpy as jnp\n",
"from matplotlib.colors import TwoSlopeNorm\n",
"\n",
"W = state.model.layers[1].weight\n",
"\n",
"# single symmetric color scale (centered at 0)\n",
"v = jnp.max(jnp.abs(W))\n",
"norm = TwoSlopeNorm(vmin=-v, vcenter=0.0, vmax=v)\n",
"\n",
"fig, axes = plt.subplots(4, 4, figsize=(8, 8), constrained_layout=True)\n",
"\n",
"for i, ax in enumerate(axes.flat):\n",
" im = ax.imshow(W[i, 0], cmap='RdYlBu_r', norm=norm)\n",
" ax.set_xticks([])\n",
" ax.set_yticks([])\n",
" ax.set_title(rf'$W^{({i+1})}$', fontsize=10, fontstyle='italic')\n",
"\n",
"# one horizontal colorbar under all panels\n",
"cbar = fig.colorbar(im, ax=axes, orientation='horizontal', pad=0.08, shrink=0.9)\n",
"\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "quantax_env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}