optax_chain
phasic.optax_wrapper.optax_chain(*transforms)Create chained Optax transforms wrapped for phasic.
This allows combining multiple gradient transformations, such as gradient clipping followed by an optimizer.
Parameters
*transforms :optax.GradientTransformation= ()-
Optax gradient transformations to chain together.
Returns
:OptaxOptimizer-
Wrapped chained optimizer compatible with phasic SVGD.
Examples
>>> import optax
>>> from phasic import optax_chain
>>>
>>> # Adam with gradient clipping
>>> optimizer = optax_chain(
... optax.clip_by_global_norm(1.0),
... optax.adam(0.001)
... )
>>>
>>> # Learning rate warmup with Adam
>>> schedule = optax.warmup_cosine_decay_schedule(
... init_value=0.0,
... peak_value=0.01,
... warmup_steps=100,
... decay_steps=1000
... )
>>> optimizer = optax_chain(
... optax.adam(schedule)
... )