{ "cells": [ { "cell_type": "markdown", "id": "dce4d856", "metadata": {}, "source": [ "# Neural network backflow\n", "\n", "Paper title: Backflow Transformations via Neural Networks for Quantum Many-Body Wave Functions\n", "\n", "Paper authors: Di Luo and Bryan K. Clark\n", "\n", "[arxiv:1807.10770 (2018)](https://arxiv.org/abs/1807.10770)\n", "\n", "[Phys. Rev. Lett. 122, 226401 (2019)](https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.122.226401)\n", "\n", "In this example, we use neural network backflow (NNBF) to solve the ground state of 4x4 Hubbard model at 1/8 hole doping.\n", "\n", "Related tutorials: {doc}`Exact diagonalization <../tutorials/exact_diag>`, {doc}`Build your network <../tutorials/build_net>`, {doc}`Samples and Measurement <../tutorials/samples>`, {doc}`Fermion mean field <../tutorials/fermion_mf>`\n", "\n", "Estimated cost: 1 A100 x 20 min" ] }, { "cell_type": "markdown", "id": "b66a6e27", "metadata": {}, "source": [ "## Define system and perform ED" ] }, { "cell_type": "markdown", "id": "41aabc1f", "metadata": {}, "source": [ "The Hubbard Hamiltonian is\n", "\n", "$$\n", "H = -t \\sum_{\\left,\\sigma} (c_{i\\sigma}^\\dagger c_{j\\sigma} + h.c.)\n", "+ U \\sum_i n_{i\\uparrow} n_{i\\downarrow}\n", "$$\n", "\n", "In the 4x4 system with 1/8 doping, the ground state sector is defined by momentum $k=0$, B1 representation of $C_{4v}$ group, and spin inverse eigenvalue -1." ] }, { "cell_type": "code", "execution_count": null, "id": "723f5020", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "-11.86883556956402\n" ] } ], "source": [ "import numpy as np\n", "import jax\n", "import jax.numpy as jnp\n", "import jax.random as jr\n", "import equinox as eqx\n", "import quantax as qtx\n", "from quantax.symmetry import TransND, C4v, SpinInverse\n", "import matplotlib.pyplot as plt\n", "from IPython.display import clear_output\n", "%config InlineBackend.figure_format = 'svg'\n", "\n", "# 4x4 lattice with 1/8 hole doping, 7 spin-up and 7 spin-down fermions\n", "lattice = qtx.sites.Square(\n", " 4, particle_type=qtx.PARTICLE_TYPE.spinful_fermion, Nparticles=(7, 7)\n", ")\n", "N = lattice.Nsites\n", "\n", "H = qtx.operator.Hubbard(U=8)\n", "\n", "symm = TransND() @ C4v(repr=\"B1\") @ SpinInverse(-1)\n", "\n", "E_gs, wf_gs = H.diagonalize(symm)\n", "E_gs = E_gs[0]\n", "print(E_gs)" ] }, { "cell_type": "markdown", "id": "60b7bf75", "metadata": {}, "source": [ "## Slater-Jastrow determinant state (S0)\n", "\n", "The Jastrow-Slater wave function can be expressed as\n", "\n", "$$\n", "\\psi_\\mathrm{S0}(n) = J(n) \\times \\det(n_\\uparrow \\star \\phi_\\uparrow) \\times \\det(n_\\downarrow \\star \\phi_\\downarrow),\n", "$$\n", "\n", "where $n \\star \\phi$ means the full orbital matrix is sliced according to the fermion occupation number, and $J(n)$ is the Jastrow factor given by\n", "\n", "$$\n", "J(n) = \\exp \\left( -\\frac{1}{2} \\sum_{ij} v_{ij} n_i n_j \\right).\n", "$$\n", "\n", "The variational parameters are $\\phi_\\uparrow$, $\\phi_\\downarrow$, and $v$. We start by training this simple wave function to obtain a good initial state." ] }, { "cell_type": "code", "execution_count": 2, "id": "454d1640", "metadata": {}, "outputs": [], "source": [ "from quantax.nn import fermion_idx\n", "from quantax.utils import LogArray\n", "\n", "\n", "def _init_params():\n", " keys = qtx.get_subkeys(3)\n", " Nup, Ndn = lattice.Nparticles\n", " phi_up = jr.normal(keys[0], (N, Nup))\n", " phi_dn = jr.normal(keys[1], (N, Ndn))\n", "\n", " M = lattice.Nfmodes\n", " v = jr.normal(keys[2], (M, M)) / M\n", " return phi_up, phi_dn, v\n", "\n", "\n", "def _slater_forward(phi_up, phi_dn, n):\n", " idx = fermion_idx(n)\n", " Nup = lattice.Nparticles[0]\n", " idx_up = idx[:Nup]\n", " idx_dn = idx[Nup:] - N\n", " M_up = phi_up[idx_up]\n", " M_dn = phi_dn[idx_dn]\n", " sign_up, logabs_up = jnp.linalg.slogdet(M_up)\n", " psi_up = LogArray(sign_up, logabs_up)\n", " sign_dn, logabs_dn = jnp.linalg.slogdet(M_dn)\n", " psi_dn = LogArray(sign_dn, logabs_dn)\n", " return psi_up * psi_dn\n", "\n", "\n", "class JastrowSlater(eqx.Module):\n", " phi_up: jax.Array\n", " phi_dn: jax.Array\n", " v: jax.Array\n", "\n", " def __init__(self):\n", " self.phi_up, self.phi_dn, self.v = _init_params()\n", "\n", " def __call__(self, n: jax.Array) -> jax.Array:\n", " jastrow = qtx.nn.exp_by_log(-0.5 * n @ self.v @ n)\n", " return jastrow * _slater_forward(self.phi_up, self.phi_dn, n)\n", "\n", "\n", "model_s0 = JastrowSlater()\n", "state_s0 = qtx.state.Variational(model_s0, max_parallel=8192*45)\n", "sampler = qtx.sampler.ParticleHop(state_s0, 8192, sweep_steps=10 * N)\n", "optimizer = qtx.optimizer.SR(state_s0, H)" ] }, { "cell_type": "code", "execution_count": 3, "id": "8973036a", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2025-10-27T18:27:58.249447\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.10.7, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "energy = qtx.utils.DataTracer()\n", "\n", "for i in range(500):\n", " samples = sampler.sweep()\n", " step = optimizer.get_step(samples)\n", " state_s0.update(step * 0.02)\n", " energy.append(optimizer.energy)\n", "\n", " if i % 10 == 0:\n", " clear_output()\n", " energy.plot(start=-200, batch=10, baseline=E_gs)\n", " plt.show()" ] }, { "cell_type": "markdown", "id": "6770da79", "metadata": {}, "source": [ "The relative error of Jastrow-Slater determinant state is around 5%, as presented by S0 percentage error in Fig.2. It might converge to a local minimum, so try to restart training if you cannot get a similar accuracy." ] }, { "cell_type": "code", "execution_count": 4, "id": "4e5e45b5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Relative error: 0.04863341849280481\n" ] } ], "source": [ "E = energy[-20:].mean()\n", "print(\"Relative error:\", jnp.abs((E - E_gs) / E_gs))" ] }, { "cell_type": "markdown", "id": "978f9af2", "metadata": {}, "source": [ "## Slater determinant NNBF (SN)\n", "\n", "Now we start applying neural networks. It provides a correction to single-particle orbitals as\n", "\n", "$$\n", "\\phi^b_\\sigma(n) = \\phi_\\sigma + a_\\sigma(n),\n", "$$\n", "\n", "where $a(n)$ is given by neural networks. The full wave function is given by\n", "\n", "$$\n", "\\psi_\\mathrm{SN}(n) = J(n) \\times \\det(n_\\uparrow \\star \\phi^b_\\uparrow(n)) \\times \\det(n_\\downarrow \\star \\phi^b_\\downarrow(n)).\n", "$$" ] }, { "cell_type": "code", "execution_count": 5, "id": "99a5bbd4", "metadata": {}, "outputs": [], "source": [ "class SlaterBackflow(eqx.Module):\n", " mlp_up: eqx.nn.MLP\n", " mlp_dn: eqx.nn.MLP\n", " phi_up: jax.Array\n", " phi_dn: jax.Array\n", " v: jax.Array\n", "\n", " def __init__(self, width: int):\n", " self.phi_up, self.phi_dn, self.v = _init_params()\n", "\n", " keys = qtx.get_subkeys(2)\n", " self.mlp_up = eqx.nn.MLP(\n", " in_size=lattice.Nfmodes,\n", " out_size=self.phi_up.size,\n", " width_size=width,\n", " depth=1,\n", " use_final_bias=False, # final bias is phi\n", " key=keys[0],\n", " )\n", "\n", " self.mlp_dn = eqx.nn.MLP(\n", " in_size=lattice.Nfmodes,\n", " out_size=self.phi_up.size,\n", " width_size=width,\n", " depth=1,\n", " use_final_bias=False, # final bias is phi\n", " key=keys[1],\n", " )\n", "\n", " def __call__(self, n: jax.Array) -> jax.Array:\n", " jastrow = jnp.exp(0.5 * n @ self.v @ n)\n", " phi_up = self.phi_up + self.mlp_up(n).reshape(self.phi_up.shape)\n", " phi_dn = self.phi_dn + self.mlp_dn(n).reshape(self.phi_dn.shape)\n", " return jastrow * _slater_forward(phi_up, phi_dn, n)\n", "\n", "\n", "model_sn = SlaterBackflow(width=256)\n", "\n", "# Initialize the backflow model with the optimized S0 parameters\n", "model0 = state_s0.model\n", "model_sn = eqx.tree_at(lambda model: model.phi_up, model_sn, model0.phi_up)\n", "model_sn = eqx.tree_at(lambda model: model.phi_dn, model_sn, model0.phi_dn)\n", "model_sn = eqx.tree_at(lambda model: model.v, model_sn, model0.v)\n", "\n", "state_sn = qtx.state.Variational(model_sn, max_parallel=8192*45)\n", "sampler = qtx.sampler.ParticleHop(state_sn, 8192, sweep_steps=10 * N)\n", "optimizer = qtx.optimizer.SR(state_sn, H)" ] }, { "cell_type": "code", "execution_count": 6, "id": "e027c28b", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2025-10-27T18:44:17.256363\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.10.7, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "energy = qtx.utils.DataTracer()\n", "\n", "for i in range(500):\n", " samples = sampler.sweep()\n", " step = optimizer.get_step(samples)\n", " state_sn.update(step * 0.05)\n", " energy.append(optimizer.energy)\n", "\n", " if i % 10 == 0:\n", " clear_output()\n", " energy.plot(start=-200, batch=10, baseline=E_gs)\n", " plt.show()" ] }, { "cell_type": "markdown", "id": "d23a5c07", "metadata": {}, "source": [ "The relative error of Jastrow-Slater backflow state is around 1.6%, as presented by SN percentage error in Fig.2." ] }, { "cell_type": "code", "execution_count": 7, "id": "296ed39c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Relative error: 0.01732872707164705\n" ] } ], "source": [ "E = energy[-20:].mean()\n", "print(\"Relative error:\", jnp.abs((E - E_gs) / E_gs))" ] }, { "cell_type": "markdown", "id": "ca41230e", "metadata": {}, "source": [ "## Plot with SN result\n", "\n", "Then we can measure the charge and spin density as presented in Fig.3." ] }, { "cell_type": "code", "execution_count": 8, "id": "ca7132e0", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2025-10-27T18:44:35.665249\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.10.7, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from quantax.operator import number_u, number_d\n", "\n", "def charge(i):\n", " return number_u(i) + number_d(i)\n", "\n", "def spin(i):\n", " return number_u(i) - number_d(i)\n", "\n", "samples = sampler.sweep()\n", "\n", "charge_op = [charge(i) for i in range(N)]\n", "spin_op = [spin(i) for i in range(N)]\n", "\n", "charge_density = [op.expectation(state_sn, samples) for op in charge_op]\n", "charge_density = np.asarray(charge_density).reshape(lattice.shape[1:])\n", "spin_density = [op.expectation(state_sn, samples) for op in spin_op]\n", "spin_density = np.asarray(spin_density).reshape(lattice.shape[1:])\n", "\n", "fig, axes = plt.subplots(1, 2, constrained_layout=True)\n", "\n", "im = axes[0].imshow(charge_density, cmap='viridis', vmin=0.65, vmax=1)\n", "axes[0].axis('off')\n", "axes[0].set_title(\"Charge density\")\n", "fig.colorbar(im, ax=axes[0], shrink=0.5)\n", "\n", "im = axes[1].imshow(spin_density, cmap='viridis', vmin=-1, vmax=1)\n", "axes[1].axis('off')\n", "axes[1].set_title(\"Spin density\")\n", "fig.colorbar(im, ax=axes[1], shrink=0.5)\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "a3c9e47d", "metadata": {}, "source": [ "The backflow parameters in Fig.4 and Fig.5 are shown below." ] }, { "cell_type": "code", "execution_count": 9, "id": "a6ac23e9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Occupation number:\n", " [[-1 -1 1 -1 -1 -1 -1 1 1 -1 1 -1 1 1 -1 1]\n", " [ 1 1 -1 1 1 -1 1 -1 -1 1 -1 -1 -1 -1 1 -1]]\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2025-10-27T18:44:36.698812\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.10.7, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "n = samples.spins[0]\n", "print(\"Occupation number:\\n\", n.reshape(2, -1))\n", "\n", "phi0u = state_sn.model.phi_up\n", "phi0d = state_sn.model.phi_dn\n", "phinu = phi0u + state_sn.model.mlp_up(n).reshape(phi0u.shape)\n", "phind = phi0d + state_sn.model.mlp_dn(n).reshape(phi0d.shape)\n", "\n", "fig, axes = plt.subplots(2, 2, figsize=(8, 4), constrained_layout=True)\n", "\n", "phi_stack = jnp.stack([phi0u, phi0d, phinu, phind])\n", "vmax = jnp.max(jnp.abs(phi_stack))\n", "\n", "axes[0, 0].imshow(phi0u.T, cmap=\"RdYlBu\", vmin=-vmax, vmax=vmax, aspect=\"auto\")\n", "axes[0, 0].set_ylabel(\"s.p.o\")\n", "\n", "axes[1, 0].imshow(phi0d.T, cmap=\"RdYlBu\", vmin=-vmax, vmax=vmax, aspect=\"auto\")\n", "axes[1, 0].set_xlabel(\"Sites\")\n", "axes[1, 0].set_ylabel(\"s.p.o\")\n", "\n", "axes[0, 1].imshow(phinu.T, cmap=\"RdYlBu\", vmin=-vmax, vmax=vmax, aspect=\"auto\")\n", "\n", "im = axes[1, 1].imshow(phind.T, cmap=\"RdYlBu\", vmin=-vmax, vmax=vmax, aspect=\"auto\")\n", "axes[1, 1].set_xlabel(\"Sites\")\n", "\n", "fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.8)\n", "\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 10, "id": "f03143ed", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2025-10-27T18:44:38.312538\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.10.7, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "Wup = state_sn.model.mlp_up.layers[0].weight\n", "Wdn = state_sn.model.mlp_dn.layers[0].weight\n", "vmax_up = jnp.max(jnp.abs(Wup))\n", "vmax_dn = jnp.max(jnp.abs(Wdn))\n", "vmax = max(vmax_up, vmax_dn)\n", "\n", "bup = state_sn.model.mlp_up.layers[0].bias\n", "bdn = state_sn.model.mlp_dn.layers[0].bias\n", "idx_up = jnp.argsort(bup)[::-1]\n", "idx_dn = jnp.argsort(bdn)[::-1]\n", "\n", "fig, axes = plt.subplots(3, 2, figsize=(6, 8))\n", "\n", "axes[0, 0].imshow(\n", " Wup[idx_up[0:32]], cmap=\"RdYlBu\", vmin=-vmax, vmax=vmax, aspect=\"auto\"\n", ")\n", "axes[0, 0].axis(\"off\")\n", "axes[1, 0].imshow(\n", " Wup[idx_up[95:128]], cmap=\"RdYlBu\", vmin=-vmax, vmax=vmax, aspect=\"auto\"\n", ")\n", "axes[1, 0].axis(\"off\")\n", "axes[2, 0].imshow(\n", " Wup[idx_up[223:256]], cmap=\"RdYlBu\", vmin=-vmax, vmax=vmax, aspect=\"auto\"\n", ")\n", "axes[2, 0].axis(\"off\")\n", "\n", "axes[0, 1].imshow(\n", " Wdn[idx_dn[0:32]], cmap=\"RdYlBu\", vmin=-vmax, vmax=vmax, aspect=\"auto\"\n", ")\n", "axes[0, 1].axis(\"off\")\n", "axes[1, 1].imshow(\n", " Wdn[idx_dn[95:128]], cmap=\"RdYlBu\", vmin=-vmax, vmax=vmax, aspect=\"auto\"\n", ")\n", "axes[1, 1].axis(\"off\")\n", "im = axes[2, 1].imshow(\n", " Wdn[idx_dn[223:256]], cmap=\"RdYlBu\", vmin=-vmax, vmax=vmax, aspect=\"auto\"\n", ")\n", "axes[2, 1].axis(\"off\")\n", "\n", "fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.8)\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "2f544694", "metadata": {}, "source": [ "Another wave function utilized in the paper is pairing NNBF (PN), which is left for readers to try." ] } ], "metadata": { "kernelspec": { "display_name": "quantax_env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 5 }