{ "cells": [ { "cell_type": "markdown", "id": "d92c0398", "metadata": {}, "source": [ "# Local updates\n", "\n", "Updates in VMC are often local, for instance, flip of a few spins. If the variational wavefunction has specific internal structures, one doesn't have to recompute the wavefunctions from scratch. Therefore, defining a different computing graph for local updates might greatly accelerate the VMC simulation.\n", "\n", "The local update technique has been widely adopted in mean-field fermionic wavefunctions and tensor networks. In this tutorials, we will introduce how to define local updates in your wavefunction.\n", "\n", "As a prerequisite, please read {doc}`Build your network ` to know how to build a network in Quantax." ] }, { "cell_type": "code", "execution_count": null, "id": "2a8f5bf6", "metadata": {}, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp\n", "import equinox as eqx\n", "import quantax as qtx\n", "\n", "L = 64\n", "lattice = qtx.sites.Chain(L)" ] }, { "cell_type": "markdown", "id": "9a293e15", "metadata": {}, "source": [ "## RefModel\n", "\n", "Let's consider a restricted Boltzmann machine (RBM) wavefunction,\n", "\n", "$$\n", "\\begin{aligned}\n", " h_i &= \\sum_j W_{ij} s_j + b_i \\\\\n", " \\psi &= \\prod_i \\cosh h_i\n", "\\end{aligned}\n", "$$" ] }, { "cell_type": "code", "execution_count": 2, "id": "21b511de", "metadata": {}, "outputs": [], "source": [ "class RBM(eqx.Module):\n", " linear: eqx.nn.Linear\n", "\n", " def __init__(self, M: int):\n", " key = qtx.get_subkeys()\n", " linear = eqx.nn.Linear(L, M, key=key)\n", " self.linear = qtx.nn.apply_lecun_normal(key, linear)\n", "\n", " def __call__(self, x: jax.Array) -> jax.Array:\n", " h = self.linear(x)\n", " return jnp.prod(jnp.cosh(h))" ] }, { "cell_type": "markdown", "id": "ec6792fa", "metadata": {}, "source": [ "If $s$ is flipped locally to generate $s'$, then\n", "\n", "$$\n", "\\begin{aligned}\n", " h_i' &= \\sum_j W_{ij} s_j' + b_i = h_i + \\sum_{j \\in \\{ j|s_j' \\neq s_j \\}} W_{ij} (s_j' - s_j) \\\\\n", " \\psi' &= \\prod_i \\cosh h_i'\n", "\\end{aligned}\n", "$$\n", "\n", "Assume the number of hidden units $M$ is of order $O(L)$. Then the local update reduces complexity from $O(L^2)$ to $O(L)$.\n", "\n", "Quantax provides {py:class}`~quantax.nn.RefModel`, a subclass of `eqx.Module`, for local updates. Here, we construct `RBM_Ref` as a RBM with local updates." ] }, { "cell_type": "code", "execution_count": 32, "id": "e2993efd", "metadata": {}, "outputs": [], "source": [ "from typing import Union\n", "from jaxtyping import PyTree\n", "\n", "\n", "class RBM_Ref(qtx.nn.RefModel):\n", " linear: eqx.nn.Linear\n", "\n", " def __init__(self, M: int):\n", " \"\"\"The same as usual RBM\"\"\"\n", " key = qtx.get_subkeys()\n", " linear = eqx.nn.Linear(L, M, key=key)\n", " self.linear = qtx.nn.apply_lecun_normal(key, linear)\n", "\n", " def __call__(self, x: jax.Array) -> jax.Array:\n", " \"\"\"The same as usual RBM\"\"\"\n", " h = self.linear(x)\n", " return jnp.prod(jnp.cosh(h))\n", "\n", " def init_internal(self, x: jax.Array) -> PyTree:\n", " \"\"\"Compute the initial hidden units for local updates in `ref_forward`.\"\"\"\n", " h = self.linear(x)\n", " return h\n", "\n", " def ref_forward(\n", " self,\n", " s: jax.Array,\n", " s_old: jax.Array,\n", " nflips: int,\n", " internal: PyTree,\n", " return_update: bool,\n", " ) -> Union[jax.Array, tuple[jax.Array, PyTree]]:\n", " \"\"\"\n", " Forward pass with reference to the old configuration and the number of flipped spins.\n", " This is the core function of local updates.\n", " \"\"\"\n", " # A marker that local updates are being used.\n", " print(\"Using local updates\")\n", "\n", " diff = s - s_old\n", " idx_flip = jnp.flatnonzero(diff, size=nflips)\n", " h_diff = self.linear.weight[:, idx_flip] @ diff[idx_flip]\n", " h_new = internal + h_diff\n", " psi = jnp.prod(jnp.cosh(h_new))\n", " if return_update:\n", " return psi, h_new\n", " else:\n", " return psi" ] }, { "cell_type": "markdown", "id": "05d1d411", "metadata": {}, "source": [ "Be careful that `init_internal` and `ref_forward` will be automatically jitted in Quantax.\n", "In `ref_forward`, only `nflips` and `return_update` will be treated as static arguments. Therefore, the shapes of arrays shouldn't depend on other inputs. For instance, `jnp.flatnonzero` will trigger jit error in the sampling and local energy examples below if `size=nflips` is not specified.\n", "\n", "The correctness of local updates is checked here." ] }, { "cell_type": "code", "execution_count": 33, "id": "639bd718", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Testing direct forward ...\n", "Direct forward psi: 3.794001513041848e+40\n", "Testing local updates ...\n", "Using local updates\n", "Local updates psi: 3.7940015130418334e+40\n" ] } ], "source": [ "model = RBM_Ref(4 * L)\n", "s_old = qtx.utils.rand_states()\n", "internal = model.init_internal(s_old)\n", "\n", "nflips = 1\n", "s_new = s_old.at[0].multiply(-1)\n", "\n", "print(\"Testing direct forward ...\")\n", "psi_direct = model(s_new)\n", "print(\"Direct forward psi: \", psi_direct)\n", "\n", "print(\"Testing local updates ...\")\n", "psi_ref = model.ref_forward(s_new, s_old, nflips, internal, return_update=False)\n", "print(\"Local updates psi: \", psi_ref)\n", "assert jnp.isclose(psi_direct, psi_ref)" ] }, { "cell_type": "markdown", "id": "c50394b6", "metadata": {}, "source": [ "## VMC with local updates\n", "\n", "To use `RefModel` in VMC, one needs to wrap it by {py:class}`~quantax.state.Variational`. `Variational` with `RefModel` provides batched and jitted {py:meth}`~quantax.state.Variational.init_internal`, {py:meth}`~quantax.state.Variational.ref_forward`, and {py:meth}`~quantax.state.Variational.ref_forward_with_updates` methods." ] }, { "cell_type": "code", "execution_count": 34, "id": "d2096eda", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Testing direct forward ...\n", "Testing local updates ...\n", "Using local updates\n" ] } ], "source": [ "state = qtx.state.Variational(model)\n", "\n", "s_old = qtx.utils.rand_states(16384)\n", "internal = state.init_internal(s_old)\n", "\n", "nflips = 1\n", "idx_segment = jnp.arange(s_old.shape[0])\n", "s_new = s_old.at[:, 0].multiply(-1)\n", "\n", "print(\"Testing direct forward ...\")\n", "psi_direct = state(s_new)\n", "\n", "print(\"Testing local updates ...\")\n", "psi_ref = state.ref_forward(s_new, s_old, nflips, idx_segment, internal)\n", "assert jnp.allclose(psi_direct, psi_ref)" ] }, { "cell_type": "markdown", "id": "74bf1e01", "metadata": {}, "source": [ "Then we can test the time cost. `ref_forward` is indeed faster than direct forward pass." ] }, { "cell_type": "code", "execution_count": 25, "id": "c803370d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "455 μs ± 7.73 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" ] } ], "source": [ "%timeit jax.block_until_ready(state(s_new))" ] }, { "cell_type": "code", "execution_count": 26, "id": "21e8c4b3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "247 μs ± 2.17 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" ] } ], "source": [ "%timeit jax.block_until_ready(state.ref_forward(s_new, s_old, nflips, idx_segment, internal))" ] }, { "cell_type": "markdown", "id": "1a938f73", "metadata": {}, "source": [ "The samplers in Quantax automatically utilize local updates whenever possible." ] }, { "cell_type": "code", "execution_count": 35, "id": "789fa0db", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using local updates\n" ] } ], "source": [ "sampler = qtx.sampler.LocalFlip(state, nsamples=16384)\n", "samples = sampler.sweep()" ] }, { "cell_type": "markdown", "id": "d7ff5c17", "metadata": {}, "source": [ "You might see memory overflow errors if `internal` stores too many values. To avoid it, please check the documentation of `max_parallel` in {py:meth}`quantax.state.Variational.__init__`.\n", "\n", "The local updates are also utilized when computing local energies. Here is an example." ] }, { "cell_type": "code", "execution_count": 36, "id": "75877bd7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using local updates\n" ] } ], "source": [ "H = qtx.operator.Ising(h=1.0)\n", "Eloc = H.Oloc(state, samples)" ] }, { "cell_type": "markdown", "id": "7f436e13", "metadata": {}, "source": [ "To disable local updates, one can set `use_ref=False` when defining `Variational`. Using this trick, we can check that local energies computed by local updates and direct forward passes are equivalent." ] }, { "cell_type": "code", "execution_count": 37, "id": "23c29af5", "metadata": {}, "outputs": [], "source": [ "state_direct = qtx.state.Variational(state.model, use_ref=False)\n", "Eloc_direct = H.Oloc(state_direct, samples)\n", "\n", "assert jnp.allclose(Eloc, Eloc_direct)" ] } ], "metadata": { "kernelspec": { "display_name": "quantax_env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.11" } }, "nbformat": 4, "nbformat_minor": 5 }