Caching System

This guide covers the comprehensive three-layer caching system in phasic, which provides 10-1000x speedups for repeated model evaluations.

TipQuick Summary

phasic uses three complementary cache layers: 1. Trace Cache - Graph elimination operations (10-1000x speedup) 2. SVGD Compilation Cache - JIT-compiled gradients (instant on hit) 3. JAX Compilation Cache - XLA compilations (seconds → instant)


Overview

Three-Layer Architecture

phasic optimizes performance through a three-layer caching system, each targeting a different computational bottleneck:

flowchart TB
    A[User Code] --> B[Layer 1: Trace Cache]
    B --> C[Layer 2: SVGD Compilation Cache]
    C --> D[Layer 3: JAX Compilation Cache]
    D --> E[Execution]

    B -.->|Hit: 0.1-1ms| E
    B -.->|Miss: 10-1000ms| C
    C -.->|Hit: instant| E
    C -.->|Miss: 1-60s| D
    D -.->|Hit: instant| E
    D -.->|Miss: compile| E

    style B fill:#e1f5dd
    style C fill:#fff4dd
    style D fill:#ffe4dd

flowchart TB
    A[User Code] --> B[Layer 1: Trace Cache]
    B --> C[Layer 2: SVGD Compilation Cache]
    C --> D[Layer 3: JAX Compilation Cache]
    D --> E[Execution]

    B -.->|Hit: 0.1-1ms| E
    B -.->|Miss: 10-1000ms| C
    C -.->|Hit: instant| E
    C -.->|Miss: 1-60s| D
    D -.->|Hit: instant| E
    D -.->|Miss: compile| E

    style B fill:#e1f5dd
    style C fill:#fff4dd
    style D fill:#ffe4dd

Why Three Layers?

For parameterized models, computation happens in stages:

Graph Build → [Trace Cache: Elimination] → Eliminated Graph →
    [SVGD Cache: Gradient] → Compiled Gradient →
    [JAX Cache: XLA] → Optimized Code → Result

Computational Costs: - Graph Elimination (10-1000ms): O(n³), structure-dependent - Gradient Compilation (1-60s): First SVGD initialization - XLA Compilation (1-10s): Shape-dependent - Evaluation (<1ms): Parameter-dependent

By caching all three stages, repeated evaluations become nearly instant!


Layer 1: Trace Cache

What It Caches

The trace cache stores pre-computed graph elimination traces from record_elimination_trace(). The cache key is a SHA-256 hash of:

  • Graph topology (vertices, edges)
  • State space dimensions
  • Parameterization patterns (edge coefficients)
  • NOT actual parameter values (structure-only)

