Annealed importance sampling#
Simulating samples from distributions is a central problem in Statistics and Machine Learning, because it enables estimating important quantities such as integrals. For example, we are often interested in evaluating integrals of the form
where
A standard Monte Carlo method for handling intractable integrals is importance sampling.
Importance sampling gets around the intractability of
Unfortunately, if
Importance sampling#
Suppose we wish to evaluate an integral of the form
If we could draw samples from
However, in many applications of interest, we cannot draw samples from
A technical, but important, requirement here is that
The ratios
However, in practice we always end up using a finite number of samples, so the importance sampling estimate will not be exactly equal to
then using
Using more samples reduces the overall variance, but increases the computational cost.
A common issue that arises with the importance sampling estimator is that its variance,
Let’s look at a simple example of Importance Sampling.
Suppose
Of course, we can in fact draw samples from a mixture of Gaussians directly, but let’s pretend we can’t.
Now, define
Using importance sampling, we can obtain an estimate of this intergral.
Importance Sampling estimate I = -5.40
Note that the exact value of the integral is
Importance Sampling over 100 trials I = 0.10 +/- 13.309.
So even though the estimator is unbiased, it has a large variance.
Before we go further, let’s compare the importance sampling estimate to a Monte Carlo estimate using samples directly from
Monte Carlo with samples from p over 100 trials I = -0.07 +/- 1.848.
The Importance sampling estimator has seven times larger random error than the Monte Carlo estimator.
To see why, this occurs, let’s plot the raw samples drawn from
First off, this plot illustrates how importance sampling works.
Although
However, the samples which receive a high importance are relatively infrequent, which means that sometimes we may get sample with a large importance weight in one of the two Gaussian modes, but not in the other mode.
Since this sample has a large importance weight, it greatly affects the overall estimate, introducing lots of random error.
Looking at a histogram of the importance weights we see that most weights are very small, and it’s only a few large weights which dominate the value of the integral (note the
Going a little further, note that the expectation of the value of an importance weight is always equal to one since
This means that if
Lemma 59 (Lower bound to importance weight variance)
Given distributions
where
Derivation (Lower bound to importance weight variance)
The variance in the importance weights can be written as
By applying Jensen’s inequality once, we can get a lower bound to the expectation above, to obtain
and similarly
which is the lower bound in the lemma.
Note that when
This is where Annealed Importance Sampling (AIS) becomes useful.
AIS is an importance sampling method which uses an annealing procedure based on Markov Chain Monte Carlo (MCMC), to produce samples whose distribution is closer to
Importance-weighted MCMC#
Motivated by the above intuition, given some initial samples from
Definition 141 (Importance weighted MCMC algorithm)
Given a proposal density
and return the sample
Note that the only requirement on the transition kernels
and gets closer to
each of the nested integrals above is intractable, which means we cannot compute
We can then define the reverse transition kernels
Because
so the reverse kernels integrate to
Now, consider performing importance sampling in this augmented space, with
The importance weights
is unbiased.
Therefore, if we set
which is also unbiased.
Crucially, the importance weights
and all terms coming from the transition kernels cancel in the importance weight ratio, yielding
By performing importance sampling in this augmented space, we have got around the issue of intractable importance weights, by cancelling out a load of terms. However, unfortunately these importance weight are exactly the same as the importance weights of the standard importance sampling estimator, so this algorithm does not improve on the variance of the standard estimator at all! However, it is possible to modify this algorithm to obtain better importance weights, while still taking advantage of the cancellation of the transition kernels.
Annealed Importance Sampling#
Given a sequnece
These distributions interpolate between
Definition 142 (Annealed Importance Sampling)
Given a target density
and let
and return the sample
Note that drawing samples according to this algorithm and setting
Implementation#
To implement this procedure, we need to specify the transition kernels, and the sequence of annealing parameters
Show code cell source
class TransitionKernel:
def __init__(self):
pass
def __call__(self, x: tf.Tensor):
pass
class GaussianTransitionKernel(TransitionKernel):
def __init__(self, scale: tf.Tensor):
self.scale = scale
def __call__(self, x: tf.Tensor, distribution: tf.Tensor):
# Create forward proposal distribution and propose next point
forward = tfd.Normal(loc=x, scale=self.scale)
next_x = forward.sample()
# Create reverse proposal distribution
reverse = tfd.Normal(loc=next_x, scale=self.scale)
# Compute acceptance probability
log_prob_1 = forward.log_prob(next_x) + distribution(next_x)
log_prob_2 = reverse.log_prob(x) + distribution(x)
log_prob_ratio = log_prob_1 - log_prob_2
p = tf.math.exp(tf.reduce_min([0., log_prob_ratio]))
# Accept reject step
accept = tf.random.categorical(
[[tf.math.log(1. - p), tf.math.log(p)]],
num_samples=1,
dtype=tf.int32
)[0, 0]
x_accept = tf.convert_to_tensor([x, next_x])[accept]
return x_accept, accept
We can then put together an AnnealedImportanceSampler
, which accepts an initial and a target distribution, a transition kernel and a list containing a schedule for
Show code cell source
class AnnealedImportanceSampler:
def __init__(
self,
initial_distribution: tfd.Distribution,
target_distribution: tfd.Distribution,
transition_kernel: TransitionKernel,
betas: List[float],
):
self.initial_distribution = initial_distribution
self.target_distribution = target_distribution
self.transition_kernel = transition_kernel
self.betas = betas
self.num_steps = betas.shape[0]
def __call__(self, num_samples: int) -> List[tf.Tensor]:
# Draw samples from intial distribution
x0 = self.initial_distribution.sample([num_samples])
# Run AIS chain on the initial samples
samples_and_log_weights = tf.map_fn(self.run_chain, x0)
return x0, samples_and_log_weights
@tf.function(jit_compile=True)
def run_chain(self, x: tf.Tensor, *args) -> tf.Tensor:
# Initialise chain history and current distribution
chain_history = [x]
annealed_log_prob = self.initial_distribution.log_prob
# Initialise log importance weight
log_w = - annealed_log_prob(x)
for i in tf.range(self.num_steps):
# Create next annealed distribution
next_annealed_log_prob = self.log_geometric_mixture(
beta=betas[i],
)
log_w = log_w + next_annealed_log_prob(x)
# Propose next point
x, accept = self.transition_kernel(
x=x,
distribution=next_annealed_log_prob,
)
log_w = log_w - next_annealed_log_prob(x)
annealed_log_prob = next_annealed_log_prob
chain_history.append(x)
log_w = log_w + self.target_distribution.log_prob(x)
return tf.convert_to_tensor([x, log_w], dtype=tf.float64)
def log_geometric_mixture(self, beta: tf.Tensor) -> Callable:
def _log_geometric_mixture(x: tf.Tensor) -> tf.Tensor:
log_prob_1 = self.initial_distribution.log_prob(x)
log_prob_2 = self.target_distribution.log_prob(x)
return ((1. - beta) * log_prob_1 + beta * log_prob_2)
return _log_geometric_mixture
Toy experiment#
Now we have an AIS sampler which can be used with arbitrary annealing parameters
where we start from large and negative values to large and positive values.
This gives
Show code cell source
dtype = tf.float64
transition_scale = tf.convert_to_tensor(0.3, dtype=dtype)
# Intialise transition kernel
transition_kernel = GaussianTransitionKernel(
scale=transition_scale
)
# Initialise betas
betas = tf.nn.sigmoid(10. * (tf.cast(tf.linspace(1e-3, 1., 1000), dtype=dtype) - 0.5))
# Initialise AIS sampler
sampler = AnnealedImportanceSampler(
initial_distribution=q,
target_distribution=p,
transition_kernel=transition_kernel,
betas=betas,
)
Let’s draw some samples on the same problem we considered earlier, and visualise them as before.
Show code cell source
plt.figure(figsize=(8, 3))
# Plot proposal
plt.plot(x_plot, p_plot, color="tab:blue")
# Plot target
plt.plot(x_plot, q_plot, color="tab:green")
# Plot initial samples
plt.hist(
x0,
density=True,
color="tab:green",
alpha=0.4,
bins=tf.linspace(-4., 4., 40),
zorder=1,
label="Initial samples from $q$",
)
# Plot samples after samples
plt.hist(
samples,
density=True,
color="tab:purple",
alpha=0.4,
bins=tf.linspace(-4., 4., 40),
zorder=1,
label="AIS samples",
)
# Plot samples weighted by their importance weights
plt.hist(
samples,
weights=w,
density=True,
color="tab:blue",
alpha=0.4,
bins=tf.linspace(-4., 4., 40),
zorder=2,
label="Weighted AIS samples",
)
plt.title("AIS samples", fontsize=22)
plt.xlabel("$x$", fontsize=18)
plt.ylabel("$p(x),~ q(x)$", fontsize=18)
plt.xticks(np.linspace(-4., 4., 5), fontsize=18)
plt.yticks(np.linspace(0., 1.0, 3), fontsize=18)
plt.xlim([-4., 4.])
plt.ylim([0, 1.1])
plt.legend(loc="upper left", fontsize=16)
twin_axis = plt.gca().twinx()
plt.plot(x_plot, f_plot, color="tab:red", label="$f(x)$")
plt.xlabel("$x$", fontsize=18)
plt.ylabel("$f(x)$", fontsize=18)
plt.legend(loc="upper right", fontsize=16)
plt.yticks(tf.linspace(-60., 60., 3), fontsize=18)
plt.show()
We observe that the distribution of the AIS samples (purple bars) is similar to the target distribution
As expected, most of the AIS importance weights are close to
Annealed Importance Sampling over 100 trials I = 0.02 +/- 2.198.
Conclusion#
In importance sampling, using a proposal distribution that is significantly different from the target distribution results in a large variance of the resulting importance weights. Large variance in the importance weights typically induces large variance in downstream estimates obtained using these weights. AIS is a method which helps address this issue. AIS introduces a sequence of annealed distributions, which interpolate between the target and proposal distribution, iteratively transforming samples from the proposal distribution using a sequence of Markov transition kernels which preserve the interpolating (annealed) distributions. In this way, AIS typically reduces the variance in the importance weights, which in turn reduces the variance in downstream Monte Carlo estimates. For more details, Radford Neal’s original paper introducing AIS[Neal, 2001] is a classic worth reading.