Samples and Measurement#

In this tutorial, we will introduce how to use Quantax to generate samples.

In the beginning, we prepare ground states of transverse-field Ising models from which we will generate samples.

import numpy as np
import jax
import jax.numpy as jnp
import quantax as qtx
import matplotlib.pyplot as plt

lattice = qtx.sites.Square(4)

H0 = qtx.operator.Ising(h=1.0)
E0, wf0 = H0.diagonalize()
state0 = qtx.state.DenseState(wf0)

H1 = qtx.operator.Ising(h=3.0)
E1, wf1 = H1.diagonalize()
state1 = qtx.state.DenseState(wf1)

H2 = qtx.operator.Ising(h=10.0)
E2, wf2 = H2.diagonalize()
state2 = qtx.state.DenseState(wf2)

Exact sampler#

ExactSampler generates samples by sampling the full wavefunction. If the provided state is not a DenseState, it will generate a DenseState internally through forward pass. The sampler is defined below.

sampler0 = qtx.sampler.ExactSampler(state0, nsamples=1024)
sampler1 = qtx.sampler.ExactSampler(state1, nsamples=1024)
sampler2 = qtx.sampler.ExactSampler(state2, nsamples=1024)

Now we can generate samples.

samples0 = sampler0.sweep()
samples1 = sampler1.sweep()
samples2 = sampler2.sweep()

print(samples0)
print("spins shape", samples0.spins.shape)
print("psi shape", samples0.psi.shape)
Samples(spins=Array([[ 1,  1,  1, ...,  1,  1,  1],
       [ 1,  1,  1, ...,  1,  1,  1],
       [-1, -1, -1, ..., -1, -1, -1],
       ...,
       [ 1,  1,  1, ...,  1,  1,  1],
       [-1, -1, -1, ..., -1, -1, -1],
       [ 1,  1,  1, ...,  1,  1,  1]], dtype=int8), psi=array([-0.62183999, -0.00981954, -0.62184007, ..., -0.62183999,
       -0.62184007, -0.62183999]), state_internal=None, reweight_factor=Array([1., 1., 1., ..., 1., 1., 1.], dtype=float64))
spins shape (1024, 16)
psi shape (1024,)

The generated sample is a PyTree with 4 terms: spins, psi, state_internal, and reweight_factor. Beginners usually only need spins and psi, which correspond to the generated spin configurations \(s\) and their wavefunctions \(\psi(s)\). The probability of samples is given by \(|\psi(s)|^2\)

Let’s check how magnetization changes by looking at

\[ \left< M^2 \right> = \left< \left( \frac{1}{N} \sum_i \sigma_i^z \right)^2 \right>. \]
from quantax.operator import sigma_z

N = lattice.Nsites
M = sum(sigma_z(i) for i in range(N)) / N
M2 = M @ M

h = [1.0, 3.0, 10.0]

exactM2 = np.array([state0 @ M2 @ state0, state1 @ M2 @ state1, state2 @ M2 @ state2])
sampledM2 = np.array([
    M2.expectation(state0, samples0),
    M2.expectation(state1, samples1),
    M2.expectation(state2, samples2),
])

plt.semilogx(h, exactM2, 'ks--', label='Exact')
plt.semilogx(h, sampledM2, 'o-', label='Sampled')
plt.xlabel('Transverse field h')
plt.ylabel(r'$\langle M^2 \rangle$')
plt.ylim(0, 1)
plt.legend()
plt.show()
../_images/3ce112ce0ad2f1502064e8b98591ece9d33461928c1e2fa8beeb3df16076ca90.png

One can also show snapshots from generated samples.

fig, axes = plt.subplots(1, 3, figsize=(9, 3))

axes[0].imshow(samples0.spins[0].reshape(4, 4))
axes[0].axis('off')
axes[0].set_title("h=0.0")

axes[1].imshow(samples1.spins[0].reshape(4, 4))
axes[1].axis('off')
axes[1].set_title("h=3.0")

axes[2].imshow(samples2.spins[0].reshape(4, 4))
axes[2].axis('off')
axes[2].set_title("h=10.0")

plt.show()
../_images/d04af5de4e26ead990b1ed267452417f32d992bd7fcfd850446f017ade881468.png

Metropolis sampler#

ExactSampler in general cannot be used in large systems, in which case we need metropolis samplers. Below we show how to create a LocalFlip sampler, which proposes new spin configurations by flipping spins locally.

sampler0 = qtx.sampler.LocalFlip(state0, nsamples=1024)
sampler1 = qtx.sampler.LocalFlip(state1, nsamples=1024)
sampler2 = qtx.sampler.LocalFlip(state2, nsamples=1024)

You might have noticed that it takes some time to construct the Metropolis sampler. It’s because the spins stored in the sampler are thermalized during initialization. You can manually set an argument thermal_steps=0 to initialize without thermalization, but please remember to thermalize spin configurations in the future.

