Caching

This page covers the comprehensive caching system in phasic, which provides 10-1000x speedups for repeated model evaluations through multiple complementary caching mechanisms.

Graph Cache

The graph cache allow fast reconstruction of graphs that take a long time to build. The cache is keyed by the callback function source code with all its construction parameters and initial probability vector (IPV). To enable the use of cache, just pass cache=True to the Graph constructor:

Basic Usage

from phasic import Graph
import numpy as np

# Define model callback
@phasic.with_ipv([10])
def coalescent(state, theta=1.0):
    n = state[0]
    if n <= 1:
        return []
    rate = n * (n - 1) / 2 * theta
    return [[[n - 1], rate]]

# Enable caching with cache=True
graph = Graph(coalescent, theta=2.0, cache=True)  # Builds + caches (slow)
graph = Graph(coalescent, theta=2.0, cache=True)  # Loads from cache (instant!)

# Default is cache=False (no caching)
graph = Graph(coalescent, theta=2.0)  # Always builds from scratch

Cache Inspection

from phasic import print_graph_cache_info, get_graph_cache_stats

# Print formatted stats
print_graph_cache_info()
# Output:
# ======================================================================
# GRAPH CACHE INFO
# ======================================================================
# Cache directory: /Users/you/.phasic_cache/graphs
# Cached graphs: 3
# Total size: 1.23 MB

# Programmatic access
stats = get_graph_cache_stats()
print(f"Cached: {stats['num_graphs']} graphs")

Custom Class Serialization

To cache graphs with custom parameters, implement to_dict() and from_dict():

from phasic.state_indexing import StateIndexer, Property

class MyModel:
    def __init__(self, param1, param2):
        self.param1 = param1
        self.param2 = param2

    def to_dict(self):
        return {'param1': self.param1, 'param2': self.param2}

    @classmethod
    def from_dict(cls, data):
        return cls(data['param1'], data['param2'])

# Now cacheable!
model = MyModel(1.0, 2.0)
graph = Graph(callback, model=model, cache=True)

# StateIndexer already implements to_dict/from_dict
indexer = StateIndexer(lineage=[Property('descendants', max_value=10)])
graph = Graph(callback, indexer=indexer, cache=True)  # Works!

Trace Cache

What It Caches

The trace cache stores pre-computed elimination traces from graph elimination. These are used with hierarchical=True mode:

# Build graph with hierarchical traces
g = Graph(callback, hierarchical=True)

# First moment computation: Records trace (slow)
mean1 = g.moments()[0]  # ~10s

# Second moment computation: Uses cached trace (fast!)
mean2 = g.moments()[0]  # <1ms

Cache Inspection

from phasic import print_trace_cache_info, get_trace_cache_stats

# Print formatted stats
print_trace_cache_info()
# Output:
# ======================================================================
# TRACE CACHE INFO
# ======================================================================
# Cache directory: /Users/you/.phasic_cache/traces
# Cached traces: 5
# Total size: 3.45 MB

# Programmatic access
stats = get_trace_cache_stats()
print(f"Cached: {stats['total_files']} traces")

JAX Compilation Cache

What It Caches

JAX caches compiled XLA code based on: - Function structure (HLO graph) - Input shapes - Device configuration

Basic Configuration

from phasic.jax_config import CompilationConfig

# Balanced preset (default)
config = CompilationConfig.balanced()
config.apply()

# Maximum performance
config = CompilationConfig.max_performance()
config.apply()

# Fast compilation (for development)
config = CompilationConfig.fast_compile()
config.apply()

Cache Location

Default: ~/.jax_cache

Override with environment variable:

export JAX_COMPILATION_CACHE_DIR=/fast/ssd/jax_cache

Quick Start

Single Machine Workflow

from phasic import Graph, print_all_cache_info, clear_caches
import jax.numpy as jnp

# 1. Build model with caching
@phasic.with_ipv([10])
def coalescent(state, theta=1.0):
    n = state[0]
    if n <= 1:
        return []
    return [[[n - 1], n * (n - 1) / 2 * theta]]

