Variational inference for diffusion processes¶
Typically, the posterior of stochastic differential equation (SDE) models cannot be computed in closed form. Therefore, if are modelling some temporal data using an SDE, and are looking to compute the posterior of the model, we will have to resort to approximations. Here we discuss a method [ACOST07] [AOS+08] for approximating the posterior of an SDE, by fitting another approximating SDE to it.
Two SDE models¶
Suppose we wish to model some data using an SDE with noisy observations of the form
where \(\beta\) is a standard Brownian motion. Given observations \(\mathcal{D} = \{t_k, y_k\}_{k = 1}^K\), we are interested in approximating the posterior process \(p(x, t | \mathcal{D})\). For most choices of \(f(x, t)\), working with this SDE analytically is be impossible. Instead, we will approximate the posterior of this SDE with another stochastic process, described by the SDE
The motivation behind this particular form is that because this SDE is linear, it corresponds to a Gaussian Process (GP) whose marginal mean and variance can be treated almost entirely in analytic form. In particular, this mean and variance can be shown to exactly obey two ODEs which, even though cannot be solved in closed form, can be well approximated using numerical integration.
Free Energy approximation¶
Because the approximating SDE has an affine drift function \(g\) and a constant noise term, it is a Gaussian Process (GP). The marginal mean and variance, \(m(t)\) and \(S(t)\), of this GP statisfy[SarkkaS19] the ODEs
We are therefore interested finding \(A(t)\) and \(b(t)\) such that the distribution of \(x\) under this SDE, approximates the distribution of \(x\) under the posterior process. Let us write \(q(x, t)\) for the distribution of \(x\) at time \(t\) according to the approximating SDE and \(p(x, t)\) for the prior of the exact SDE - we will often abbreviate these by \(q(x)\) and \(p(x)\). We will fit \(A, b\) by minimising the free energy
We cannot compute this free energy in closed form, so we approximate it with a discretisation. Then letting the discretisation resolution go to \(0\) we come close to an analytic expression for the KL-divergence between the prior and the approximating SDE.
Lemma (KL-divergence between \(p\) and \(q\)) The KL-divergence between the prior and approximating SDEs can be written as
Proof: KL-divergence between \(p\) and \(q\)
Consider splitting the time interval \([t_0, t_1]\) into \(N\) equal segments of size \(\Delta t = \frac{t_1 - t_0}{N}\) and let the values of the SDE at the endpoints of these segments be \(x_0, x_1, ..., x_N\). Then, the joint distribution at these points in time under the prior and approximating SDEs are
We can then write the KL-divergence between these two multivariate distributions as
Evaluating the second term and taking the \(\Delta t \to 0\) limit we obtain
Therefore, the free energy is equal to
and we look to optimise this w.r.t. \(A, b\).
Maximising the Free Energy¶
Maximising \(\mathcal{F}\) w.r.t. \(A, b\) directly is tricky because \(m, S\) depend on \(A, b\). Changing \(A, b\) will change \(m, S\) in a complicated nonlinear way and we would have to account for this when computing the gradients of \(\mathcal{F}\) w.r.t. \(A, b\). More specifically, since the approximating SDE is linear, its mean \(m(t)\) and covariance \(S(t)\) exactly follow
linking the values of \(m, S\) at later times, to the values of \(A, b\) at earlier times. To avoid having to compute these gradients explicitly, we can treat the \(m, S, A, b\) variables as independent and instead introduce the ODEs above as constraints to the optimisation problem. This means that any solution found by our optimisation problem will respect these constraints and \(m, S, A, b\) will be consistent. Introducing Lagrange mulitpliers \(\lambda\) (a vector) and \(\Psi\) (a symmetric matrix), enforcing these constraints, we can write the Lagrangian \(\mathcal{L}\) as
Note that \(\lambda(t_1)\) and \(\Psi(t_1)\) can be set to \(0\) since we only need the ODE constraints to be satisfied for \(t_0 \leq t < t_1\). Taking variational derivatives, we obtain the four equations presented below - we ommit dependence on \(t\) to lighten the notation. The first two equations are
and give the gradients of the Lagrangian with respect to \(A, b\). Taking variational derivatives and setting these equal to \(0\) we obtain the ODEs for \(\lambda\) and \(\Psi\) below
Now in order to optimise the Lagrangian, we first solve forwards for \(m, S\) given \(A, b\) and initial conditions \(m(t_0) = m_0, S(t_0) = S_0\). Then we then solve backwards for \(\lambda(t), \Psi(t)\) for given \(m(t), S(t)\) and \(A(t), b(t)\) with initial conditions \(\lambda(t_1) = 0, \Psi(t_1) = 0\). Lastly we compute the gradients \(\nabla_A \mathcal{L}\) and \(\nabla_b \mathcal{L}\) and adjust \(A, b\) by taking a gradient step,[AOS+08] that is
Alternatively,[AOS+08] we can equate these gradients to \(0\) and solve for \(A, b\), giving a new set of parameters \(\tilde{A}, \tilde{b}\)
Implementation¶
We now implement the algoritm, breaking it into a forward solve, a backward solve and a helper which interleaves these two steps a number of times. We will apply the algorithm to the Ornstein-Uhlenbeck SDE with a
This model was used in [ACOST07], because its posterior is available in closed form, since it corresponds to a GP with zero mean and covariance function
The forward step¶
In the forward step we solve the ODEs
numerically for \(m, S\) of the approximating SDE, using the Euler method. Since the forward solve is independent of the SDE being approximated, none of the details of the OU SDE enter the forward
function.
def forward(m, S, b, A, Sigma, dt):
for i in range(len(b) - 1):
# Euler step for m and S ODEs
m[i + 1] = m[i] - (np.dot(A[i], m[i]) - b[i]) * dt
S[i + 1] = S[i] - (np.dot(A[i], S[i]) + np.dot(S[i], A[i].T) - Sigma) * dt
return m, S
The backward step¶
In the backward step we first solve the ODEs
and then compute the variational derivatives w.r.t. \(A, b\) and set these to zero to compute the updates \(\tilde{A}, \tilde{b}\)
For this SDE, the \(E_{sde}\) and \(E_{obs}\) quantities are
Note that because observation term of the free energy involves \(\delta\)-functions at the locations where observations are present, the following jump conditions must be applied at these locations
def backward(t_grid, A, b, m, S, Sigma, gamma, r, psi, lamda, t_dict, x, dt):
# Arrays for storing the updates for A and b
A_ = np.zeros_like(A)
b_ = np.zeros_like(b)
for i in range(len(b) - 1, 0, -1):
# Compute dEdS and dEdm
coeff = (A[i] - gamma) ** 2 / Sigma
dEdS = 0.5 * coeff
dEdm = coeff * m[i] - b[i] * (A[i] - gamma) / Sigma
# Euler step for lambda and psi ODEs
lamda[i - 1] = lamda[i] - (np.dot(A[i].T, lamda[i]) - dEdm) * dt
psi[i - 1] = psi[i] - (2 * np.dot(psi[i], A[i]) - dEdS) * dt
# Handle jump conditions at locations of the data
if t_grid[i - 1] in t_dict:
psi[i - 1] = psi[i - 1] + 0.5 * r ** -2
lamda[i - 1] = lamda[i - 1] - r ** -2 * (x[t_dict[t_grid[i - 1]]] - m[i - 1])
for i in range(len(b) - 1, -1, -1):
A_[i] = gamma + 2 * np.dot(Sigma, psi[i])
b_[i] = - gamma * m[i] + np.dot(A_[i], m[i]) - np.dot(Sigma, lamda[i])
return psi, lamda, b_, A_
Putting the two together¶
The smoothing
function below puts the forward and backward steps together, executing them in an interleaved fashion for a specified number of times.
def smoothing(t_obs, t_grid, y_obs, num_passes, omega, Sigma, gamma, r, dt, m0, S0):
grid_size = t_grid.shape[0]
# Dictionary mapping from times to indices for array x
t_dict = dict(zip(t_obs, np.arange(0, len(t_obs))))
b = np.zeros((grid_size, 1))
A = np.zeros((grid_size, 1, 1))
for i in range(num_passes):
lamda = np.zeros((grid_size, 1))
psi = np.zeros((grid_size, 1, 1))
m = m0 * np.ones((grid_size, 1))
S = S0 * np.ones((grid_size, 1, 1))
# Forward pass to compute m, S
m, S = forward(m=m, S=S, b=b, A=A, Sigma=Sigma, dt=dt)
# Backward pass to compute psi, lamda, b_, A_
psi, lamda, b_, A_ = backward(t_grid=t_grid,
A=A,
b=b,
m=m,
S=S,
Sigma=Sigma,
gamma=gamma,
r=r,
psi=psi,
lamda=lamda,
t_dict=t_dict,
x=y_obs,
dt=dt)
b = b + omega * (b_ - b)
A = A + omega * (A_ - A)
return b, A, m, S, psi, lamda
Demonstration¶
We now apply the algorithm to toy data drawn from the exact GP model, shown below.
def ornstein_uhlenbeck(sigma, gamma, t, t_):
coeff = 0.5 * sigma ** 2 / gamma
exp = np.exp(- gamma * np.abs(t[..., :, None] - t_[..., None, :]))
return coeff * exp
# Initial and final time to approximate
t0 = 0.
t1 = 5.
# OU parameters sigma and gamma, and observation noise level
sigma = 1.
gamma = 2.
r = 1e-1
# Number of observations and integration grid size
num_obs = 5
grid_size = 10000
# Number of grid points between observation points
interval_size = int(grid_size / (num_obs + 1) - 1e-6)
# Grid of times used for integration and grid of times for observations
t_grid = np.linspace(t0, t1, grid_size)
t_obs = t_grid[::interval_size][1:-1]
# Set random seed
np.random.seed(1)
# Zero mean vector and OU covariance matrix for sampling from the SDE prior
y_mean = np.zeros((num_obs,))
y_cov = ornstein_uhlenbeck(sigma=sigma, gamma=gamma, t=t_obs, t_=t_obs)
y_cov = y_cov + r ** 2 * np.eye(num_obs)
y_obs = np.random.multivariate_normal(mean=y_mean, cov=y_cov)
# Plot data
plt.figure(figsize=(8, 3))
plt.scatter(t_obs, y_obs, marker='x', color='red', zorder=2)
# Format plot
plt.title('Example data', fontsize=18)
plt.xlabel('$t$', fontsize=16)
plt.ylabel('$x$', fontsize=16)
plt.xticks(np.linspace(0, 5, 6))
plt.yticks(np.linspace(-1, 1, 5))
plt.show()
# Algorithm parameters
num_passes = 20
Sigma = sigma ** 2 * np.eye(1)
omega = 5e-1
m0 = 0.
S0 = 1e-1
dt = (t1 - t0) / grid_size
# Run the smoothing algorithm
b, A, m, S, psi, lamda = smoothing(t_obs=t_obs,
t_grid=t_grid,
y_obs=y_obs,
num_passes=num_passes,
omega=omega,
Sigma=Sigma,
gamma=gamma,
r=r,
dt=dt,
m0=m0,
S0=S0)
# Compute the exact posterior of the OU process
t_post = np.linspace(t0, t1, 100)
# Zero mean vector and OU covariance matrix for sampling from the SDE prior
Ktt = ornstein_uhlenbeck(sigma=sigma, gamma=gamma, t=t_obs, t_=t_obs)
Ktt = Ktt + r ** 2 * np.eye(num_obs)
Kt_t = ornstein_uhlenbeck(sigma=sigma, gamma=gamma, t=t_post, t_=t_obs)
Kt_t_ = ornstein_uhlenbeck(sigma=sigma, gamma=gamma, t=t_post, t_=t_post)
# Exact posterior mean and variance
post_mean = Kt_t @ np.linalg.solve(Ktt, y_obs)
post_var = np.diag(Kt_t_ - Kt_t @ np.linalg.solve(Ktt, Kt_t.T))
# Plot data, approximate and exact posterior
plt.figure(figsize=(8, 3))
# Observed data
plt.scatter(t_obs, y_obs, marker='x', color='red', zorder=3)
# Approximate posterior
plt.plot(t_grid, m[:, 0], color='k', zorder=2)
plt.fill_between(t_grid,
m[:, 0] - S[:, 0, 0] ** 0.5,
m[:, 0] + S[:, 0, 0] ** 0.5,
color='gray',
alpha=0.5,
zorder=1,
label='Approximate posterior')
# Exact posterior
plt.plot(t_post, post_mean, color='green', zorder=1, label='Exact posterior')
plt.plot(t_post, post_mean - post_var ** 0.5, '--', color='green', zorder=1)
plt.plot(t_post, post_mean + post_var ** 0.5, '--', color='green', zorder=1)
# Format plot
plt.title('Exact and approximate posterior', fontsize=18)
plt.xlabel('$t$', fontsize=16)
plt.ylabel('$x$', fontsize=16)
plt.xticks(np.linspace(0, 5, 6))
plt.yticks(np.linspace(-1, 1, 5))
plt.xlim([0, 5])
plt.legend()
plt.show()
Note that since we did not optimise the initial mean and variance \(m_0, S_0\), the approximate posterior does not precisely match the exact posterior. Overall however, the approximation looks quite good. This is expected since this was a relatively easy example, where the SDE being approximated was linear. It would be interesting to test this algorithm in scenarios where the underlying SDE is non-linear or when the posterior SDE is multimodal.
Conclusions¶
The method presented in this page provides a way for approximating posteriors of constant-noise SDEs using a variational approach. To achieve this, the original SDE is approximated by a linear SDE with a time-varying diffusion function, adapted so that the approximate posterior minimises the free energy. One shortcoming of this method is that its computational complexity scales as \(\mathcal{O}(T)\) with the time interval \(T\). Moreover, as the solution must be performed numerically on a fine grid, this cost is far from negligible. Furthermore, this method does not currently support adaptive stepsize numerical ODE solvers, since the forward and backward steps require access to the value of related quantities on the same grid of points. Lastly, this approximation must be repeated for different datasets - it would be interesting to investigate whether amortisation can be incorporated into this method.
References¶
- AOS+08(1,2,3)
Cédric Archambeau, Manfred Opper, Yuan Shen, Dan Cornford, and John Shawe-taylor. Variational inference for diffusion processes. In J. Platt, D. Koller, Y. Singer, and S. Roweis, editors, Advances in Neural Information Processing Systems, volume 20. 2008.
- ACOST07(1,2)
Cedric Archambeau, Dan Cornford, Manfred Opper, and John Shawe-Taylor. Gaussian process approximations of stochastic differential equations. In Gaussian Processes in Practice, 1–16. PMLR, 2007.
- SarkkaS19
Simo Särkkä and Arno Solin. Applied stochastic differential equations. Volume 10. Cambridge University Press, 2019.