The magnetization can be measured similarly.

samples0 = sampler0.sweep()
samples1 = sampler1.sweep()
samples2 = sampler2.sweep()

exactM2 = np.array([state0 @ M2 @ state0, state1 @ M2 @ state1, state2 @ M2 @ state2])
sampledM2 = np.array([
    M2.expectation(state0, samples0),
    M2.expectation(state1, samples1),
    M2.expectation(state2, samples2),
])

plt.semilogx(h, exactM2, 'ks--', label='Exact')
plt.semilogx(h, sampledM2, 'o-', label='Sampled')
plt.xlabel('Transverse field h')
plt.ylabel(r'$\langle M^2 \rangle$')
plt.ylim(0, 1)
plt.legend()
plt.show()
../_images/b996c275661476200d890e195bae8174110e8ac30df311f4835e38bc9a3d4843.png

By default, the initial thermalization takes 20 * N updates, and each sweep takes 2 * N updates. This is an empirical balance between efficiency and reliability. For better thermalization, one can specify thermal_steps and sweep_steps arguments in the initialization.

One can also customize metropolis samplers by an inheritance of qtx.sampler.Metropolis. Here we show an example of defining a sampler that flips all spins in a lattice.

from functools import partial
from jaxtyping import Key


class FlipLattice(qtx.sampler.Metropolis):

    @partial(jax.jit, static_argnums=0)
    def propose(self, key: Key, old_spins: jax.Array):
        new_spins = -old_spins
        return new_spins

The sampler we just defined is obviously not ergodic. To make it work, we want to conbine it with LocalFlip. This can be done as shown below.

sampler_local = qtx.sampler.LocalFlip(state1, nsamples=896, thermal_steps=0)
sampler_global = FlipLattice(state1, nsamples=128, thermal_steps=0)
sampler_mix = qtx.sampler.MixSampler([sampler_local, sampler_global])

sampler_mix generates 896+128=1024 samples per sweep. In each proposal, there is 896/1024 chance to update all spin configurations by LocalFlip, and 128/1024 to update by FlipLattice. Here we check the magnetization to verify that the generated samples are still correct.

samples = sampler_mix.sweep()

print("Exact magnetization for h=1.0:", state1 @ M2 @ state1)
print("Sampled magnetization for h=1.0:", M2.expectation(state1, samples1))
Exact magnetization for h=1.0: 0.29388403111101535
Sampled magnetization for h=1.0: 0.30169677734375

Measurement#

Now we introduce how to use the generated samples to measure physical quantities.

As we have shown, one can call op.expectation(state, samples) to estimate the expectation value \(\braket{\hat O}\) of an operator. To estimate the varaince \(\braket{\hat O^2} - \braket{\hat O}^2\) at the same time, one can call

expectation, var = M2.expectation(state0, samples0, return_var=True)
print("h=1.0: <M2> =", expectation, "var(M2) =", var)
h=1.0: <M2> = 0.9398193359375 var(M2) = 0.01366935670375824

As we know, the uncertainty of an observable \(O\) in Monte Carlo methods is

\[ u(O) \approx \sqrt{\frac{\mathrm{Var}(O)}{N_s}}, \]

where \(N_s\) is the number of samples. Then we can estimate the uncertainty of the observable as

uncertainty = jnp.sqrt(var / samples0.nsamples)
print(f"h=1.0: <M2> = {expectation} ± {uncertainty}")
print("Exact M2 =", exactM2[0])
h=1.0: <M2> = 0.9398193359375 ± 0.0036536257547830353
Exact M2 = 0.9399276867036591

Apart from observables, some other quantities are also important in quantum physics. An important one is the overlap (fidelity) between two states. One can rewrite it as

\[\begin{split} \begin{aligned} \mathcal{F}(\psi, \phi) &= \frac{\braket{\psi|\phi} \braket{\phi|\psi}}{\braket{\psi|\psi} \braket{\phi|\phi}} \\ &= \frac{\sum_s \psi^*(s) \phi(s) \sum_{s'} \phi^*(s') \psi(s')}{||\psi||^2 ||\phi||^2} \\ &= \sum_s \frac{|\psi(s)|^2}{||\psi||^2} \frac{\phi(s)}{\psi(s)} \sum_{s'} \frac{|\phi(s')|^2}{||\phi||^2} \frac{\psi(s')}{\phi(s')} \\ &= \left< \frac{\phi(s)}{\psi(s)} \right>_{s \sim |\psi|^2} \left< \frac{\psi(s)}{\phi(s)} \right>_{s \sim |\phi|^2} \end{aligned} \end{split}\]

Here, we show the overlap between state0 and state1 as an example.

samples0 = sampler0.sweep()
samples1 = sampler1.sweep()

s0 = samples0.spins
psi_s0 = samples0.psi
phi_s0 = state1(s0)

s1 = samples1.spins
psi_s1 = state0(s1)
phi_s1 = samples1.psi

print("Estimated overlap: ", jnp.mean(phi_s0 / psi_s0) * jnp.mean(psi_s1 / phi_s1))
print("Exact overlap: ", abs(state0 @ state1) ** 2)
Estimated overlap:  0.1494855468512831
Exact overlap:  0.14122972655126168

Another important quantity is the entanglement entropy. In VMC, it’s only convenient to compute the Rényi-2 entropy

\[ S_2(A, B) = -\log \mathrm{Tr}(\rho_A^2) \]

when we split the whole system into two subregions A and B.

For a pure state \(\ket{\psi}\), matrix elements of \(\rho_A\) is given by $\( (\rho_A)_{s_A, s_A'} = \frac{\sum_{s_B} \psi(s_A, s_B) \psi^*(s_A', s_B)}{||\psi||^2}, \)\( where \)\ket{s} = \ket{s_A} \otimes \ket{s_B}$ splits a basis state into basis states in two subregions.

Then we can rewrite \(\mathrm{Tr}(\rho_A^2)\) into

\[\begin{split} \begin{aligned} \mathrm{Tr}(\rho_A^2) &= \sum_{s_A, s_A'} (\rho_A)_{s_A, s_A'} (\rho_A)_{s_A', s_A} \\ &= \frac{\sum_{s_A, s_A', s_B, s_B'} \psi(s_A, s_B') \psi^*(s_A', s_B') \psi(s_A', s_B) \psi^*(s_A, s_B)}{||\psi||^4} \\ &= \sum_{s_A, s_A', s_B, s_B'} \frac{|\psi(s_A, s_B) \psi(s_A', s_B')|^2}{||\psi||^4} \frac{\psi(s_A, s_B') \psi(s_A', s_B)}{\psi(s_A, s_B) \psi(s_A', s_B')} \\ &= \left< \frac{\psi(s_A, s_B') \psi(s_A', s_B)}{\psi(s_A, s_B) \psi(s_A', s_B')} \right>_{s \sim |\psi|^2, s' \sim |\psi|^2}. \end{aligned} \end{split}\]

Therefore, the entropy is $\( S_2(A, B) = -\log \left< \frac{\psi(s_A, s_B') \psi(s_A', s_B)}{\psi(s_A, s_B) \psi(s_A', s_B')} \right>_{s \sim |\psi|^2, s' \sim |\psi|^2}. \)$

Here, we show an example of entropy computation, where we split the whole lattice into A and B with equivalent sizes.

def entropy(state: qtx.state.Variational, samples: qtx.sampler.Samples):
    Lx, Ly = lattice.shape[1:]
    Ns = samples.nsamples
    spins = samples.spins.reshape(Ns, Lx, Ly)
    sA = spins[: Ns // 2, : Lx // 2]
    sB = spins[: Ns // 2, Lx // 2 :]
    sA_ = spins[Ns // 2 :, : Lx // 2]
    sB_ = spins[Ns // 2 :, Lx // 2 :]
    
    swap_s1 = jnp.concatenate([sA, sB_], axis=1).reshape(Ns // 2, -1)
    swap_s2 = jnp.concatenate([sA_, sB], axis=1).reshape(Ns // 2, -1)
    psi1 = state(swap_s1)
    psi2 = state(swap_s2)

    Tr_rhoA2 = jnp.mean((psi1 * psi2) / (samples.psi[: Ns // 2] * samples.psi[Ns // 2 :]))
    entropy = -jnp.log(Tr_rhoA2)
    return entropy
samples0 = sampler0.sweep()
samples1 = sampler1.sweep()
samples2 = sampler2.sweep()

S0 = entropy(state0, samples0)
S1 = entropy(state1, samples1)
S2 = entropy(state2, samples2)

From the exact wave-funciton, one can utilize quspin to compute the exact entropy.

basis = state0.basis

S0_exact = basis.ent_entropy(state0.psi, density=False, alpha=2.0)["Sent_A"]
S1_exact = basis.ent_entropy(state1.psi, density=False, alpha=2.0)["Sent_A"]
S2_exact = basis.ent_entropy(state2.psi, density=False, alpha=2.0)["Sent_A"]

plt.plot(h, [S0_exact, S1_exact, S2_exact], 'ks--', label='Exact')
plt.plot(h, [S0, S1, S2], 'o-', label='Sampled')
plt.xlabel('Transverse field h')
plt.ylabel(r'$S_2$')
plt.ylim(0, 1)
plt.legend()
plt.show()
../_images/e5b3e6bc819e1795f73aea0e77309cdf193860b27a5835d397c431ddd3be8a02.png