Caching

This covers the comprehensive caching system in phasic, which provides 10-1000x speedups for repeated model evaluations through multiple complementary caching mechanisms.

Graph Cache

What It Caches

The graph cache stores fully constructed Graph objects to disk, keyed by: - Callback function source code - All construction parameters - Initial probability vector (IPV)

This is perfect for expensive graph constructions that take seconds or minutes to build.

Basic Usage

from phasic import Graph
import numpy as np

# Define model callback
@phasic.with_ipv([10])
def coalescent(state, theta=1.0):
    n = state[0]
    if n <= 1:
        return []
    rate = n * (n - 1) / 2 * theta
    return [[[n - 1], rate]]

# Enable caching with cache=True
graph = Graph(coalescent, theta=2.0, cache=True)  # Builds + caches (slow)
graph = Graph(coalescent, theta=2.0, cache=True)  # Loads from cache (instant!)

# Default is cache=False (no caching)
graph = Graph(coalescent, theta=2.0)  # Always builds from scratch

Cache Inspection

from phasic import print_graph_cache_info, get_graph_cache_stats

# Print formatted stats
print_graph_cache_info()
# Output:
# ======================================================================
# GRAPH CACHE INFO
# ======================================================================
# Cache directory: /Users/you/.phasic_cache/graphs
# Cached graphs: 3
# Total size: 1.23 MB

# Programmatic access
stats = get_graph_cache_stats()
print(f"Cached: {stats['num_graphs']} graphs")

Custom Class Serialization

To cache graphs with custom parameters, implement to_dict() and from_dict():

from phasic.state_indexing import StateIndexer, Property

class MyModel:
    def __init__(self, param1, param2):
        self.param1 = param1
        self.param2 = param2

    def to_dict(self):
        return {'param1': self.param1, 'param2': self.param2}

    @classmethod
    def from_dict(cls, data):
        return cls(data['param1'], data['param2'])

# Now cacheable!
model = MyModel(1.0, 2.0)
graph = Graph(callback, model=model, cache=True)

# StateIndexer already implements to_dict/from_dict
indexer = StateIndexer(lineage=[Property('descendants', max_value=10)])
graph = Graph(callback, indexer=indexer, cache=True)  # Works!

Trace Cache

What It Caches

The trace cache stores pre-computed elimination traces from graph elimination. These are used with hierarchical=True mode:

# Build graph with hierarchical traces
g = Graph(callback, hierarchical=True)

# First moment computation: Records trace (slow)
mean1 = g.moments()[0]  # ~10s

# Second moment computation: Uses cached trace (fast!)
mean2 = g.moments()[0]  # <1ms

Cache Inspection

from phasic import print_trace_cache_info, get_trace_cache_stats

# Print formatted stats
print_trace_cache_info()
# Output:
# ======================================================================
# TRACE CACHE INFO
# ======================================================================
# Cache directory: /Users/you/.phasic_cache/traces
# Cached traces: 5
# Total size: 3.45 MB

# Programmatic access
stats = get_trace_cache_stats()
print(f"Cached: {stats['total_files']} traces")

JAX Compilation Cache

What It Caches

JAX caches compiled XLA code based on: - Function structure (HLO graph) - Input shapes - Device configuration

Basic Configuration

from phasic.jax_config import CompilationConfig

# Balanced preset (default)
config = CompilationConfig.balanced()
config.apply()

# Maximum performance
config = CompilationConfig.max_performance()
config.apply()

# Fast compilation (for development)
config = CompilationConfig.fast_compile()
config.apply()

Cache Location

Default: ~/.jax_cache

Override with environment variable:

export JAX_COMPILATION_CACHE_DIR=/fast/ssd/jax_cache

Quick Start

Single Machine Workflow

from phasic import Graph, print_all_cache_info, clear_caches
import jax.numpy as jnp

# 1. Build model with caching
@phasic.with_ipv([10])
def coalescent(state, theta=1.0):
    n = state[0]
    if n <= 1:
        return []
    return [[[n - 1], n * (n - 1) / 2 * theta]]

# Option A: Graph cache (expensive construction)
graph = Graph(coalescent, theta=2.0, cache=True)  # Instant 2nd time!

# Option B: Hierarchical traces (repeated evaluations)
graph = Graph(coalescent, hierarchical=True)  # Fast moments()

