GaussPrior
phasic.svgd.GaussPrior(mean=None, std=None, ci=None, prob=0.95)Gaussian prior distribution.
The prior is defined in THETA space (the natural parameter space). When used with positive_params=True, SVGD automatically handles the transformation to PHI space with proper Jacobian correction.
Can be specified via mean/std or credible interval.
Parameters
mean :float= None-
Prior mean in THETA space. Required if std is provided.
std :float= None-
Prior standard deviation in THETA space. Required if mean is provided.
ci : tuple of (float, float) = None-
Credible interval (low, high) in THETA space. Alternative to mean/std.
prob :float= 0.95-
Probability mass in the credible interval (only used with ci).
Examples
>>> # Specify via mean and std
>>> prior = GaussPrior(mean=5.0, std=2.0)
>>>
>>> # Specify via 95% credible interval
>>> prior = GaussPrior(ci=(2.0, 8.0))
>>>
>>> # Plot to verify prior matches your beliefs
>>> prior.plot() # Shows Gaussian centered at 5
>>>
>>> # Use in SVGD - transformations handled automatically
>>> svgd = graph.svgd(data, theta_dim=1, prior=prior)Methods
| Name | Description |
|---|---|
| plot | Plot the Gaussian prior distribution in THETA space. |
| sample | Sample from the prior. |
plot
phasic.svgd.GaussPrior.plot(log=False, ax=None, return_ax=False, **kwargs)Plot the Gaussian prior distribution in THETA space.
Parameters
log :bool= False-
If True, plot log-probability instead of probability density.
ax :matplotlib.axes.Axes= None-
Axes to plot on. If None, creates new figure.
return_ax :bool= True-
If True, return ax. If False, call plt.show() instead.
****kwargs** : = {}-
Additional arguments passed to plot function.
Returns
:matplotlib.axes.Axes-
The axes with the plot (only if return_ax=False)
sample
phasic.svgd.GaussPrior.sample(key, shape)Sample from the prior.
When _transform is set, samples in THETA space and converts to PHI space.
Parameters
key :jax.random.PRNGKey-
Random key
shape :tuple-
Shape of samples (n_particles, theta_dim)
Returns
:array-
Samples (in PHI space if transform is set)