Bayesian Inference with SVGD (Stein Variational Gradient Descent)
Having explored how to construct models, compute their properties, evaluate them efficiently, and scale computations across distributed clusters, we now turn to a fundamental question: given observed data, how do we estimate the parameters of our model? This is the domain of statistical inference. In a Bayesian framework, we seek not just point estimates but entire posterior distributions over parameters, quantifying our uncertainty about parameter values given the data we have observed.
Traditional Markov Chain Monte Carlo (MCMC) methods like Metropolis-Hastings and Hamiltonian Monte Carlo have been the workhorses of Bayesian inference for decades. However, these methods face challenges with high-dimensional parameter spaces, complex posterior geometries, and the need for many sequential samples to achieve convergence. Stein Variational Gradient Descent (SVGD) offers a compelling alternative. Instead of generating samples sequentially through a Markov chain, SVGD represents the posterior with a set of particles (parameter vectors) and updates these particles iteratively to move them toward the posterior distribution. The method combines ideas from optimization, kernel methods, and functional analysis to create a deterministic algorithm that is highly parallelizable and works well in high dimensions.
SVGD begins with an initial set of particles (typically drawn from the prior or from a simple distribution) and iteratively transport these particles toward the posterior using gradient information. Each particle interacts with all other particles through a kernel function, with the interaction strength depending on the distance between particles. This interaction ensures that particles spread out to cover the posterior distribution, allowing extraction of not only the most likely parameter value (MAP estimate) but also a confidence interval of the estimate.
SVGD is particularly well suited for our implementation of phase-type distributions because the PDF needed for gradient computation readily computed and because the any number of exact moments can be exploited for regularization.
TipThis should not be possible
The complexity of evaluating phase-type distributions in their matrix formulation is O(n^3), prohibiting their evaluation if state spaces are large, let alone evaluating them across many SVGD particles and iterations. However, hidden to the user, phasic implements a series of innovations to make inference practically possible
Backend implemented in C
Graph algorithm for Gaussian elimination fully exploiting sparseness of state space making evaluation closer to O(n^x) than O(n^x).
Hierarchal decomposition of strongly connected components for caching and parallel compute.
Functional trace recording of elimination in parameterized graphs makes evaluations after reparametrization O(n^x).
Graph algorithms for computing moments O(n^x)
Graph algorithms for PDF O(n^x)
JAX JIT-compiled auto-differentiation
Three levels of caching: RAM cache, Persistent JAX cache, Persistent elimination trace cache.
Community trace repository and torrent trace sharing.
Lets make a coalescent moddel with one parameter, the coalescent rate:
nr_samples =4@with_ipv([nr_samples]+[0]*(nr_samples-1))def coalescent_1param(state): transitions = []for i inrange(state.size):for j inrange(i, state.size): same =int(i == j)if same and state[i] <2:continueifnot same and (state[i] <1or state[j] <1):continue new = state.copy() new[i] -=1 new[j] -=1 new[i+j+1] +=1 transitions.append([new, [state[i]*(state[j]-same)/(1+same)]])return transitions
Construct the graph and set some value of the model parameter:
graph = Graph(coalescent_1param)graph.plot()
With a large set of observed times to the most recent common ancestor (TMRCA), we can fit our model to find the most likely value of the model parameter. For this turorial we will just set the model parameter to 7 and then sample from the model to get such a data set. This has the advantage that we know what the model parameter is supposed to be.
Parameter Fixed MAP Mean SD HPD 95% lo HPD 95% hi
0 No 6.9808 6.8792 0.1764 6.6874 7.2724
Particles: 24, Iterations: 200
svgd.learning_rate
0.0001
However, you should always evaluate convergence of the SVGD optimization. The plot_trace, shows the parameter values of each particle across iterations:
svgd.plot_trace();
The distribution of particle values describe the posterior distribution from which a mean and CI can be computed:
svgd.plot_ci(true_theta=true_theta) ;
A nice property of SVGD is that you can easily evaluate if the optimization converges. In this case it clearly does not. Fortunately there are several easy remedies for this as shown below.
Prior
The prior distribution used in Bayesian inference represents our prior belief, I.e. our knowledge about what range of values the parameter can reasonably take. Phasic comes with two types of priors that are easily defined from basic knowledge about the parameter. The default behaviour of Phasic is to use a Gaussian prior centered on a rough estimate of each parameter using method-of-moments. To see the prior used, you can can plot it like this:
svgd.prior.plot() ;
or reconstruct the GaussPrior from its properties:
The GaussPrior and HalfCauchyPrior classes provide the log-probabilities needed for SVGD. If you want to provide your own prior function, it must return such log-probabilities.
Lets add the prior to SVGD inference and rerun it.
CONVERGED (iteration 50/102)
Mean stabilized at iteration 50
Std stabilized at iteration 50
Detected Issues:
ℹ Converged at 49.0% of iterations - could reduce n_iterations
Learning Rate: Converged early (iteration 50) - could converge faster
0.00015000000000000001
Particles: High ESS ratio (1.00) - could reduce particles
n_particles=20
Iterations: Converged early
Could reduce to n_iterations=60
That is better, but it still does not converge to the true theta value. We can change the learning rate from its default value of 0.01:
Phasic provides three optimizers for SVGD that offer adaptive per-parameter learning rates. These can be particularly useful when gradients have vastly different scales across parameters or when fixed step sizes cause oscillation.
Optimizer Comparison:
Optimizer
Adaptive LR
Momentum
Best For
Fixed LR
No
No
Simple, well-tuned problems
Adam
Yes
Yes
General purpose (recommended)
SGDMomentum
No
Yes
Accelerating convergence
RMSprop
Yes
No
Non-stationary objectives
Adagrad
Yes
No
Sparse gradients
Adam
When using Adam, the learning_rate parameter passed to SVGD is ignored in favor of the optimizer’s learning rate.
Adam (Adaptive Moment Estimation) maintains running estimates of two quantities for each parameter:
First moment (m): An exponentially weighted average of past gradients (momentum)
Second moment (v): An exponentially weighted average of past squared gradients (gradient variance)
Update rule:
\begin{align}
m &\leftarrow \beta_1 \cdot m + (1 - \beta_1) \cdot g \\
v &\leftarrow \beta_2 \cdot v + (1 - \beta_2) \cdot g^2 \\[0.5em]
\hat{m} &= \frac{m}{1 - \beta_1^t} \\[0.5em]
\hat{v} &= \frac{v}{1 - \beta_2^t} \\[0.5em]
\theta &\leftarrow \theta - \alpha \cdot \frac{\hat{m}}{\sqrt{\hat{v}} + \epsilon}
\end{align}
Each parameter gets its own effective learning rate:
Parameters with consistently large gradients \rightarrow larger \hat{v}\rightarrow smaller effective step
Parameters with consistently small gradients \rightarrow smaller \hat{v}\rightarrow larger effective step
This automatic scaling helps when different parameters naturally have different gradient magnitudes.
Adam vs Fixed Learning Rate
Fixed Learning Rate
Adam
Same step size for all parameters
Adaptive step size per parameter
Can oscillate when gradients vary widely
Dampens oscillations via momentum
Requires careful tuning
More robust to initial choice
Simpler, less overhead
Tracks additional state (m, v)
When fixed learning rate works well:
Well-behaved optimization landscapes
Parameters with similar gradient scales
When you’ve tuned the learning rate carefully
When Adam helps:
Large datasets causing large gradient magnitudes
Parameters with vastly different scales
“Shark teeth” oscillation patterns in convergence
When you want reasonable results without extensive tuning
Tuning Guidelines:
learning_rate: Start with 0.001-0.01. If convergence is too slow, increase. If unstable, decrease.
beta1: 0.9 works well for most cases. Lower values (0.8) give less momentum.
beta2: 0.999 is standard. Lower values (0.99) adapt faster to gradient changes.
epsilon: Rarely needs adjustment. Increase to 1e-7 if you see numerical issues.
Consider Adam when you observe:
Oscillating convergence (“shark teeth” pattern in loss)
Different parameters converging at different rates
Sensitivity to learning rate choice
Large datasets** causing large gradient magnitudes
Stick with fixed learning rate when:
Convergence is already smooth and fast
You’ve tuned the learning rate well for your problem
You want minimal computational overhead
The optimization landscape is well-behaved
Kingma and Ba (2014) - Adam: A Method for Stochastic Optimization. arXiv:1412.6980.
from phasic import Adamoptimizer = Adam( learning_rate=0.001, # Base learning rate, scaled per-parameter by Adam (alpha) beta1=0.9, # Decay rate for first moment (momentum). Higher = more smoothing beta2=0.999, # Decay rate for second moment. Higher = longer memory of gradient magnitudes epsilon=1e-8# Small constant to prevent division by zero)
Parameter Fixed MAP Mean SD HPD 95% lo HPD 95% hi
0 No 6.9434 6.9406 0.0266 6.9133 6.9979
Particles: 24, Iterations: 100
svgd.plot_trace()
svgd.plot_convergence() ;
SGDMomentum
SGDMomentum (SGD with momentum) accumulates velocity in directions of persistent gradient descent, helping accelerate convergence and dampen oscillations. Update rule:
\begin{align}
v &= m * v + g \\
\theta &= \theta + (l * v)
\end{align}
where g is gradient l is learning rate, and m and v are the first and second moments.
RMSprop(Hinton 2012) divides the learning rate by an exponentially decaying average of squared gradients, adapting per-parameter. Update rule:
$$ \begin{align}
v &= d * v + (1 - d) * g^2 \\
\theta &= \theta + \frac{l * g}{\sqrt{v} + \epsilon}
\end{align} $$
where g is gradient l is learning rate, d is the decay and v is
from phasic.svgd import RMSpropoptimizer = RMSprop( learning_rate=1, # Base learning rate (default 0.001) decay=0.99, # Decay rate for squared gradient average epsilon=1e-8# Numerical stability)svgd = graph.svgd( observed_data, optimizer=optimizer,)
svgd.plot_ci(true_theta=true_theta)
Adagrad
Adagrad(Duchi et al. 2011) accumulates the sum of squared gradients, giving smaller learning rates to parameters with large accumulated gradients. Update rule: