OptaxOptimizer

phasic.optax_wrapper.OptaxOptimizer(optax_optimizer)

Wrapper to use Optax optimizers with phasic’s SVGD interface.

This class adapts Optax’s functional optimizer interface to phasic’s object-oriented optimizer interface expected by SVGD.

Parameters

optax_optimizer : optax.GradientTransformation

An Optax optimizer or chain of gradient transformations.

Attributes

optimizer : optax.GradientTransformation

The underlying Optax optimizer.

state : optax.OptState or None

The optimizer state, initialized by reset().

Examples

>>> import optax
>>> from phasic.optax_wrapper import OptaxOptimizer
>>>
>>> # Wrap any Optax optimizer
>>> optimizer = OptaxOptimizer(optax.adam(learning_rate=0.001))
>>>
>>> # Or use with gradient clipping
>>> optimizer = OptaxOptimizer(optax.chain(
...     optax.clip_by_global_norm(1.0),
...     optax.adam(0.001)
... ))

Methods

Name Description
reset Initialize optimizer state for given particle shape.
step Compute update using Optax optimizer.

reset

phasic.optax_wrapper.OptaxOptimizer.reset(shape, params=None)

Initialize optimizer state for given particle shape.

Called at the start of optimization to initialize the Optax state.

Parameters

shape : tuple[int, …]

Shape of particles array (n_particles, theta_dim).

params : jax.Array or None = None

Initial parameter values. If provided, used for optimizers that need current params (e.g., adamw for weight decay).

step

phasic.optax_wrapper.OptaxOptimizer.step(phi, params=None, particles=None)

Compute update using Optax optimizer.

Parameters

phi : jax.Array

SVGD gradient direction of shape (n_particles, theta_dim): (K @ grad_log_p + sum(grad_K)) / n_particles.

params : jax.Array or None = None

Current parameter values. Required for optimizers with weight decay (e.g., adamw). If not provided, uses internally tracked params.

particles : jax.Array or None = None

Alias for params, for compatibility with phasic optimizer interface.

Returns

: jax.Array

Scaled update of shape (n_particles, theta_dim) to add to particles.