Stein variational gradient descent#

Star Issue Watch Follow

One central challenge in Statistics and Bayesian machine learning is dealing with intractable distributions. In many cases, our models involve complicated distributions, which can be difficult to integrate out or sample from - consider for example the posterior distribution over the parameters of a Bayesian neural network. A great deal of approaches have been devised to enable efficient handling of complicated distributions, broadly falling in two categories: Variational Inference (VI) and Markov Chain Monte Carlo (MCMC) - we focus on the former only here.

Suppose we are working with an intractable distribution \(p\). VI seeks to approximate \(p\) by another approximate distribution \(q\), constrained to be in a tractable family of distributions - such as an independent Gaussian. By optimising a similarity metric between \(q\) and \(p\), such as the KL-divergence, VI produces a (hopefully) decent approximation which captures some of the important aspects of the target. VI can be much faster than MCMC, but it is an approximate method. The severity of approximation involved in VI is largely affected by the family of the approximate distribution, and can be very large for many applications of interest.

Stein Variational Gradient Descent (SVGD) [Liu and Wang, 2019] is an algorithm which enables approximate inference for intractable distributions, wihtout the severe constraints of the approximating family of VI. Much like VI, it minimises the KL divergence between \(q\) and \(p\), but unlike VI it does not involve heavy assumptions on the family of \(q\). Instead, SVGD evolves a finite set of particles, which approximates \(q\), by a sequence of transformations such that \(q\) gets progressively closer to \(p\).

Derivation of SVGD#

The idea SVGD is to approximate a target distribution \(p\) by an approximate distribution \(q\), by applying a sequence of transformations to \(q\) which will bring it closer to \(p\). By applying the transformation (from a restricted family of transformations) which most rapidly reduces the KL divergence, we will obtain an algorithm that looks much like steepest-direction gradient descent.

Invertible transformations#

Suppose we have an initial distribution \(q\), which we pass through a transformation \(T : \mathbb{R}^N \to \mathbb{R}^N\), that is

\[\begin{align} z = T(x),~~\text{ where } x \sim q(x). \end{align}\]

If the map \(T\) is invertible, we can easily compute the density of the transformed variable \(z\) via the change of variables formula. To ensure \(T\) is invertible, let us set \(T(x) = x + \epsilon \phi(x)\), where \(\epsilon\) is a small coefficient and \(\phi : \mathbb{R}^N \to \mathbb{R}^N\). If \(\phi\) is smooth and \(\epsilon\) is sufficiently small, then \(T\) is invertible, which which means we can easily compute the density of \(z\). We turn to the question of how to pick an appropriate \(\phi\).

Direction of steepest descent#

Let us use the subscript notation \(q_{[T]}\) to denote the distribution obtained by passing \(q\) through \(T\). Then we are interested in picking a \(T\) which minimises \(\text{KL}(q_{[T]} || p)\). First, we compute the derivative of the KL w.r.t. \(\epsilon\), which we obtain in closed form.

Theorem 90 (Proof: Gradient of KL is the KSD)

Let \(x \sim q(x)\), and \(T(x) = x + \epsilon \phi(x)\), where \(\phi\) is a smooth function. Then

\[\begin{align} \nabla_{\epsilon}\text{KL}(q_{[T]} || p) \big|_{\epsilon = 0} = - \mathbb{E}_{x \sim q}\left[\text{trace} \mathcal{A}_p \phi(x) \right], \end{align}\]

where \(q_{[T]}\) is the density of \(T(x)\) and

\[\begin{align} \mathcal{A}_p \phi(x) = \nabla_x \log p(x)\phi^\top(x) + \nabla_x \phi(x). \end{align}\]
Gradient of KL is the KSD

Let \(p_{\left[T^{-1}\right]}(x)\) denote the density of \(z = T^{-1}(x)\) when \(x \sim p(x)\). By changing the variable of integration from \(z\) to \(x = T^{-1}(x)\), we obtain

\[\begin{split}\begin{align} \text{KL}(q_{[T]} || p) &= \int q_{[T]}(z) \log \frac{q_{[T]}(z)}{p(z)} dz \\ &= \int q(x) \left[ \log q(x) - \log p_{\left[T^{-1}\right]}(x) \right] dx. \end{align}\end{split}\]

This change of variables is convenient because now only one term in the integral depends on \(\epsilon\), that is \(p_{\left[T^{-1}\right]}(x)\). Now taking the derivative with respect to \(\epsilon\) we obtain

\[\begin{split}\begin{align} \nabla_{\epsilon} \text{KL}(q_{[T]} || p) &= - \int q(x) \nabla_{\epsilon} \log p_{\left[T^{-1}\right]}(x) dx, \\ &= - \int q(x) \nabla_{\epsilon} \log p_{\left[T^{-1}\right]}(x) dx, \end{align}\end{split}\]