Location: ~/.phasic_cache/traces/*.json

How It Works

flowchart LR
    A[Graph Structure] --> B[Serialize & Hash]
    B --> C{Cache Hit?}
    C -->|Yes| D[Load Trace<br/>0.1-1ms]
    C -->|No| E[Perform Elimination<br/>10-1000ms]
    E --> F[Save Trace]
    F --> G[Return Result]
    D --> G

flowchart LR
    A[Graph Structure] --> B[Serialize & Hash]
    B --> C{Cache Hit?}
    C -->|Yes| D[Load Trace<br/>0.1-1ms]
    C -->|No| E[Perform Elimination<br/>10-1000ms]
    E --> F[Save Trace]
    F --> G[Return Result]
    D --> G

Algorithm: 1. Graph structure is serialized and hashed (SHA-256) 2. Check ~/.phasic_cache/traces/{hash}.json 3. Hit: Load trace and skip elimination (0.1-1ms) 4. Miss: Perform elimination (10-1000ms), save trace 5. Future builds of same structure: instant

Basic Usage

The trace cache is used automatically when building graphs:

from phasic import Graph
import numpy as np

# Define callback for parameterized graph
def coalescent_callback(state, nr_samples=3):
    if len(state) == 0:
        return [(np.array([nr_samples]), 1.0, [1.0])]
    if state[0] > 1:
        n = state[0]
        return [(np.array([n - 1]), 0.0, [n * (n - 1) / 2])]
    return []

# Build graph (first time: slow, subsequent: fast)
g = Graph(coalescent_callback)
# First build: 10-1000ms (eliminates and caches)
# Second build: 0.1-1ms (loads from cache)

Cache Inspection

from pathlib import Path

# Check trace cache directory
trace_cache_dir = Path.home() / '.phasic_cache' / 'traces'
num_traces = len(list(trace_cache_dir.glob('*.json'))) if trace_cache_dir.exists() else 0

print(f"Cached traces: {num_traces}")
print(f"Location: {trace_cache_dir}")

Performance Impact

Model Size No Cache With Cache Speedup
37 vertices 45ms 1.3ms 35x
67 vertices 250ms 2.1ms 120x
100+ vertices 500ms+ 5ms 100x+

Layer 2: SVGD Compilation Cache

What It Caches

The SVGD cache stores JIT-compiled gradient functions during SVGD initialization. The cache key is:

(model_id, theta_shape, n_particles)

Location: - Memory: Python dict (reliable, session-based) - Disk: Pickle files (unreliable, ~80% failure rate due to JAX limitations)

WarningDisk Cache Limitation

The disk component often fails due to pickle limitations with JAX JIT functions. Memory cache works reliably within a session.

How It Works

flowchart LR
    A[SVGD Init] --> B{Memory Cache?}
    B -->|Hit| C[Reuse Gradient<br/>instant]
    B -->|Miss| D{Disk Cache?}
    D -->|Hit| E[Load Gradient<br/>fast]
    D -->|Miss| F[Compile with<br/>jit-grad-log-prob<br/>1-60s]
    F --> G[Save to Memory]
    G --> C
    E --> C

flowchart LR
    A[SVGD Init] --> B{Memory Cache?}
    B -->|Hit| C[Reuse Gradient<br/>instant]
    B -->|Miss| D{Disk Cache?}
    D -->|Hit| E[Load Gradient<br/>fast]
    D -->|Miss| F[Compile with<br/>jit-grad-log-prob<br/>1-60s]
    F --> G[Save to Memory]
    G --> C
    E --> C

Basic Usage

Caching happens automatically in SVGD:

from phasic import SVGD
import jax.numpy as jnp

# Build model
model = Graph.pmf_from_graph(g, discrete=False, theta_dim=1)
observed_data = g.sample(100)

# First SVGD: Compiles gradients (slow)
svgd1 = SVGD(
    model=model,
    observed_data=observed_data,
    theta_dim=1,
    n_particles=100
)
# Initialization: 1-60s (compilation)

# Second SVGD (same config): Uses cache (fast)
svgd2 = SVGD(
    model=model,
    observed_data=observed_data,
    theta_dim=1,
    n_particles=100
)
# Initialization: <100ms (cache hit)

Cache Behavior

Memory cache (reliable): - Lives for Python session duration - Shared across SVGD instances - Instant reuse on cache hit

Disk cache (unreliable): - Often fails to pickle JAX JIT functions - Failures silently ignored - Not recommended for critical workflows

Recommendation: Rely on memory cache within session, JAX cache across sessions.


Layer 3: JAX Compilation Cache

What It Caches

JAX caches compiled XLA code based on: - Function structure (HLO graph) - Input shapes (not values) - Device configuration

Location: ~/.jax_cache/ (or $JAX_COMPILATION_CACHE_DIR)

How It Works

flowchart LR
    A[jit-f-x call] --> B[Compute Cache Key<br/>signature + shapes]
    B --> C{Cache Hit?}
    C -->|Yes| D[Load Compiled<br/>instant]
    C -->|No| E[Compile with XLA<br/>1-10s]
    E --> F[Save to Cache]
    F --> G[Execute]
    D --> G

flowchart LR
    A[jit-f-x call] --> B[Compute Cache Key<br/>signature + shapes]
    B --> C{Cache Hit?}
    C -->|Yes| D[Load Compiled<br/>instant]
    C -->|No| E[Compile with XLA<br/>1-10s]
    E --> F[Save to Cache]
    F --> G[Execute]
    D --> G

JAX manages this automatically - no user code needed!

Configuration

Environment Variable (Before Import)

import os
os.environ['JAX_COMPILATION_CACHE_DIR'] = '/fast/ssd/jax_cache'

# THEN import JAX
import jax
from phasic import Graph

Using CompilationConfig

from phasic.jax_config import CompilationConfig

# Balanced preset (default)
config = CompilationConfig.balanced()
config.apply()

# Maximum performance
config = CompilationConfig.max_performance()
config.apply()

# Fast compilation (development)
config = CompilationConfig.fast_compile()
config.apply()

# Custom settings
config = CompilationConfig(
    cache_dir='/scratch/jax_cache',
    optimization_level=3,
    parallel_compile=True
)
config.apply()

Cache Management

NoteConsolidated Cache Management (October 2025)

All cache management functions now use CacheManager internally as a single source of truth. No code duplication, 100% backward compatible.

from phasic import cache_info, print_cache_info, clear_cache

# Get cache statistics
info = cache_info()
print(f"Files: {info['num_files']}, Size: {info['total_size_mb']:.1f} MB")

# Pretty-print cache info
print_cache_info()
# Output:
# ======================================================================
# JAX COMPILATION CACHE INFO
# ======================================================================
# Path: /Users/you/.jax_cache
# Cached compilations: 47
# Total size: 234.5 MB
# Most recent files (showing 10/47):
#   2025-10-19T16:30:45 |   2847.3 KB | jax_cache_f3a8b2...
#   ...

# Clear cache
clear_cache()  # Clears default cache
clear_cache('/custom/cache')  # Clear specific cache
clear_cache(verbose=False)  # Silent mode

Advanced Management

from phasic.cache_manager import CacheManager

manager = CacheManager()

# Export cache for distribution
manager.export_cache('jax_cache_backup.tar.gz')

# Import cache
manager.import_cache('jax_cache_backup.tar.gz')

# Cleanup old entries
manager.vacuum(max_age_days=30, max_size_gb=10.0)

# Sync from shared filesystem
manager.sync_from_remote('/shared/project/jax_cache')

Performance Impact

Operation No Cache With Cache Speedup
First compile 5-10s 5-10s 1x
Same shape 5-10s <1ms >5,000x
Different params 5-10s <1ms >5,000x

Quick Start

Single Machine Workflow

from phasic import Graph, SVGD, CompilationConfig
import jax.numpy as jnp
import numpy as np

# 1. Configure JAX cache (once per session, optional)
config = CompilationConfig.balanced()
config.apply()

# 2. Build model (trace cache used automatically)
def my_callback(state):
    # Your model definition
    ...

g = Graph(my_callback)
model = Graph.pmf_from_graph(g)  # Trace cache: auto

# 3. Run SVGD (all caches work together)
theta = jnp.array([1.0])
times = jnp.linspace(0.1, 5, 50)
observed_data = model(theta, times) + np.random.normal(0, 0.01, 50)

# First run: Slow (trace load + SVGD compile + JAX compile)
svgd = SVGD(
    model=model,
    observed_data=observed_data,
    theta_dim=1,
    n_particles=100,
    n_iterations=1000
)
svgd.fit()  # ~10-40 seconds first time

# Second run: Fast (all caches populated)
svgd2 = SVGD(model=model, observed_data=observed_data,
            theta_dim=1, n_particles=100, n_iterations=1000)
svgd2.fit()  # <10 seconds (mostly SVGD iterations)

Pre-warming Cache for Production

from phasic.cache_manager import CacheManager

manager = CacheManager()

# Define expected input shapes
theta_samples = [
    jnp.ones(1),
    jnp.ones(2),
    jnp.ones(5)
]

time_grids = [
    jnp.linspace(0.1, 5, 20),
    jnp.linspace(0.1, 5, 50),
    jnp.linspace(0.1, 5, 100)
]

# Pre-compile for all combinations
manager.prewarm_model(model, theta_samples, time_grids)
# Pre-warming JAX cache for 9 combinations...
# [1/9] theta_shape=(1,), times_shape=(20,)... ✓
# ...
# ✓ Pre-warming complete in 45.2s
# Cache size: 123.4 MB (9 files)

# Now production queries are instant!
pdf = model(jnp.ones(2), jnp.linspace(0.1, 5, 50))  # <1ms

Advanced Usage

Full Pipeline Integration

All three caches work together seamlessly:

from phasic import Graph, SVGD
import jax.numpy as jnp
import time

# Build graph
print("[1] Building graph")
start = time.time()
g = Graph(my_callback)
t_graph = time.time() - start
print(f"    → TRACE CACHE: {t_graph*1000:.1f}ms")

# Create model
print("[2] Creating model")
start = time.time()
model = Graph.pmf_from_graph(g, discrete=False, theta_dim=1)
t_model = time.time() - start
print(f"    → Serialization: {t_model*1000:.1f}ms")

# Initialize SVGD
print("[3] Initializing SVGD")
start = time.time()
svgd = SVGD(model=model, observed_data=data, theta_dim=1, n_particles=100)
t_init = time.time() - start
print(f"    → SVGD COMPILATION CACHE: {t_init:.2f}s")

# Run SVGD
print("[4] Running SVGD.fit()")
start = time.time()
svgd.fit()
t_fit = time.time() - start
print(f"    → JAX COMPILATION CACHE: {t_fit:.2f}s")

# Summary
total = t_graph + t_model + t_init + t_fit
print(f"\nTotal: {total:.2f}s")
print(f"  Graph build: {t_graph/total*100:.1f}%")
print(f"  SVGD init:   {t_init/total*100:.1f}%")
print(f"  SVGD fit:    {t_fit/total*100:.1f}%")

Distributed Computing

Layered Cache Strategy

For clusters with shared filesystem:

from phasic.jax_config import CompilationConfig

config = CompilationConfig(
    cache_dir='/home/user/.jax_cache',          # Local (fast)
    shared_cache_dir='/shared/project/jax_cache', # Shared (read-only)
    cache_strategy='layered'
)
config.apply()

# JAX checks: local → shared → compile

Cache Synchronization

from phasic.cache_manager import CacheManager

manager = CacheManager(cache_dir='/home/user/.jax_cache')

# Pull updates from shared cache
manager.sync_from_remote('/shared/project/jax_cache')

# Dry run to preview
manager.sync_from_remote('/shared/project/jax_cache', dry_run=True)

SLURM Example

#!/bin/bash
#SBATCH --job-name=ptd_inference
#SBATCH --nodes=4
#SBATCH --ntasks-per-node=8

# Shared cache on network storage
export SHARED_CACHE=/shared/project/phasic_cache
export LOCAL_CACHE=$HOME/.jax_cache

# Sync cache at job start
python -c "
from phasic.cache_manager import CacheManager
manager = CacheManager(cache_dir='$LOCAL_CACHE')
manager.sync_from_remote('$SHARED_CACHE')
"

# Run job (uses local cache)
srun python my_inference.py

# Optionally sync new compilations back
rsync -av $LOCAL_CACHE/ $SHARED_CACHE/

Testing the Cache System

Comprehensive Test Suite

Run the full cache testing suite:

python tests/test_svgd_jax.py
# See Test 7 for comprehensive cache layer testing

This demonstrates: - Trace cache testing with timing - SVGD compilation cache behavior - JAX compilation cache management - Cache management functions - Full pipeline integration

Manual Cache Testing

from phasic import Graph, cache_info, print_cache_info
import time

# Test trace cache
print("[1] Trace Cache Test")
g1 = Graph(my_callback)
start = time.time()
g2 = Graph(my_callback)
speedup = (time.time() - start) * 1000
print(f"    Second build: {speedup:.1f}ms (cache hit)")

# Test JAX cache
print("\n[2] JAX Cache Test")
info_before = cache_info()
print(f"    Before: {info_before['num_files']} files")

# Run something that compiles
model = Graph.pmf_from_graph(g1, discrete=False, theta_dim=1)
# ... use model ...

info_after = cache_info()
print(f"    After: {info_after['num_files']} files")
print(f"    New compilations: {info_after['num_files'] - info_before['num_files']}")

# Print detailed cache info
print("\n[3] Detailed Cache Info")
print_cache_info(max_files=5)

Best Practices

DO: Always Enable Caching

# Good: Use default caching (trace cache automatic)
g = Graph(my_callback)
model = Graph.pmf_from_graph(g)  # Cache enabled

# Bad: Unnecessary re-computation
# (Note: Graph doesn't have use_cache parameter, caching is automatic)

DO: Configure JAX Early

# BEFORE importing JAX
import os
os.environ['JAX_COMPILATION_CACHE_DIR'] = '/fast/storage'

# THEN import
import jax
from phasic import Graph

DO: Pre-warm for Production

from phasic.cache_manager import CacheManager

manager = CacheManager()
manager.prewarm_model(model, expected_shapes, expected_grids)
# Now production queries are instant

DO: Monitor Cache Size

from phasic import cache_info
from phasic.cache_manager import CacheManager

# Check size regularly
info = cache_info()
if info['total_size_mb'] > 5000:  # 5 GB
    print("Warning: JAX cache is large")

    # Clean up old entries
    manager = CacheManager()
    manager.vacuum(max_age_days=30, max_size_gb=10.0)

❌ DON’T: Rely on Disk SVGD Cache

# Don't expect SVGD disk cache to work reliably
# It fails ~80% of the time due to pickle limitations
# Use memory cache within session, JAX cache across sessions

❌ DON’T: Configure JAX After Import

# Wrong: Too late!
import jax
config = CompilationConfig()
config.apply()  # Ignored - JAX already imported

# Correct: Configure first
config = CompilationConfig()
config.apply()
import jax

Troubleshooting

Cache Not Working?

Symptom: Model recompiles every time

Solutions:

  1. Check cache directory exists and is writable:

    import os
    from pathlib import Path
    
    cache_dir = os.environ.get('JAX_COMPILATION_CACHE_DIR',
                                str(Path.home() / '.jax_cache'))
    cache_path = Path(cache_dir).expanduser()
    
    print(f"Cache dir: {cache_path}")
    print(f"Exists: {cache_path.exists()}")
    if cache_path.exists():
        print(f"Writable: {os.access(cache_path, os.W_OK)}")
  2. Ensure cache configured BEFORE JAX import:

    # Restart Python and try:
    import os
    os.environ['JAX_COMPILATION_CACHE_DIR'] = '/your/path'
    import jax  # Now cache config is active
  3. Check trace cache:

    from pathlib import Path
    trace_dir = Path.home() / '.phasic_cache' / 'traces'
    if not trace_dir.exists():
        trace_dir.mkdir(parents=True, exist_ok=True)

Out of Disk Space?

Solution: Regular cleanup

from phasic.cache_manager import CacheManager

manager = CacheManager()
manager.vacuum(max_age_days=7, max_size_gb=5.0)

Or manual cleanup:

# Clear JAX cache
rm -rf ~/.jax_cache/*

# Clear trace cache
rm -rf ~/.phasic_cache/traces/*

Slow First Run?

Expected behavior: First run does elimination + compilation

To speed up: 1. Use cached traces (build same structure multiple times) 2. Pre-warm JAX cache with prewarm_model() 3. Use smaller models for testing, production models for deployment


Performance Benchmarks

Trace Cache Impact

Model Size No Cache With Cache Speedup
37 vertices 45ms 1.3ms 35x
67 vertices 250ms 2.1ms 120x
100 vertices 500ms 5ms 100x

JAX Cache Impact

Operation No Cache With Cache Speedup
First compile 5-10s 5-10s 1x
Same shape 5-10s <1ms >5,000x
Different params 5-10s <1ms >5,000x

Combined Impact: SVGD with 1000 Iterations

Workflow: Repeated SVGD runs for parameter exploration

  • No caching: 10s compile × 1000 = 2.7 hours
  • JAX cache only: 10s compile + 1000 × 1ms = 11 seconds
  • All caches: 1ms load + 1000 × 1ms = 1.0 seconds

Total speedup: ~9,900x

Real-World Example

Scenario: MCMC parameter exploration with 10,000 model evaluations

# Without caching
# Each evaluation: 10-1000ms elimination + 1-10s compile = ~11s
# Total: 11s × 10,000 = 30.5 hours

# With all caches
# Elimination: 1ms (cached)
# Compilation: <1ms (cached)
# Evaluation: <1ms
# Total per eval: ~3ms
# Total: 3ms × 10,000 = 30 seconds

Speedup: 3,660x