Caching
This page covers the comprehensive caching system in phasic, which provides 10-1000x speedups for repeated model evaluations through multiple complementary caching mechanisms.
Graph Cache
The graph cache allow fast reconstruction of graphs that take a long time to build. The cache is keyed by the callback function source code with all its construction parameters and initial probability vector (IPV). To enable the use of cache, just pass cache=True to the Graph constructor:
Basic Usage
from phasic import Graph
import numpy as np
# Define model callback
@phasic.with_ipv([10])
def coalescent(state, theta=1.0):
n = state[0]
if n <= 1:
return []
rate = n * (n - 1) / 2 * theta
return [[[n - 1], rate]]
# Enable caching with cache=True
graph = Graph(coalescent, theta=2.0, cache=True) # Builds + caches (slow)
graph = Graph(coalescent, theta=2.0, cache=True) # Loads from cache (instant!)
# Default is cache=False (no caching)
graph = Graph(coalescent, theta=2.0) # Always builds from scratchCache Inspection
from phasic import print_graph_cache_info, get_graph_cache_stats
# Print formatted stats
print_graph_cache_info()
# Output:
# ======================================================================
# GRAPH CACHE INFO
# ======================================================================
# Cache directory: /Users/you/.phasic_cache/graphs
# Cached graphs: 3
# Total size: 1.23 MB
# Programmatic access
stats = get_graph_cache_stats()
print(f"Cached: {stats['num_graphs']} graphs")Custom Class Serialization
To cache graphs with custom parameters, implement to_dict() and from_dict():
from phasic.state_indexing import StateIndexer, Property
class MyModel:
def __init__(self, param1, param2):
self.param1 = param1
self.param2 = param2
def to_dict(self):
return {'param1': self.param1, 'param2': self.param2}
@classmethod
def from_dict(cls, data):
return cls(data['param1'], data['param2'])
# Now cacheable!
model = MyModel(1.0, 2.0)
graph = Graph(callback, model=model, cache=True)
# StateIndexer already implements to_dict/from_dict
indexer = StateIndexer(lineage=[Property('descendants', max_value=10)])
graph = Graph(callback, indexer=indexer, cache=True) # Works!Trace Cache
What It Caches
The trace cache stores pre-computed elimination traces from graph elimination. These are used with hierarchical=True mode:
# Build graph with hierarchical traces
g = Graph(callback, hierarchical=True)
# First moment computation: Records trace (slow)
mean1 = g.moments()[0] # ~10s
# Second moment computation: Uses cached trace (fast!)
mean2 = g.moments()[0] # <1msCache Inspection
from phasic import print_trace_cache_info, get_trace_cache_stats
# Print formatted stats
print_trace_cache_info()
# Output:
# ======================================================================
# TRACE CACHE INFO
# ======================================================================
# Cache directory: /Users/you/.phasic_cache/traces
# Cached traces: 5
# Total size: 3.45 MB
# Programmatic access
stats = get_trace_cache_stats()
print(f"Cached: {stats['total_files']} traces")JAX Compilation Cache
What It Caches
JAX caches compiled XLA code based on: - Function structure (HLO graph) - Input shapes - Device configuration
Basic Configuration
from phasic.jax_config import CompilationConfig
# Balanced preset (default)
config = CompilationConfig.balanced()
config.apply()
# Maximum performance
config = CompilationConfig.max_performance()
config.apply()
# Fast compilation (for development)
config = CompilationConfig.fast_compile()
config.apply()Cache Location
Default: ~/.jax_cache
Override with environment variable:
export JAX_COMPILATION_CACHE_DIR=/fast/ssd/jax_cacheQuick Start
Single Machine Workflow
from phasic import Graph, print_all_cache_info, clear_caches
import jax.numpy as jnp
# 1. Build model with caching
@phasic.with_ipv([10])
def coalescent(state, theta=1.0):
n = state[0]
if n <= 1:
return []
return [[[n - 1], n * (n - 1) / 2 * theta]]
# Option A: Graph cache (expensive construction)
graph = Graph(coalescent, theta=2.0, cache=True) # Instant 2nd time!
# Option B: Hierarchical traces (repeated evaluations)
graph = Graph(coalescent, hierarchical=True) # Fast moments()
# 2. Check all caches
print_all_cache_info()
# 3. Clear all caches
clear_caches()Unified Cache API
phasic provides a consistent API for all cache types:
# ============================================================
# INSPECTION
# ============================================================
# Print formatted info
from phasic import (
print_all_cache_info, # All caches
print_jax_cache_info, # JAX only
print_graph_cache_info, # Graph only
print_trace_cache_info # Trace only
)
print_all_cache_info() # Shows JAX + Graph + Trace
# Get stats programmatically
from phasic import (
get_all_cache_stats, # All caches
cache_info, # JAX only (alias: get_jax_cache_stats)
get_graph_cache_stats, # Graph only
get_trace_cache_stats # Trace only
)
stats = get_all_cache_stats()
# Returns: {'jax': {...}, 'graph': {...}, 'trace': {...}}
# ============================================================
# CLEARING
# ============================================================
from phasic import (
clear_caches, # Clear ALL caches
clear_jax_cache, # JAX only
clear_model_cache # Graph + Trace (legacy name)
)
clear_caches() # Recommended: clears everythingAdvanced Usage
Low-Level Graph Cache API
from phasic import GraphCache
cache = GraphCache()
# Get or build graph
graph = cache.get_or_build(callback, theta=2.0, nr_samples=10)
# Manual cache management
cache.save_graph(graph, callback, theta=2.0)
graph = cache.load_graph(callback, theta=2.0)
# Stats
stats = cache.get_cache_stats()
# Clear
cache.clear_graph_cache()Distributing Caches
# On machine with pre-computed caches
cd ~/.phasic_cache
tar -czf my_caches.tar.gz graphs/ traces/
# Transfer to other machines
scp my_caches.tar.gz user@cluster:~
# On destination machine
cd ~/.phasic_cache
tar -xzf ~/my_caches.tar.gzBest Practices
Use cases
Graph Cache (cache=True): - Expensive graph construction (>1 second) - Same callback + parameters used repeatedly - Sharing pre-built graphs across sessions
Trace Cache (hierarchical=True): - Computing moments/expectations repeatedly - SVGD or MCMC with many iterations - Same graph structure, different parameters
JAX Cache: - Always enable (default) - JAX-based models (pmf_from_graph) - Production deployments
Cache Hygiene
# Regular cleanup
from phasic import get_all_cache_stats, clear_caches
stats = get_all_cache_stats()
# Check total size
total_mb = (
stats['jax']['total_size_mb'] +
stats['graph']['total_size_mb'] +
stats['trace']['total_mb']
)
if total_mb > 1000: # > 1GB
print("Warning: Caches are large")
# Clear old caches or specific types
clear_caches()Troubleshooting
Graph Cache Not Working?
Symptom: cache=True doesn’t save/load graphs
Solutions:
Check callback is hashable (no lambdas in interactive sessions):
# Won't cache: graph = Graph(lambda state: ..., cache=True) # No source code! # Will cache: def callback(state): ... graph = Graph(callback, cache=True) # Source code availableEnsure custom parameters implement
to_dict()/from_dict()Check cache directory exists and is writable:
from phasic import get_graph_cache_stats stats = get_graph_cache_stats() print(stats['cache_dir'])
Trace Cache Not Working?
Symptom: hierarchical=True recomputes every time
Solution: Check trace cache:
from phasic import print_trace_cache_info
print_trace_cache_info()Out of Disk Space?
from phasic import clear_caches
clear_caches(verbose=True) # Clear all caches