and using the fact that

\[\begin{align} \log p_{\left[T^{-1}\right]}(x) &= \log p(T(x)) + \log |\nabla_x T(x)|, \end{align}\]

we obtain the expression

\[\begin{split}\begin{align} \nabla_{\epsilon} \log p_{\left[T^{-1}\right]}(x) &= \nabla \log p(T(x))^\top \nabla_\epsilon T(x) + \nabla_\epsilon \log |\nabla_x T(x)|, \\ &= \nabla \log p(T(x))^\top \nabla_\epsilon T(x) + \text{trace}\left[(\nabla_x T(x))^{-1} \nabla_\epsilon \nabla_x T(x)\right], \end{align}\end{split}\]

where we have used the identity

\[\begin{align} \nabla_{\epsilon} \log |\det A| = \text{trace} A^{-1} \nabla_{\epsilon} A, \end{align}\]

we arrive at the following expression for the derivative

\[\begin{align} \nabla_{\epsilon} \text{KL}(q_{[T]} || p) &= - \mathbb{E}_{x \sim q} \left[\nabla \log p(T(x))^\top \nabla_\epsilon T(x) + \text{trace} (\nabla_x T(x))^{-1} \nabla_\epsilon \nabla_x T(x)\right]. \end{align}\]

Setting \(T(x) = x + \epsilon \phi(x)\) yields the result

\[\begin{split}\begin{align} \nabla_{\epsilon} \text{KL}(q_{[T]} || p) &= - \mathbb{E}_{x \sim q} \left[\nabla \log p(x)^\top \phi(x) + \text{trace}\left[\nabla_x \phi(x) \right]\right], \\ &= - \mathbb{E}_{x \sim q} \left[\text{trace} \mathcal{A}_p \phi(x) \right]. \end{align}\end{split}\]

This result gives us the rate of change of the KL as \(\epsilon\) increases, for given \(\phi\). Now, we want to pick \(\phi\) such that \(-\mathbb{E}_{x \sim q} \left[\text{trace} \mathcal{A}_p \phi(x) \right]\) is as negative as possible. However, this minimisation is not well defined, because one can scale \(\phi\) by an arbitrary scalar making the expectation unbounded. Further, the minimisation is not analytically or computationally tractable either. This issue can be resolved by considering a constrained version of this optimisation problem instead, using Reproducing Kernel Hilbert Spaces (RKHS).

Let \(k\) be a positive-definite kernel, defining a corresponding RKHS \(\mathcal{H}\) with inner product \(\langle\cdot, \cdot \rangle_{\mathcal{H}}\). Let also \(\mathcal{H}_D = \mathcal{H} \times ... \times \mathcal{H}\) be the Hilbert space of \(D\)-dimensional vector valued functions \(f = (f_1, ..., f_D) : f_1, ..., f_D \in \mathcal{H}\) with corresponding inner product

\[\begin{align} \langle f, g \rangle_{\mathcal{H}_D} = \sqrt{\sum_{d = 1}^D \langle f_d, g_d \rangle_{\mathcal{H}_D}^2}. \end{align}\]

If we now constrain \(\phi \in \mathcal{H}_D\) and \(|| \phi ||_{\mathcal{H}_D} \leq 1\) we obtain[Liu et al., 2016] the following analytic expression for the direction of steepest descent.

Theorem 91 (Direction of steepest descent)

The function \(\phi^* \in \mathcal{H}_D, || \phi^* ||_{\mathcal{H}_D} \leq 1\) which maximises the rate of decrease KL-divergence is

\[\begin{align} \phi^*(\cdot) = \beta / ||\beta||_{\mathcal{H}_d} ,~~\beta(\cdot) = \mathbb{E}_{x \sim q}\left[ k(x, \cdot) \nabla_x \log p(x) + \nabla_x k(x, \cdot)\right]. \end{align}\]
Proof: Direction of steepest descent

For \(f \in \mathcal{H}_D\) we have the following equality

\[\begin{split}\begin{align} \langle f, \beta \rangle_{\mathcal{H}_D} &= \sum_{d = 1}^D \langle f_d(\cdot), \beta_d \rangle_{\mathcal{H}} \\ &= \sum_{d = 1}^D \left \langle f_d(\cdot), \mathbb{E}_{x \sim q}\left[k(x, \cdot) \nabla_{x_d} \log p(x) + \nabla_{x_d} k(x, \cdot)\right] \right\rangle_{\mathcal{H}} \\ &= \sum_{d = 1}^D \mathbb{E}_{x \sim q}\left[\nabla_{x_d} \log p(x) \langle f_d(\cdot), k(x, \cdot) \rangle + \langle f_d(\cdot), \nabla_{x_d} k(x, \cdot) \rangle \right] \rangle_{\mathcal{H}} \\ &= \sum_{d = 1}^D \mathbb{E}_{x \sim q}\left[\nabla_{x_d} \log p(x) f_d(x) + \nabla_{x_d} f_d(x) \right] \\ &= \mathbb{E}_{x \sim p}\left[\mathcal{A}_q f(x)\right]. \end{align}\end{split}\]

