OptaxOptimizer
phasic.optax_wrapper.OptaxOptimizer(optax_optimizer)Wrapper to use Optax optimizers with phasic’s SVGD interface.
This class adapts Optax’s functional optimizer interface to phasic’s object-oriented optimizer interface expected by SVGD.
Parameters
optax_optimizer :optax.GradientTransformation-
An Optax optimizer or chain of gradient transformations.
Attributes
optimizer :optax.GradientTransformation-
The underlying Optax optimizer.
state :optax.OptStateor None-
The optimizer state, initialized by reset().
Examples
>>> import optax
>>> from phasic.optax_wrapper import OptaxOptimizer
>>>
>>> # Wrap any Optax optimizer
>>> optimizer = OptaxOptimizer(optax.adam(learning_rate=0.001))
>>>
>>> # Or use with gradient clipping
>>> optimizer = OptaxOptimizer(optax.chain(
... optax.clip_by_global_norm(1.0),
... optax.adam(0.001)
... ))Methods
| Name | Description |
|---|---|
| reset | Initialize optimizer state for given particle shape. |
| step | Compute update using Optax optimizer. |
reset
phasic.optax_wrapper.OptaxOptimizer.reset(shape, params=None)Initialize optimizer state for given particle shape.
Called at the start of optimization to initialize the Optax state.
Parameters
shape :tuple[int, …]-
Shape of particles array (n_particles, theta_dim).
params :jax.Arrayor None = None-
Initial parameter values. If provided, used for optimizers that need current params (e.g., adamw for weight decay).
step
phasic.optax_wrapper.OptaxOptimizer.step(phi, params=None, particles=None)Compute update using Optax optimizer.
Parameters
phi :jax.Array-
SVGD gradient direction of shape
(n_particles, theta_dim):(K @ grad_log_p + sum(grad_K)) / n_particles. params :jax.Arrayor None = None-
Current parameter values. Required for optimizers with weight decay (e.g., adamw). If not provided, uses internally tracked params.
particles :jax.Arrayor None = None-
Alias for params, for compatibility with phasic optimizer interface.
Returns
:jax.Array-
Scaled update of shape
(n_particles, theta_dim)to add to particles.