# Option A: Graph cache (expensive construction)
graph = Graph(coalescent, theta=2.0, cache=True)  # Instant 2nd time!

# Option B: Hierarchical traces (repeated evaluations)
graph = Graph(coalescent, hierarchical=True)  # Fast moments()

# 2. Check all caches
print_all_cache_info()

# 3. Clear all caches
clear_caches()

Unified Cache API

phasic provides a consistent API for all cache types:

# ============================================================
# INSPECTION
# ============================================================

# Print formatted info
from phasic import (
    print_all_cache_info,      # All caches
    print_jax_cache_info,       # JAX only
    print_graph_cache_info,     # Graph only
    print_trace_cache_info      # Trace only
)

print_all_cache_info()  # Shows JAX + Graph + Trace

# Get stats programmatically
from phasic import (
    get_all_cache_stats,        # All caches
    cache_info,                 # JAX only (alias: get_jax_cache_stats)
    get_graph_cache_stats,      # Graph only
    get_trace_cache_stats       # Trace only
)

stats = get_all_cache_stats()
# Returns: {'jax': {...}, 'graph': {...}, 'trace': {...}}

# ============================================================
# CLEARING
# ============================================================

from phasic import (
    clear_caches,         # Clear ALL caches
    clear_jax_cache,      # JAX only
    clear_model_cache     # Graph + Trace (legacy name)
)

clear_caches()  # Recommended: clears everything

Advanced Usage

Low-Level Graph Cache API

from phasic import GraphCache

cache = GraphCache()

# Get or build graph
graph = cache.get_or_build(callback, theta=2.0, nr_samples=10)

# Manual cache management
cache.save_graph(graph, callback, theta=2.0)
graph = cache.load_graph(callback, theta=2.0)

# Stats
stats = cache.get_cache_stats()

# Clear
cache.clear_graph_cache()

Distributing Caches

# On machine with pre-computed caches
cd ~/.phasic_cache
tar -czf my_caches.tar.gz graphs/ traces/

# Transfer to other machines
scp my_caches.tar.gz user@cluster:~

# On destination machine
cd ~/.phasic_cache
tar -xzf ~/my_caches.tar.gz

Best Practices

Use cases

Graph Cache (cache=True): - Expensive graph construction (>1 second) - Same callback + parameters used repeatedly - Sharing pre-built graphs across sessions

Trace Cache (hierarchical=True): - Computing moments/expectations repeatedly - SVGD or MCMC with many iterations - Same graph structure, different parameters

JAX Cache: - Always enable (default) - JAX-based models (pmf_from_graph) - Production deployments

Cache Hygiene

# Regular cleanup
from phasic import get_all_cache_stats, clear_caches

stats = get_all_cache_stats()

# Check total size
total_mb = (
    stats['jax']['total_size_mb'] +
    stats['graph']['total_size_mb'] +
    stats['trace']['total_mb']
)

if total_mb > 1000:  # > 1GB
    print("Warning: Caches are large")
    # Clear old caches or specific types
    clear_caches()

Troubleshooting

Graph Cache Not Working?

Symptom: cache=True doesn’t save/load graphs

Solutions:

  1. Check callback is hashable (no lambdas in interactive sessions):

    # Won't cache:
    graph = Graph(lambda state: ..., cache=True)  # No source code!
    
    # Will cache:
    def callback(state):
        ...
    graph = Graph(callback, cache=True)  # Source code available
  2. Ensure custom parameters implement to_dict() / from_dict()

  3. Check cache directory exists and is writable:

    from phasic import get_graph_cache_stats
    stats = get_graph_cache_stats()
    print(stats['cache_dir'])

Trace Cache Not Working?

Symptom: hierarchical=True recomputes every time

Solution: Check trace cache:

from phasic import print_trace_cache_info
print_trace_cache_info()

Out of Disk Space?

from phasic import clear_caches
clear_caches(verbose=True)  # Clear all caches