Caching
This covers the comprehensive caching system in phasic, which provides 10-1000x speedups for repeated model evaluations through multiple complementary caching mechanisms.
Graph Cache
What It Caches
The graph cache stores fully constructed Graph objects to disk, keyed by: - Callback function source code - All construction parameters - Initial probability vector (IPV)
This is perfect for expensive graph constructions that take seconds or minutes to build.
Basic Usage
from phasic import Graph
import numpy as np
# Define model callback
@phasic.with_ipv([10])
def coalescent(state, theta=1.0):
n = state[0]
if n <= 1:
return []
rate = n * (n - 1) / 2 * theta
return [[[n - 1], rate]]
# Enable caching with cache=True
graph = Graph(coalescent, theta=2.0, cache=True) # Builds + caches (slow)
graph = Graph(coalescent, theta=2.0, cache=True) # Loads from cache (instant!)
# Default is cache=False (no caching)
graph = Graph(coalescent, theta=2.0) # Always builds from scratchCache Inspection
from phasic import print_graph_cache_info, get_graph_cache_stats
# Print formatted stats
print_graph_cache_info()
# Output:
# ======================================================================
# GRAPH CACHE INFO
# ======================================================================
# Cache directory: /Users/you/.phasic_cache/graphs
# Cached graphs: 3
# Total size: 1.23 MB
# Programmatic access
stats = get_graph_cache_stats()
print(f"Cached: {stats['num_graphs']} graphs")Custom Class Serialization
To cache graphs with custom parameters, implement to_dict() and from_dict():
from phasic.state_indexing import StateIndexer, Property
class MyModel:
def __init__(self, param1, param2):
self.param1 = param1
self.param2 = param2
def to_dict(self):
return {'param1': self.param1, 'param2': self.param2}
@classmethod
def from_dict(cls, data):
return cls(data['param1'], data['param2'])
# Now cacheable!
model = MyModel(1.0, 2.0)
graph = Graph(callback, model=model, cache=True)
# StateIndexer already implements to_dict/from_dict
indexer = StateIndexer(lineage=[Property('descendants', max_value=10)])
graph = Graph(callback, indexer=indexer, cache=True) # Works!Trace Cache
What It Caches
The trace cache stores pre-computed elimination traces from graph elimination. These are used with hierarchical=True mode:
# Build graph with hierarchical traces
g = Graph(callback, hierarchical=True)
# First moment computation: Records trace (slow)
mean1 = g.moments()[0] # ~10s
# Second moment computation: Uses cached trace (fast!)
mean2 = g.moments()[0] # <1msCache Inspection
from phasic import print_trace_cache_info, get_trace_cache_stats
# Print formatted stats
print_trace_cache_info()
# Output:
# ======================================================================
# TRACE CACHE INFO
# ======================================================================
# Cache directory: /Users/you/.phasic_cache/traces
# Cached traces: 5
# Total size: 3.45 MB
# Programmatic access
stats = get_trace_cache_stats()
print(f"Cached: {stats['total_files']} traces")JAX Compilation Cache
What It Caches
JAX caches compiled XLA code based on: - Function structure (HLO graph) - Input shapes - Device configuration
Basic Configuration
from phasic.jax_config import CompilationConfig
# Balanced preset (default)
config = CompilationConfig.balanced()
config.apply()
# Maximum performance
config = CompilationConfig.max_performance()
config.apply()
# Fast compilation (for development)
config = CompilationConfig.fast_compile()
config.apply()Cache Location
Default: ~/.jax_cache
Override with environment variable:
export JAX_COMPILATION_CACHE_DIR=/fast/ssd/jax_cacheQuick Start
Single Machine Workflow
from phasic import Graph, print_all_cache_info, clear_caches
import jax.numpy as jnp
# 1. Build model with caching
@phasic.with_ipv([10])
def coalescent(state, theta=1.0):
n = state[0]
if n <= 1:
return []
return [[[n - 1], n * (n - 1) / 2 * theta]]
# Option A: Graph cache (expensive construction)
graph = Graph(coalescent, theta=2.0, cache=True) # Instant 2nd time!
# Option B: Hierarchical traces (repeated evaluations)
graph = Graph(coalescent, hierarchical=True) # Fast moments()
# 2. Check all caches
print_all_cache_info()
# 3. Clear all caches
clear_caches()Unified Cache API
phasic provides a consistent API for all cache types:
# ============================================================
# INSPECTION
# ============================================================
# Print formatted info
from phasic import (
print_all_cache_info, # All caches
print_jax_cache_info, # JAX only
print_graph_cache_info, # Graph only
print_trace_cache_info # Trace only
)
print_all_cache_info() # Shows JAX + Graph + Trace
# Get stats programmatically
from phasic import (
get_all_cache_stats, # All caches
cache_info, # JAX only (alias: get_jax_cache_stats)
get_graph_cache_stats, # Graph only
get_trace_cache_stats # Trace only
)
stats = get_all_cache_stats()
# Returns: {'jax': {...}, 'graph': {...}, 'trace': {...}}
# ============================================================
# CLEARING
# ============================================================
from phasic import (
clear_caches, # Clear ALL caches
clear_jax_cache, # JAX only
clear_model_cache # Graph + Trace (legacy name)
)
clear_caches() # Recommended: clears everythingAdvanced Usage
Low-Level Graph Cache API
from phasic import GraphCache
cache = GraphCache()
# Get or build graph
graph = cache.get_or_build(callback, theta=2.0, nr_samples=10)
# Manual cache management
cache.save_graph(graph, callback, theta=2.0)
graph = cache.load_graph(callback, theta=2.0)
# Stats
stats = cache.get_cache_stats()
# Clear
cache.clear_graph_cache()Distributing Caches
# On machine with pre-computed caches
cd ~/.phasic_cache
tar -czf my_caches.tar.gz graphs/ traces/
# Transfer to other machines
scp my_caches.tar.gz user@cluster:~
# On destination machine
cd ~/.phasic_cache
tar -xzf ~/my_caches.tar.gzBest Practices
Use cases
Graph Cache (cache=True): - Expensive graph construction (>1 second) - Same callback + parameters used repeatedly - Sharing pre-built graphs across sessions - ❌ Rapidly changing callbacks (cache invalidation overhead)
Trace Cache (hierarchical=True): - Computing moments/expectations repeatedly - SVGD or MCMC with many iterations - Same graph structure, different parameters - ❌ Single-use graphs
JAX Cache: - Always enable (default) - JAX-based models (pmf_from_graph) - Production deployments
Cache Hygiene
# Regular cleanup
from phasic import get_all_cache_stats, clear_caches
stats = get_all_cache_stats()
# Check total size
total_mb = (
stats['jax']['total_size_mb'] +
stats['graph']['total_size_mb'] +
stats['trace']['total_mb']
)
if total_mb > 1000: # > 1GB
print("Warning: Caches are large")
# Clear old caches or specific types
clear_caches()Troubleshooting
Graph Cache Not Working?
Symptom: cache=True doesn’t save/load graphs
Solutions:
Check callback is hashable (no lambdas in interactive sessions):
# Won't cache: graph = Graph(lambda state: ..., cache=True) # No source code! # Will cache: def callback(state): ... graph = Graph(callback, cache=True) # Source code availableEnsure custom parameters implement
to_dict()/from_dict()Check cache directory exists and is writable:
from phasic import get_graph_cache_stats stats = get_graph_cache_stats() print(stats['cache_dir'])
Trace Cache Not Working?
Symptom: hierarchical=True recomputes every time
Solution: Check trace cache:
from phasic import print_trace_cache_info
print_trace_cache_info()Out of Disk Space?
from phasic import clear_caches
clear_caches(verbose=True) # Clear all cachesPerformance Benchmarks
Graph Cache Impact
| Model Size | No Cache | With Cache | Speedup |
|---|---|---|---|
| 100 vertices | 2s | instant | ∞ (cached) |
| 1,000 vertices | 15s | instant | ∞ (cached) |
| 10,000 vertices | 120s | instant | ∞ (cached) |
Combined Impact
Workflow: MCMC with 1,000 iterations on 100-vertex graph
- No caching: 2s build × 1,000 = 33 minutes
- Graph cache: instant build + 1,000 × fast eval = 10 seconds
- Both caches: instant build + instant eval = <1 second
Total speedup: ~2,000x