{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Stein variational gradient descent\n", "\n", "\n", "Star\n", "Issue\n", "Watch\n", "Follow\n", "\n", "One central challenge in Statistics and Bayesian machine learning is dealing with intractable distributions.\n", "In many cases, our models involve complicated distributions, which can be difficult to integrate out or sample from - consider for example the posterior distribution over the parameters of a Bayesian neural network.\n", "A great deal of approaches have been devised to enable efficient handling of complicated distributions, broadly falling in two categories: Variational Inference (VI) and Markov Chain Monte Carlo (MCMC) - we focus on the former only here.\n", "\n", "Suppose we are working with an intractable distribution $p$.\n", "VI seeks to approximate $p$ by another approximate distribution $q$, constrained to be in a tractable family of distributions - such as an independent Gaussian.\n", "By optimising a similarity metric between $q$ and $p$, such as the KL-divergence, VI produces a (hopefully) decent approximation which captures some of the important aspects of the target.\n", "VI can be much faster than MCMC, but it is an approximate method. The severity of approximation involved in VI is largely affected by the family of the approximate distribution, and can be very large for many applications of interest.\n", "\n", "Stein Variational Gradient Descent (SVGD) {cite}`liu2019stein` is an algorithm which enables approximate inference for intractable distributions, wihtout the severe constraints of the approximating family of VI.\n", "Much like VI, it minimises the KL divergence between $q$ and $p$, but unlike VI it does not involve heavy assumptions on the family of $q$.\n", "Instead, SVGD evolves a finite set of particles, which approximates $q$, by a sequence of transformations such that $q$ gets progressively closer to $p$." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "tags": [ "remove-cell" ] }, "outputs": [], "source": [ "import numpy as np\n", "\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "from matplotlib_inline.backend_inline import set_matplotlib_formats\n", "\n", "set_matplotlib_formats('pdf', 'svg')\n", "\n", "matplotlib.rcParams['mathtext.fontset'] = 'stix'\n", "matplotlib.rcParams['font.family'] = 'STIXGeneral'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Derivation of SVGD\n", "\n", "The idea SVGD is to approximate a target distribution $p$ by an approximate distribution $q$, by applying a sequence of transformations to $q$ which will bring it closer to $p$. By applying the transformation (from a restricted family of transformations) which most rapidly reduces the KL divergence, we will obtain an algorithm that looks much like steepest-direction gradient descent.\n", "\n", "### Invertible transformations\n", "\n", "Suppose we have an initial distribution $q$, which we pass through a transformation $T : \\mathbb{R}^N \\to \\mathbb{R}^N$, that is\n", "\n", "$$\\begin{align}\n", "z = T(x),~~\\text{ where } x \\sim q(x).\n", "\\end{align}$$\n", "\n", "If the map $T$ is invertible, we can easily compute the density of the transformed variable $z$ via the change of variables formula. To ensure $T$ is invertible, let us set $T(x) = x + \\epsilon \\phi(x)$, where $\\epsilon$ is a small coefficient and $\\phi : \\mathbb{R}^N \\to \\mathbb{R}^N$. If $\\phi$ is smooth and $\\epsilon$ is sufficiently small, then $T$ is invertible, which which means we can easily compute the density of $z$. We turn to the question of how to pick an appropriate $\\phi$.\n", "\n", "### Direction of steepest descent\n", "\n", "Let us use the subscript notation $q_{[T]}$ to denote the distribution obtained by passing $q$ through $T$. Then we are interested in picking a $T$ which minimises $\\text{KL}(q_{[T]} || p)$. First, we compute the derivative of the KL w.r.t. $\\epsilon$, which we obtain in closed form.\n", "\n", ":::{prf:theorem} Proof: Gradient of KL is the KSD\n", "\n", "Let $x \\sim q(x)$, and $T(x) = x + \\epsilon \\phi(x)$, where $\\phi$ is a smooth function. Then\n", " \n", "$$\\begin{align}\n", "\\nabla_{\\epsilon}\\text{KL}(q_{[T]} || p) \\big|_{\\epsilon = 0} = - \\mathbb{E}_{x \\sim q}\\left[\\text{trace} \\mathcal{A}_p \\phi(x) \\right],\n", "\\end{align}$$\n", " \n", "where $q_{[T]}$ is the density of $T(x)$ and\n", " \n", "$$\\begin{align}\n", "\\mathcal{A}_p \\phi(x) = \\nabla_x \\log p(x)\\phi^\\top(x) + \\nabla_x \\phi(x).\n", "\\end{align}$$\n", " \n", ":::\n", "\n", ":::{dropdown} Gradient of KL is the KSD\n", " \n", "Let $p_{\\left[T^{-1}\\right]}(x)$ denote the density of $z = T^{-1}(x)$ when $x \\sim p(x)$. By changing the variable of integration from $z$ to $x = T^{-1}(x)$, we obtain\n", " \n", "$$\\begin{align}\n", "\\text{KL}(q_{[T]} || p) &= \\int q_{[T]}(z) \\log \\frac{q_{[T]}(z)}{p(z)} dz \\\\\n", " &= \\int q(x) \\left[ \\log q(x) - \\log p_{\\left[T^{-1}\\right]}(x) \\right] dx.\n", "\\end{align}$$\n", " \n", "This change of variables is convenient because now only one term in the integral depends on $\\epsilon$, that is $p_{\\left[T^{-1}\\right]}(x)$. Now taking the derivative with respect to $\\epsilon$ we obtain\n", " \n", "$$\\begin{align}\n", "\\nabla_{\\epsilon} \\text{KL}(q_{[T]} || p) &= - \\int q(x) \\nabla_{\\epsilon} \\log p_{\\left[T^{-1}\\right]}(x) dx, \\\\\n", " &= - \\int q(x) \\nabla_{\\epsilon} \\log p_{\\left[T^{-1}\\right]}(x) dx,\n", "\\end{align}$$\n", " \n", "and using the fact that\n", "\n", "$$\\begin{align}\n", "\\log p_{\\left[T^{-1}\\right]}(x) &= \\log p(T(x)) + \\log |\\nabla_x T(x)|,\n", "\\end{align}$$\n", " \n", "we obtain the expression\n", "\n", "$$\\begin{align}\n", "\\nabla_{\\epsilon} \\log p_{\\left[T^{-1}\\right]}(x) &= \\nabla \\log p(T(x))^\\top \\nabla_\\epsilon T(x) + \\nabla_\\epsilon \\log |\\nabla_x T(x)|, \\\\\n", " &= \\nabla \\log p(T(x))^\\top \\nabla_\\epsilon T(x) + \\text{trace}\\left[(\\nabla_x T(x))^{-1} \\nabla_\\epsilon \\nabla_x T(x)\\right],\n", "\\end{align}$$\n", " \n", "where we have used the identity\n", " \n", "$$\\begin{align}\n", "\\nabla_{\\epsilon} \\log |\\det A| = \\text{trace} A^{-1} \\nabla_{\\epsilon} A,\n", "\\end{align}$$\n", " \n", "we arrive at the following expression for the derivative\n", " \n", "$$\\begin{align}\n", "\\nabla_{\\epsilon} \\text{KL}(q_{[T]} || p) &= - \\mathbb{E}_{x \\sim q} \\left[\\nabla \\log p(T(x))^\\top \\nabla_\\epsilon T(x) + \\text{trace} (\\nabla_x T(x))^{-1} \\nabla_\\epsilon \\nabla_x T(x)\\right].\n", "\\end{align}$$\n", " \n", "Setting $T(x) = x + \\epsilon \\phi(x)$ yields the result\n", " \n", "$$\\begin{align}\n", "\\nabla_{\\epsilon} \\text{KL}(q_{[T]} || p) &= - \\mathbb{E}_{x \\sim q} \\left[\\nabla \\log p(x)^\\top \\phi(x) + \\text{trace}\\left[\\nabla_x \\phi(x) \\right]\\right], \\\\\n", " &= - \\mathbb{E}_{x \\sim q} \\left[\\text{trace} \\mathcal{A}_p \\phi(x) \\right].\n", "\\end{align}$$\n", ":::\n", "\n", "\n", "This result gives us the rate of change of the KL as $\\epsilon$ increases, for given $\\phi$. Now, we want to pick $\\phi$ such that $-\\mathbb{E}_{x \\sim q} \\left[\\text{trace} \\mathcal{A}_p \\phi(x) \\right]$ is as negative as possible. However, this minimisation is not well defined, because one can scale $\\phi$ by an arbitrary scalar making the expectation unbounded. Further, the minimisation is not analytically or computationally tractable either. This issue can be resolved by considering a constrained version of this optimisation problem instead, using Reproducing Kernel Hilbert Spaces (RKHS).\n", "\n", "Let $k$ be a positive-definite kernel, defining a corresponding RKHS $\\mathcal{H}$ with inner product $\\langle\\cdot, \\cdot \\rangle_{\\mathcal{H}}$. Let also $\\mathcal{H}_D = \\mathcal{H} \\times ... \\times \\mathcal{H}$ be the Hilbert space of $D$-dimensional vector valued functions $f = (f_1, ..., f_D) : f_1, ..., f_D \\in \\mathcal{H}$ with corresponding inner product\n", "\n", "$$\\begin{align}\n", "\\langle f, g \\rangle_{\\mathcal{H}_D} = \\sqrt{\\sum_{d = 1}^D \\langle f_d, g_d \\rangle_{\\mathcal{H}_D}^2}.\n", "\\end{align}$$\n", "\n", "If we now constrain $\\phi \\in \\mathcal{H}_D$ and $|| \\phi ||_{\\mathcal{H}_D} \\leq 1$ we obtain{cite}`liu2016kernelized` the following analytic expression for the direction of steepest descent." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ ":::{prf:theorem} Direction of steepest descent\n", "\n", "The function $\\phi^* \\in \\mathcal{H}_D, || \\phi^* ||_{\\mathcal{H}_D} \\leq 1$ which maximises the rate of decrease KL-divergence is\n", " \n", "$$\\begin{align}\n", "\\phi^*(\\cdot) = \\beta / ||\\beta||_{\\mathcal{H}_d} ,~~\\beta(\\cdot) = \\mathbb{E}_{x \\sim q}\\left[ k(x, \\cdot) \\nabla_x \\log p(x) + \\nabla_x k(x, \\cdot)\\right].\n", "\\end{align}$$\n", " \n", ":::\n", "\n", ":::{dropdown} Proof: Direction of steepest descent\n", " \n", "For $f \\in \\mathcal{H}_D$ we have the following equality\n", " \n", "$$\\begin{align}\n", "\\langle f, \\beta \\rangle_{\\mathcal{H}_D} &= \\sum_{d = 1}^D \\langle f_d(\\cdot), \\beta_d \\rangle_{\\mathcal{H}} \\\\\n", " &= \\sum_{d = 1}^D \\left \\langle f_d(\\cdot), \\mathbb{E}_{x \\sim q}\\left[k(x, \\cdot) \\nabla_{x_d} \\log p(x) + \\nabla_{x_d} k(x, \\cdot)\\right] \\right\\rangle_{\\mathcal{H}} \\\\\n", " &= \\sum_{d = 1}^D \\mathbb{E}_{x \\sim q}\\left[\\nabla_{x_d} \\log p(x) \\langle f_d(\\cdot), k(x, \\cdot) \\rangle + \\langle f_d(\\cdot), \\nabla_{x_d} k(x, \\cdot) \\rangle \\right] \\rangle_{\\mathcal{H}} \\\\\n", " &= \\sum_{d = 1}^D \\mathbb{E}_{x \\sim q}\\left[\\nabla_{x_d} \\log p(x) f_d(x) + \\nabla_{x_d} f_d(x) \\right] \\\\\n", " &= \\mathbb{E}_{x \\sim p}\\left[\\mathcal{A}_q f(x)\\right].\n", "\\end{align}$$\n", " \n", "Therefore, the $f \\in \\mathcal{H}_D$ which maximises $\\mathbb{E}_{x \\sim p}\\left[\\mathcal{A}_q f(x)\\right]$ is the one which maximises the inner product $\\langle f, \\beta \\rangle_{\\mathcal{H}_D}$, which occurs when $f$ is proportional to $\\beta$.\n", " \n", ":::\n", "\n", "### Empirical approximation\n", "\n", "Now, if we approximate $q$ by a finite set of $N$ particles at locations $x_n^{(i)}, n = 1, ..., N$, at the $i^{th}$ iteration, we obtain at the following iterative algorithm.\n", "\n", ":::{prf:definition} Stein variational gradient descent\n", "\n", "Given a distribution $p(x)$, a postive definite kernel $k(x, x')$ and a set of particles with initial positions $\\{x_n^{(0)}\\}_{n=1}^N$, Stein variational gradient descent evolves the particles according to\n", " \n", "$$\\begin{align}\n", "x^{(i + 1)}_n = x^{(i)}_n + \\frac{\\epsilon^{(i)}}{N}\\sum_{m = 1}^N \\left[ k(x_n, x_m) \\nabla_x \\log p(x)|_{x_n} + \\nabla_x k(x_m, x) |_{x_n}\\right].\n", "\\end{align}$$\n", ":::" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Implementation\n", "\n", "The SVGD algorithm is surprisingly easy to implement, while also each step is quite cheap to evaluate.\n", "We will use SVGD to approximate a mixture-of-gaussians distribution, to allow for multiple modes." ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "tags": [ "remove-cell" ] }, "outputs": [], "source": [ "import tensorflow as tf\n", "\n", "def mixture_of_gaussians_logprob(locs, scales, probs):\n", " \n", " def logprob(x):\n", " \n", " # Dimension of x\n", " D = x.shape[-1]\n", " \n", " # Ensure MoG weight probabilities sum to 1\n", " log_probs = tf.math.log(probs[None, :] / tf.reduce_sum(probs))\n", " \n", " # Differences between x and gaussian locations\n", " diff = x[:, None, :] - locs[None, :, :]\n", " \n", " # Compute log of gaussian, including the normalising constant\n", " quad = -0.5 * tf.reduce_sum((diff / scales) ** 2, axis=2)\n", " quad = quad - 0.5 * D * tf.math.log(2 * np.pi * tf.reduce_prod(scales, axis=-1) ** 2)\n", " \n", " # Compute log-probability using the log-sum-exp trick for stability\n", " summands = log_probs + quad\n", " max_summand = tf.reduce_max(summands, axis=1)\n", " \n", " summed = tf.reduce_sum(tf.exp(summands - max_summand[:, None]), axis=1)\n", " summed = max_summand + tf.math.log(summed)\n", " \n", " return summed\n", " \n", " return logprob" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "tags": [ "center-output", "remove-input" ] }, "outputs": [ { "data": { "application/pdf": "", "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-11-30T18:01:25.780390\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Float dtype to use\n", "dtype = tf.float32\n", "\n", "# Parameters for mixture of gaussians pdf\n", "locs = tf.convert_to_tensor([[-1e0, 0.], [1e0, 0.]], dtype=dtype)\n", "scales = tf.convert_to_tensor([2e-1, 2e-1], dtype=dtype)\n", "probs = tf.convert_to_tensor([5e-1, 5e-1], dtype=dtype)\n", "\n", "# Discretisation grid reslution\n", "grid_res = 100\n", "\n", "# Create log-probabilty lambda\n", "logprob = mixture_of_gaussians_logprob(\n", " locs=locs,\n", " scales=scales,\n", " probs=probs,\n", ")\n", "\n", "# Input locations at which to compute log-probabilities\n", "x_plot = np.linspace(-2.5, 2.5, grid_res)\n", "x_plot = np.stack(np.meshgrid(x_plot, x_plot), axis=-1)\n", "x_plot = tf.convert_to_tensor(np.reshape(x_plot, (-1, 2)), dtype=dtype)\n", "\n", "# Compute log-probabilities\n", "logp = logprob(x_plot)\n", "\n", "# Reshape to 3D and 2D arrays for plotting\n", "x_plot = np.reshape(x_plot, (grid_res, grid_res, 2))\n", "logp = tf.reshape(logp, (grid_res, grid_res))\n", "\n", "# Contourplot levels corresponding to standard deviations\n", "levels = np.max(np.exp(logp)) * np.exp(- np.linspace(4, 0, 5) ** 2)\n", "\n", "# Plot density\n", "plt.figure(figsize=(5, 5))\n", "plt.contourf(\n", " x_plot[:, :, 0],\n", " x_plot[:, :, 1],\n", " np.exp(logp),\n", " cmap=\"coolwarm\",\n", ")\n", "\n", "# Format plot\n", "plt.xticks(np.linspace(-2, 2, 5), fontsize=16)\n", "plt.yticks(np.linspace(-2, 2, 5), fontsize=16)\n", "plt.xlim([-2., 2.])\n", "plt.ylim([-2., 2.])\n", "plt.xlabel('$x_1$', fontsize=20)\n", "plt.ylabel('$x_2$', fontsize=20)\n", "plt.title('Mixture of Gaussians density', fontsize=20)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Although SVGD can use any positive-semidefinite kernel, we will focus our attention to the standard EQ kernel\n", "\n", "$$\\begin{align}\n", "k(x, x') = \\exp\\left(-\\frac{1}{2\\ell^2} (x - x')^2\\right),\n", "\\end{align}$$\n", "\n", "implemented by the `eq` function below. The `svgd_grad` computes the SVGD gradients for a set of particles, using Tensorflow's batch jacobians." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "def eq(lengthscales):\n", " \n", " def kernel(x, x_):\n", " \n", " diff = x[:, None, :] - x_[None, :, :]\n", " quad = tf.reduce_sum((diff / lengthscales) ** 2, axis=2)\n", " exp = tf.exp(-0.5 * quad)\n", " \n", " return exp\n", " \n", " return kernel\n", "\n", "\n", "@tf.function\n", "def svgd_grad(x, logprob, kernel):\n", " \n", " x_ = tf.convert_to_tensor(x[:], dtype=tf.float32)\n", " x = tf.convert_to_tensor(x, dtype=tf.float32)\n", " \n", " with tf.GradientTape(persistent=True) as tape:\n", " \n", " tape.watch(x)\n", " \n", " logp = logprob(x)\n", " k = kernel(x, x_)\n", " \n", " dlogp = tape.gradient(logp, x)\n", " dk = tape.batch_jacobian(k, x)\n", " \n", " svg = (k @ dlogp + tf.reduce_sum(dk, axis=0)) / x.shape[0]\n", " \n", " return svg" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Demo on mixture of Gaussians\n", "\n", "We can now run SVGD using a modest number of particles initialised in between the two modes." ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "tags": [ "remove-cell" ] }, "outputs": [], "source": [ "# Number of particles to simulate\n", "num_particles = 100\n", "\n", "# Initial positions of particles\n", "x = 2e-1 * np.random.normal(size=(num_particles, 2)).astype(np.float32)\n", "\n", "# Create EQ kernel\n", "eq_scales = tf.convert_to_tensor([2e-1, 2e-1], dtype=dtype)\n", "kernel = eq(eq_scales)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "tags": [ "center-output", "remove-input" ] }, "outputs": [ { "data": { "application/pdf": "", "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-11-30T18:00:02.697723\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Step size and number of gradient descent steps\n", "stepsize = 1e-1\n", "num_steps = 100\n", "\n", "plt.figure(figsize=(8, 4))\n", "\n", "for i in range(num_steps + 1):\n", " \n", " # Plot 0th iteration and last iteration\n", " if i in [0, num_steps]:\n", " \n", " # Choose appropriate subplot\n", " plt.subplot(1, 2, (i > 0) + 1)\n", " \n", " # Plot particles and probabilities\n", " plt.scatter(x[:, 0], x[:, 1], zorder=2, c='k', s=10)\n", " plt.contourf(\n", " x_plot[:, :, 0],\n", " x_plot[:, :, 1],\n", " np.exp(logp),\n", " cmap=\"coolwarm\",\n", " )\n", " \n", " plt.xlim([-2., 2.])\n", " plt.ylim([-2., 2.])\n", " \n", " plt.xticks(fontsize=16)\n", " plt.yticks(np.linspace(-2., 2., 5), fontsize=16)\n", " plt.xlabel('$x_1$', fontsize=20)\n", " \n", " if i > 0:\n", " plt.yticks([])\n", " \n", " else:\n", " plt.ylabel('$x_2$', fontsize=20)\n", " \n", " plt.title(f'Iteration {i}', fontsize=22)\n", " \n", " # Compute Stein variational gradient\n", " svg = svgd_grad(x, logprob=logprob, kernel=kernel)\n", " \n", " # Adapt particle locations using the SVG\n", " x = x + stepsize * svg.numpy()\n", " \n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We observe that some of the particles fall into each of the two modes, capturing the bimodality of the target, something which VI with a mean-field Gaussian $q$ cannot do.\n", "\n", "### Failure mode on mixture of Gaussians\n", "\n", "However, SVGD also has failure modes, as illustrated below.\n", "If we initialise the particles on one mode of two well-separated Gaussians, then the optimisation gets stuck at a local optimum which fails to capture one of the two modes of the mixture of Gaussians.\n", "Even though SVGD may be able to express more expressive approximate distributions than mean-field VI, it is not guaranteed that the optimisation will be able to find such a distribution." ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "tags": [ "remove-cell" ] }, "outputs": [], "source": [ "# Parameters for mixture of gaussians pdf\n", "locs = tf.convert_to_tensor([[-1e0, 0.], [1e0, 0.]], dtype=dtype)\n", "scales = tf.convert_to_tensor([2e-1, 2e-1], dtype=dtype)\n", "probs = tf.convert_to_tensor([0.5, 0.5], dtype=dtype)\n", "\n", "# Number of particles to simulate\n", "num_particles = 100\n", "\n", "# Initial positions of particles\n", "x = 2e-1 * np.random.normal(size=(num_particles, 2)).astype(np.float32) \n", "x = x - np.array([[-1e0, 0]]).astype(np.float32)\n", "\n", "# Create EQ kernel\n", "eq_scales = tf.convert_to_tensor([2e-1, 2e-1], dtype=dtype)\n", "kernel = eq(eq_scales)\n", "\n", "# Create log-probabilty lambda\n", "logprob = mixture_of_gaussians_logprob(\n", " locs=locs,\n", " scales=scales,\n", " probs=probs,\n", ")" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "tags": [ "center-output", "remove-input" ] }, "outputs": [ { "data": { "application/pdf": "", "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-11-30T18:00:05.647224\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Step size and number of gradient descent steps\n", "stepsize = 1e-1\n", "num_steps = 100\n", "\n", "plt.figure(figsize=(8, 4))\n", "\n", "for i in range(num_steps + 1):\n", " \n", " # Plot 0th iteration and last iteration\n", " if i in [0, num_steps]:\n", " \n", " # Choose appropriate subplot\n", " plt.subplot(1, 2, (i > 0) + 1)\n", " \n", " # Plot particles and probabilities\n", " plt.scatter(x[:, 0], x[:, 1], zorder=2, c='k', s=10)\n", " plt.contourf(\n", " x_plot[:, :, 0],\n", " x_plot[:, :, 1],\n", " np.exp(logp),\n", " cmap=\"coolwarm\",\n", " )\n", " \n", " plt.xlim([-2., 2.])\n", " plt.ylim([-2., 2.])\n", " \n", " plt.xticks(fontsize=16)\n", " plt.yticks(np.linspace(-2., 2., 5), fontsize=16)\n", " plt.xlabel('$x_1$', fontsize=20)\n", " \n", " if i > 0:\n", " plt.yticks([])\n", " \n", " else:\n", " plt.ylabel('$x_2$', fontsize=20)\n", " \n", " plt.title(f'Iteration {i}', fontsize=22)\n", " \n", " # Compute Stein variational gradient\n", " svg = svgd_grad(x, logprob=logprob, kernel=kernel)\n", " \n", " # Adapt particle locations using the SVG\n", " x = x + stepsize * svg.numpy()\n", " \n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Conclusion\n", "\n", "This section presented SVGD, a very interesting general-purpose algorithm for approximate inference.\n", "SVGD works by simulating a set of particles, regarded as an empirical approximation of a distribution $q$ which itself approximates the target distribution $p$.\n", "By evolving $q$ according to a sequence of transformations, each of which is determined as the direction of steepest decrease in the KL between $q$ and $p$, SVGD can produce a flexible approximation to the target $p$." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## References\n", "\n", "```{bibliography}\n", ":filter: docname in docnames\n", "```" ] } ], "metadata": { "kernelspec": { "display_name": "venv-rw", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 4 }