# 2. Check all caches
print_all_cache_info()

# 3. Clear all caches
clear_caches()

Unified Cache API

phasic provides a consistent API for all cache types:

# ============================================================
# INSPECTION
# ============================================================

# Print formatted info
from phasic import (
    print_all_cache_info,      # All caches
    print_jax_cache_info,       # JAX only
    print_graph_cache_info,     # Graph only
    print_trace_cache_info      # Trace only
)

print_all_cache_info()  # Shows JAX + Graph + Trace

# Get stats programmatically
from phasic import (
    get_all_cache_stats,        # All caches
    cache_info,                 # JAX only (alias: get_jax_cache_stats)
    get_graph_cache_stats,      # Graph only
    get_trace_cache_stats       # Trace only
)

stats = get_all_cache_stats()
# Returns: {'jax': {...}, 'graph': {...}, 'trace': {...}}

# ============================================================
# CLEARING
# ============================================================

from phasic import (
    clear_caches,         # Clear ALL caches
    clear_jax_cache,      # JAX only
    clear_model_cache     # Graph + Trace (legacy name)
)

clear_caches()  # Recommended: clears everything

Advanced Usage

Low-Level Graph Cache API

from phasic import GraphCache

cache = GraphCache()

# Get or build graph
graph = cache.get_or_build(callback, theta=2.0, nr_samples=10)

# Manual cache management
cache.save_graph(graph, callback, theta=2.0)
graph = cache.load_graph(callback, theta=2.0)

# Stats
stats = cache.get_cache_stats()

# Clear
cache.clear_graph_cache()

Distributing Caches

# On machine with pre-computed caches
cd ~/.phasic_cache
tar -czf my_caches.tar.gz graphs/ traces/

# Transfer to other machines
scp my_caches.tar.gz user@cluster:~

# On destination machine
cd ~/.phasic_cache
tar -xzf ~/my_caches.tar.gz

Best Practices

Use cases

Graph Cache (cache=True): - Expensive graph construction (>1 second) - Same callback + parameters used repeatedly - Sharing pre-built graphs across sessions - ❌ Rapidly changing callbacks (cache invalidation overhead)

Trace Cache (hierarchical=True): - Computing moments/expectations repeatedly - SVGD or MCMC with many iterations - Same graph structure, different parameters - ❌ Single-use graphs

JAX Cache: - Always enable (default) - JAX-based models (pmf_from_graph) - Production deployments

Cache Hygiene

# Regular cleanup
from phasic import get_all_cache_stats, clear_caches

stats = get_all_cache_stats()

# Check total size
total_mb = (
    stats['jax']['total_size_mb'] +
    stats['graph']['total_size_mb'] +
    stats['trace']['total_mb']
)

if total_mb > 1000:  # > 1GB
    print("Warning: Caches are large")
    # Clear old caches or specific types
    clear_caches()

Troubleshooting

Graph Cache Not Working?

Symptom: cache=True doesn’t save/load graphs

Solutions:

  1. Check callback is hashable (no lambdas in interactive sessions):

    # Won't cache:
    graph = Graph(lambda state: ..., cache=True)  # No source code!
    
    # Will cache:
    def callback(state):
        ...
    graph = Graph(callback, cache=True)  # Source code available
  2. Ensure custom parameters implement to_dict() / from_dict()

  3. Check cache directory exists and is writable:

    from phasic import get_graph_cache_stats
    stats = get_graph_cache_stats()
    print(stats['cache_dir'])

Trace Cache Not Working?

Symptom: hierarchical=True recomputes every time

Solution: Check trace cache:

from phasic import print_trace_cache_info
print_trace_cache_info()

Out of Disk Space?

from phasic import clear_caches
clear_caches(verbose=True)  # Clear all caches

Performance Benchmarks

Graph Cache Impact

Model Size No Cache With Cache Speedup
100 vertices 2s instant ∞ (cached)
1,000 vertices 15s instant ∞ (cached)
10,000 vertices 120s instant ∞ (cached)

Combined Impact

Workflow: MCMC with 1,000 iterations on 100-vertex graph

  • No caching: 2s build × 1,000 = 33 minutes
  • Graph cache: instant build + 1,000 × fast eval = 10 seconds
  • Both caches: instant build + instant eval = <1 second

Total speedup: ~2,000x