phasic optimizes performance through a three-layer caching system, each targeting a different computational bottleneck:
flowchart TB A[User Code] --> B[Layer 1: Trace Cache] B --> C[Layer 2: SVGD Compilation Cache] C --> D[Layer 3: JAX Compilation Cache] D --> E[Execution] B -.->|Hit: 0.1-1ms| E B -.->|Miss: 10-1000ms| C C -.->|Hit: instant| E C -.->|Miss: 1-60s| D D -.->|Hit: instant| E D -.->|Miss: compile| E style B fill:#e1f5dd style C fill:#fff4dd style D fill:#ffe4dd
flowchart TB
A[User Code] --> B[Layer 1: Trace Cache]
B --> C[Layer 2: SVGD Compilation Cache]
C --> D[Layer 3: JAX Compilation Cache]
D --> E[Execution]
B -.->|Hit: 0.1-1ms| E
B -.->|Miss: 10-1000ms| C
C -.->|Hit: instant| E
C -.->|Miss: 1-60s| D
D -.->|Hit: instant| E
D -.->|Miss: compile| E
style B fill:#e1f5dd
style C fill:#fff4dd
style D fill:#ffe4dd
Why Three Layers?
For parameterized models, computation happens in stages:
By caching all three stages, repeated evaluations become nearly instant!
Layer 1: Trace Cache
What It Caches
The trace cache stores pre-computed graph elimination traces from record_elimination_trace(). The cache key is a SHA-256 hash of:
Graph topology (vertices, edges)
State space dimensions
Parameterization patterns (edge coefficients)
NOT actual parameter values (structure-only)
Location:~/.phasic_cache/traces/*.json
How It Works
flowchart LR A[Graph Structure] --> B[Serialize & Hash] B --> C{Cache Hit?} C -->|Yes| D[Load Trace<br/>0.1-1ms] C -->|No| E[Perform Elimination<br/>10-1000ms] E --> F[Save Trace] F --> G[Return Result] D --> G
flowchart LR
A[Graph Structure] --> B[Serialize & Hash]
B --> C{Cache Hit?}
C -->|Yes| D[Load Trace<br/>0.1-1ms]
C -->|No| E[Perform Elimination<br/>10-1000ms]
E --> F[Save Trace]
F --> G[Return Result]
D --> G
Algorithm: 1. Graph structure is serialized and hashed (SHA-256) 2. Check ~/.phasic_cache/traces/{hash}.json 3. Hit: Load trace and skip elimination (0.1-1ms) 4. Miss: Perform elimination (10-1000ms), save trace 5. Future builds of same structure: instant
Basic Usage
The trace cache is used automatically when building graphs:
from phasic import Graphimport numpy as np# Define callback for parameterized graphdef coalescent_callback(state, nr_samples=3):iflen(state) ==0:return [(np.array([nr_samples]), 1.0, [1.0])]if state[0] >1: n = state[0]return [(np.array([n -1]), 0.0, [n * (n -1) /2])]return []# Build graph (first time: slow, subsequent: fast)g = Graph(coalescent_callback)# First build: 10-1000ms (eliminates and caches)# Second build: 0.1-1ms (loads from cache)
Cache Inspection
from pathlib import Path# Check trace cache directorytrace_cache_dir = Path.home() /'.phasic_cache'/'traces'num_traces =len(list(trace_cache_dir.glob('*.json'))) if trace_cache_dir.exists() else0print(f"Cached traces: {num_traces}")print(f"Location: {trace_cache_dir}")
Performance Impact
Model Size
No Cache
With Cache
Speedup
37 vertices
45ms
1.3ms
35x
67 vertices
250ms
2.1ms
120x
100+ vertices
500ms+
5ms
100x+
Layer 2: SVGD Compilation Cache
What It Caches
The SVGD cache stores JIT-compiled gradient functions during SVGD initialization. The cache key is:
(model_id, theta_shape, n_particles)
Location: - Memory: Python dict (reliable, session-based) - Disk: Pickle files (unreliable, ~80% failure rate due to JAX limitations)
WarningDisk Cache Limitation
The disk component often fails due to pickle limitations with JAX JIT functions. Memory cache works reliably within a session.
How It Works
flowchart LR A[SVGD Init] --> B{Memory Cache?} B -->|Hit| C[Reuse Gradient<br/>instant] B -->|Miss| D{Disk Cache?} D -->|Hit| E[Load Gradient<br/>fast] D -->|Miss| F[Compile with<br/>jit-grad-log-prob<br/>1-60s] F --> G[Save to Memory] G --> C E --> C
flowchart LR
A[SVGD Init] --> B{Memory Cache?}
B -->|Hit| C[Reuse Gradient<br/>instant]
B -->|Miss| D{Disk Cache?}
D -->|Hit| E[Load Gradient<br/>fast]
D -->|Miss| F[Compile with<br/>jit-grad-log-prob<br/>1-60s]
F --> G[Save to Memory]
G --> C
E --> C
flowchart LR A[jit-f-x call] --> B[Compute Cache Key<br/>signature + shapes] B --> C{Cache Hit?} C -->|Yes| D[Load Compiled<br/>instant] C -->|No| E[Compile with XLA<br/>1-10s] E --> F[Save to Cache] F --> G[Execute] D --> G
flowchart LR
A[jit-f-x call] --> B[Compute Cache Key<br/>signature + shapes]
B --> C{Cache Hit?}
C -->|Yes| D[Load Compiled<br/>instant]
C -->|No| E[Compile with XLA<br/>1-10s]
E --> F[Save to Cache]
F --> G[Execute]
D --> G
JAX manages this automatically - no user code needed!
Configuration
Environment Variable (Before Import)
import osos.environ['JAX_COMPILATION_CACHE_DIR'] ='/fast/ssd/jax_cache'# THEN import JAXimport jaxfrom phasic import Graph
Using CompilationConfig
from phasic.jax_config import CompilationConfig# Balanced preset (default)config = CompilationConfig.balanced()config.apply()# Maximum performanceconfig = CompilationConfig.max_performance()config.apply()# Fast compilation (development)config = CompilationConfig.fast_compile()config.apply()# Custom settingsconfig = CompilationConfig( cache_dir='/scratch/jax_cache', optimization_level=3, parallel_compile=True)config.apply()
Cache Management
NoteConsolidated Cache Management (October 2025)
All cache management functions now use CacheManager internally as a single source of truth. No code duplication, 100% backward compatible.
from phasic.cache_manager import CacheManagermanager = CacheManager()# Export cache for distributionmanager.export_cache('jax_cache_backup.tar.gz')# Import cachemanager.import_cache('jax_cache_backup.tar.gz')# Cleanup old entriesmanager.vacuum(max_age_days=30, max_size_gb=10.0)# Sync from shared filesystemmanager.sync_from_remote('/shared/project/jax_cache')
Performance Impact
Operation
No Cache
With Cache
Speedup
First compile
5-10s
5-10s
1x
Same shape
5-10s
<1ms
>5,000x
Different params
5-10s
<1ms
>5,000x
Quick Start
Single Machine Workflow
from phasic import Graph, SVGD, CompilationConfigimport jax.numpy as jnpimport numpy as np# 1. Configure JAX cache (once per session, optional)config = CompilationConfig.balanced()config.apply()# 2. Build model (trace cache used automatically)def my_callback(state):# Your model definition ...g = Graph(my_callback)model = Graph.pmf_from_graph(g) # Trace cache: auto# 3. Run SVGD (all caches work together)theta = jnp.array([1.0])times = jnp.linspace(0.1, 5, 50)observed_data = model(theta, times) + np.random.normal(0, 0.01, 50)# First run: Slow (trace load + SVGD compile + JAX compile)svgd = SVGD( model=model, observed_data=observed_data, theta_dim=1, n_particles=100, n_iterations=1000)svgd.fit() # ~10-40 seconds first time# Second run: Fast (all caches populated)svgd2 = SVGD(model=model, observed_data=observed_data, theta_dim=1, n_particles=100, n_iterations=1000)svgd2.fit() # <10 seconds (mostly SVGD iterations)
Pre-warming Cache for Production
from phasic.cache_manager import CacheManagermanager = CacheManager()# Define expected input shapestheta_samples = [ jnp.ones(1), jnp.ones(2), jnp.ones(5)]time_grids = [ jnp.linspace(0.1, 5, 20), jnp.linspace(0.1, 5, 50), jnp.linspace(0.1, 5, 100)]# Pre-compile for all combinationsmanager.prewarm_model(model, theta_samples, time_grids)# Pre-warming JAX cache for 9 combinations...# [1/9] theta_shape=(1,), times_shape=(20,)... ✓# ...# ✓ Pre-warming complete in 45.2s# Cache size: 123.4 MB (9 files)# Now production queries are instant!pdf = model(jnp.ones(2), jnp.linspace(0.1, 5, 50)) # <1ms
from phasic.jax_config import CompilationConfigconfig = CompilationConfig( cache_dir='/home/user/.jax_cache', # Local (fast) shared_cache_dir='/shared/project/jax_cache', # Shared (read-only) cache_strategy='layered')config.apply()# JAX checks: local → shared → compile
Cache Synchronization
from phasic.cache_manager import CacheManagermanager = CacheManager(cache_dir='/home/user/.jax_cache')# Pull updates from shared cachemanager.sync_from_remote('/shared/project/jax_cache')# Dry run to previewmanager.sync_from_remote('/shared/project/jax_cache', dry_run=True)
SLURM Example
#!/bin/bash#SBATCH --job-name=ptd_inference#SBATCH --nodes=4#SBATCH --ntasks-per-node=8# Shared cache on network storageexportSHARED_CACHE=/shared/project/phasic_cacheexportLOCAL_CACHE=$HOME/.jax_cache# Sync cache at job startpython-c"from phasic.cache_manager import CacheManagermanager = CacheManager(cache_dir='$LOCAL_CACHE')manager.sync_from_remote('$SHARED_CACHE')"# Run job (uses local cache)srun python my_inference.py# Optionally sync new compilations backrsync-av$LOCAL_CACHE/ $SHARED_CACHE/
Testing the Cache System
Comprehensive Test Suite
Run the full cache testing suite:
python tests/test_svgd_jax.py# See Test 7 for comprehensive cache layer testing
This demonstrates: - Trace cache testing with timing - SVGD compilation cache behavior - JAX compilation cache management - Cache management functions - Full pipeline integration
Manual Cache Testing
from phasic import Graph, cache_info, print_cache_infoimport time# Test trace cacheprint("[1] Trace Cache Test")g1 = Graph(my_callback)start = time.time()g2 = Graph(my_callback)speedup = (time.time() - start) *1000print(f" Second build: {speedup:.1f}ms (cache hit)")# Test JAX cacheprint("\n[2] JAX Cache Test")info_before = cache_info()print(f" Before: {info_before['num_files']} files")# Run something that compilesmodel = Graph.pmf_from_graph(g1, discrete=False, theta_dim=1)# ... use model ...info_after = cache_info()print(f" After: {info_after['num_files']} files")print(f" New compilations: {info_after['num_files'] - info_before['num_files']}")# Print detailed cache infoprint("\n[3] Detailed Cache Info")print_cache_info(max_files=5)
Best Practices
DO: Always Enable Caching
# Good: Use default caching (trace cache automatic)g = Graph(my_callback)model = Graph.pmf_from_graph(g) # Cache enabled# Bad: Unnecessary re-computation# (Note: Graph doesn't have use_cache parameter, caching is automatic)
DO: Configure JAX Early
# BEFORE importing JAXimport osos.environ['JAX_COMPILATION_CACHE_DIR'] ='/fast/storage'# THEN importimport jaxfrom phasic import Graph
DO: Pre-warm for Production
from phasic.cache_manager import CacheManagermanager = CacheManager()manager.prewarm_model(model, expected_shapes, expected_grids)# Now production queries are instant
DO: Monitor Cache Size
from phasic import cache_infofrom phasic.cache_manager import CacheManager# Check size regularlyinfo = cache_info()if info['total_size_mb'] >5000: # 5 GBprint("Warning: JAX cache is large")# Clean up old entries manager = CacheManager() manager.vacuum(max_age_days=30, max_size_gb=10.0)
❌ DON’T: Rely on Disk SVGD Cache
# Don't expect SVGD disk cache to work reliably# It fails ~80% of the time due to pickle limitations# Use memory cache within session, JAX cache across sessions
Expected behavior: First run does elimination + compilation
To speed up: 1. Use cached traces (build same structure multiple times) 2. Pre-warm JAX cache with prewarm_model() 3. Use smaller models for testing, production models for deployment
Performance Benchmarks
Trace Cache Impact
Model Size
No Cache
With Cache
Speedup
37 vertices
45ms
1.3ms
35x
67 vertices
250ms
2.1ms
120x
100 vertices
500ms
5ms
100x
JAX Cache Impact
Operation
No Cache
With Cache
Speedup
First compile
5-10s
5-10s
1x
Same shape
5-10s
<1ms
>5,000x
Different params
5-10s
<1ms
>5,000x
Combined Impact: SVGD with 1000 Iterations
Workflow: Repeated SVGD runs for parameter exploration