optax_chain

phasic.optax_wrapper.optax_chain(*transforms)

Create chained Optax transforms wrapped for phasic.

This allows combining multiple gradient transformations, such as gradient clipping followed by an optimizer.

Parameters

*transforms : optax.GradientTransformation = ()

Optax gradient transformations to chain together.

Returns

: OptaxOptimizer

Wrapped chained optimizer compatible with phasic SVGD.

Examples

>>> import optax
>>> from phasic import optax_chain
>>>
>>> # Adam with gradient clipping
>>> optimizer = optax_chain(
...     optax.clip_by_global_norm(1.0),
...     optax.adam(0.001)
... )
>>>
>>> # Learning rate warmup with Adam
>>> schedule = optax.warmup_cosine_decay_schedule(
...     init_value=0.0,
...     peak_value=0.01,
...     warmup_steps=100,
...     decay_steps=1000
... )
>>> optimizer = optax_chain(
...     optax.adam(schedule)
... )