Local updates#

Updates in VMC are often local, for instance, flip of a few spins. If the variational wavefunction has specific internal structures, one doesn’t have to recompute the wavefunctions from scratch. Therefore, defining a different computing graph for local updates might greatly accelerate the VMC simulation.

The local update technique has been widely adopted in mean-field fermionic wavefunctions and tensor networks. In this tutorials, we will introduce how to define local updates in your wavefunction.

As a prerequisite, please read Build your network to know how to build a network in Quantax.

import jax
import jax.numpy as jnp
import equinox as eqx
import quantax as qtx

L = 64
lattice = qtx.sites.Chain(L)

RefModel#

Let’s consider a restricted Boltzmann machine (RBM) wavefunction,

\[\begin{split} \begin{aligned} h_i &= \sum_j W_{ij} s_j + b_i \\ \psi &= \prod_i \cosh h_i \end{aligned} \end{split}\]
class RBM(eqx.Module):
    linear: eqx.nn.Linear

    def __init__(self, M: int):
        key = qtx.get_subkeys()
        linear = eqx.nn.Linear(L, M, key=key)
        self.linear = qtx.nn.apply_lecun_normal(key, linear)

    def __call__(self, x: jax.Array) -> jax.Array:
        h = self.linear(x)
        return jnp.prod(jnp.cosh(h))

If \(s\) is flipped locally to generate \(s'\), then

\[\begin{split} \begin{aligned} h_i' &= \sum_j W_{ij} s_j' + b_i = h_i + \sum_{j \in \{ j|s_j' \neq s_j \}} W_{ij} (s_j' - s_j) \\ \psi' &= \prod_i \cosh h_i' \end{aligned} \end{split}\]

Assume the number of hidden units \(M\) is of order \(O(L)\). Then the local update reduces complexity from \(O(L^2)\) to \(O(L)\).

Quantax provides RefModel, a subclass of eqx.Module, for local updates. Here, we construct RBM_Ref as a RBM with local updates.

from typing import Union
from jaxtyping import PyTree


class RBM_Ref(qtx.nn.RefModel):
    linear: eqx.nn.Linear

    def __init__(self, M: int):
        """The same as usual RBM"""
        key = qtx.get_subkeys()
        linear = eqx.nn.Linear(L, M, key=key)
        self.linear = qtx.nn.apply_lecun_normal(key, linear)

    def __call__(self, x: jax.Array) -> jax.Array:
        """The same as usual RBM"""
        h = self.linear(x)
        return jnp.prod(jnp.cosh(h))

    def init_internal(self, x: jax.Array) -> PyTree:
        """Compute the initial hidden units for local updates in `ref_forward`."""
        h = self.linear(x)
        return h

    def ref_forward(
        self,
        s: jax.Array,
        s_old: jax.Array,
        nflips: int,
        internal: PyTree,
        return_update: bool,
    ) -> Union[jax.Array, tuple[jax.Array, PyTree]]:
        """
        Forward pass with reference to the old configuration and the number of flipped spins.
        This is the core function of local updates.
        """
        # A marker that local updates are being used.
        print("Using local updates")

        diff = s - s_old
        idx_flip = jnp.flatnonzero(diff, size=nflips)
        h_diff = self.linear.weight[:, idx_flip] @ diff[idx_flip]
        h_new = internal + h_diff
        psi = jnp.prod(jnp.cosh(h_new))
        if return_update:
            return psi, h_new
        else:
            return psi

Be careful that init_internal and ref_forward will be automatically jitted in Quantax. In ref_forward, only nflips and return_update will be treated as static arguments. Therefore, the shapes of arrays shouldn’t depend on other inputs. For instance, jnp.flatnonzero will trigger jit error in the sampling and local energy examples below if size=nflips is not specified.

The correctness of local updates is checked here.

model = RBM_Ref(4 * L)
s_old = qtx.utils.rand_states()
internal = model.init_internal(s_old)

nflips = 1
s_new = s_old.at[0].multiply(-1)

print("Testing direct forward ...")
psi_direct = model(s_new)
print("Direct forward psi: ", psi_direct)

print("Testing local updates ...")
psi_ref = model.ref_forward(s_new, s_old, nflips, internal, return_update=False)
print("Local updates psi: ", psi_ref)
assert jnp.isclose(psi_direct, psi_ref)
Testing direct forward ...
Direct forward psi:  3.794001513041848e+40
Testing local updates ...
Using local updates
Local updates psi:  3.7940015130418334e+40

VMC with local updates#

To use RefModel in VMC, one needs to wrap it by Variational. Variational with RefModel provides batched and jitted init_internal(), ref_forward(), and ref_forward_with_updates() methods.

state = qtx.state.Variational(model)

s_old = qtx.utils.rand_states(16384)
internal = state.init_internal(s_old)

nflips = 1
idx_segment = jnp.arange(s_old.shape[0])
s_new = s_old.at[:, 0].multiply(-1)

print("Testing direct forward ...")
psi_direct = state(s_new)

print("Testing local updates ...")
psi_ref = state.ref_forward(s_new, s_old, nflips, idx_segment, internal)
assert jnp.allclose(psi_direct, psi_ref)
Testing direct forward ...
Testing local updates ...
Using local updates

Then we can test the time cost. ref_forward is indeed faster than direct forward pass.

%timeit jax.block_until_ready(state(s_new))
455 μs ± 7.73 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
%timeit jax.block_until_ready(state.ref_forward(s_new, s_old, nflips, idx_segment, internal))
247 μs ± 2.17 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

The samplers in Quantax automatically utilize local updates whenever possible.

sampler = qtx.sampler.LocalFlip(state, nsamples=16384)
samples = sampler.sweep()
Using local updates

You might see memory overflow errors if internal stores too many values. To avoid it, please check the documentation of max_parallel in quantax.state.Variational.__init__().

The local updates are also utilized when computing local energies. Here is an example.

H = qtx.operator.Ising(h=1.0)
Eloc = H.Oloc(state, samples)
Using local updates

To disable local updates, one can set use_ref=False when defining Variational. Using this trick, we can check that local energies computed by local updates and direct forward passes are equivalent.

state_direct = qtx.state.Variational(state.model, use_ref=False)
Eloc_direct = H.Oloc(state_direct, samples)

assert jnp.allclose(Eloc, Eloc_direct)