Therefore, the \(f \in \mathcal{H}_D\) which maximises \(\mathbb{E}_{x \sim p}\left[\mathcal{A}_q f(x)\right]\) is the one which maximises the inner product \(\langle f, \beta \rangle_{\mathcal{H}_D}\), which occurs when \(f\) is proportional to \(\beta\).

Empirical approximation#

Now, if we approximate \(q\) by a finite set of \(N\) particles at locations \(x_n^{(i)}, n = 1, ..., N\), at the \(i^{th}\) iteration, we obtain at the following iterative algorithm.

Definition 73 (Stein variational gradient descent)

Given a distribution \(p(x)\), a postive definite kernel \(k(x, x')\) and a set of particles with initial positions \(\{x_n^{(0)}\}_{n=1}^N\), Stein variational gradient descent evolves the particles according to

\[\begin{align} x^{(i + 1)}_n = x^{(i)}_n + \frac{\epsilon^{(i)}}{N}\sum_{m = 1}^N \left[ k(x_n, x_m) \nabla_x \log p(x)|_{x_n} + \nabla_x k(x_m, x) |_{x_n}\right]. \end{align}\]

Implementation#

The SVGD algorithm is surprisingly easy to implement, while also each step is quite cheap to evaluate. We will use SVGD to approximate a mixture-of-gaussians distribution, to allow for multiple modes.

../../../_images/d97f5b8e5b8af7381c1822e2be052237dfb4f4e1eafa3626ce11f23ce5024e67.svg

Although SVGD can use any positive-semidefinite kernel, we will focus our attention to the standard EQ kernel

\[\begin{align} k(x, x') = \exp\left(-\frac{1}{2\ell^2} (x - x')^2\right), \end{align}\]

implemented by the eq function below. The svgd_grad computes the SVGD gradients for a set of particles, using Tensorflow’s batch jacobians.

def eq(lengthscales):
    
    def kernel(x, x_):
    
        diff = x[:, None, :] - x_[None, :, :]
        quad = tf.reduce_sum((diff / lengthscales) ** 2, axis=2)
        exp = tf.exp(-0.5 * quad)
        
        return exp
    
    return kernel


@tf.function
def svgd_grad(x, logprob, kernel):
    
    x_ = tf.convert_to_tensor(x[:], dtype=tf.float32)
    x = tf.convert_to_tensor(x, dtype=tf.float32)
    
    with tf.GradientTape(persistent=True) as tape:
        
        tape.watch(x)
        
        logp = logprob(x)
        k = kernel(x, x_)
    
    dlogp = tape.gradient(logp, x)
    dk = tape.batch_jacobian(k, x)
    
    svg = (k @ dlogp + tf.reduce_sum(dk, axis=0)) / x.shape[0]
    
    return svg

Demo on mixture of Gaussians#

We can now run SVGD using a modest number of particles initialised in between the two modes.

../../../_images/1d29f3b69fad88ebc86053dae519c6c6c68b263ef03849e2a75fd584f808628f.svg

We observe that some of the particles fall into each of the two modes, capturing the bimodality of the target, something which VI with a mean-field Gaussian \(q\) cannot do.

Failure mode on mixture of Gaussians#

However, SVGD also has failure modes, as illustrated below. If we initialise the particles on one mode of two well-separated Gaussians, then the optimisation gets stuck at a local optimum which fails to capture one of the two modes of the mixture of Gaussians. Even though SVGD may be able to express more expressive approximate distributions than mean-field VI, it is not guaranteed that the optimisation will be able to find such a distribution.

../../../_images/bbe3cf626fcd26fefccbaec3f6263d2b226005283cbd1aa335e592673a99b560.svg

Conclusion#

This section presented SVGD, a very interesting general-purpose algorithm for approximate inference. SVGD works by simulating a set of particles, regarded as an empirical approximation of a distribution \(q\) which itself approximates the target distribution \(p\). By evolving \(q\) according to a sequence of transformations, each of which is determined as the direction of steepest decrease in the KL between \(q\) and \(p\), SVGD can produce a flexible approximation to the target \(p\).

References#

LLJ16

Qiang Liu, Jason Lee, and Michael Jordan. A kernelized stein discrepancy for goodness-of-fit tests. In International conference on machine learning, 276–284. PMLR, 2016.

LW19

Qiang Liu and Dilin Wang. Stein variational gradient descent: a general purpose bayesian inference algorithm. 2019. arXiv:1608.04471.