Multi-level caching

Phasic uses a three-layer caching system to avoid repeating expensive computations. Each layer targets a different bottleneck in the pipeline from model definition to inference:

Layer What is cached Location Speedup
Graph cache Fully constructed Graph objects ~/.phasic_cache/graphs/ Avoids callback-based construction
Trace cache Elimination traces from Gaussian elimination ~/.phasic_cache/traces/ Avoids O(n³) elimination
JAX compilation cache JIT-compiled XLA code ~/.jax_cache/ Avoids recompilation on restart

All three caches are persistent — they survive across Python sessions and restarts. Cache correctness is ensured by SHA-256 content hashing: the same graph structure always produces the same hash, and any structural change automatically invalidates the entry.

from phasic import (
    Graph, with_ipv,
    clear_caches, clear_jax_cache, clear_model_cache,
    cache_info, get_all_cache_stats, print_all_cache_info,
    get_graph_cache_stats, print_graph_cache_info,
    get_trace_cache_stats, print_trace_cache_info,
    set_log_level
)
from phasic.trace_cache import list_cached_traces, cleanup_old_traces
import numpy as np
import time
from vscodenb import set_vscode_theme

set_vscode_theme()

We will use a simple coalescent model throughout this tutorial:

nr_samples = 4

@with_ipv([nr_samples] + [0] * (nr_samples - 1))
def coalescent(state):
    transitions = []
    for i in range(state.size):
        for j in range(i, state.size):
            same = int(i == j)
            if same and state[i] < 2:
                continue
            if not same and (state[i] < 1 or state[j] < 1):
                continue
            new = state.copy()
            new[i] -= 1
            new[j] -= 1
            new[i + j + 1] += 1
            transitions.append((new, state[i] * (state[j] - same) / (1 + same)))
    return transitions

Start from a clean slate so the examples below show cache misses and hits clearly:

clear_caches(verbose=True)

Layer 1: Graph cache

Building a graph from a callback function requires exploring the full state space, creating vertices and edges, and can take seconds to minutes for large models. The graph cache stores fully constructed Graph objects on disk so that the same model can be loaded instantly on subsequent calls.

The cache key is a SHA-256 hash of:

  • The callback function’s AST (abstract syntax tree), so whitespace/comment changes are ignored but code changes invalidate the cache
  • All construction parameters (ipv, nr_samples, keyword arguments)

Enable the graph cache by passing cache_graph=True to Graph():

# First build: constructs graph from callback and saves to cache
start = time.time()
graph = Graph(coalescent, cache_graph=True)
t1 = time.time() - start
print(f"First build:  {t1*1000:.1f} ms  ({graph.vertices_length()} vertices)")

# Second build: loaded from cache
start = time.time()
graph = Graph(coalescent, cache_graph=True)
t2 = time.time() - start
print(f"Second build: {t2*1000:.1f} ms  (from cache)")

if t2 > 0:
    print(f"Speedup: {t1/t2:.1f}x")

If you modify the callback function or pass different parameters, the cache misses and the graph is rebuilt:

# Different sample size → different hash → cache miss
graph_5 = Graph(coalescent_5_samples, cache_graph=True)  # builds fresh

The graph cache is most useful for large models where construction takes seconds or more.

Layer 2: Trace cache

When computing moments or running SVGD inference, phasic performs Gaussian elimination on the graph to record an elimination trace — a linear sequence of operations that can be replayed cheaply with different parameter values. Recording the trace is O(n³) and is the most expensive step for large models.

The trace cache stores these elimination traces on disk, keyed by a SHA-256 hash of the graph structure (vertices, edges, rates — but not the specific parameter values). Enable trace caching by passing cache_trace=True to Graph():

# Build graph with trace caching enabled
graph = Graph(coalescent, cache_trace=True)

# First moments call: records elimination trace and caches it
start = time.time()
m1 = graph.moments(2)
t1 = time.time() - start
print(f"First moments():  {t1*1000:.1f} ms")

# Second moments call: reuses cached trace
start = time.time()
m2 = graph.moments(2)
t2 = time.time() - start
print(f"Second moments(): {t2*1000:.1f} ms")

print(f"\nFirst two moments: {m1}")

Because the trace cache is persistent on disk (~/.phasic_cache/traces/), the cached trace is available even if you restart Python and construct the same graph structure again. This makes the trace cache especially valuable for iterative development and parameter exploration.

Layer 3: JAX compilation cache

When running SVGD inference, JAX JIT-compiles the log-likelihood, kernel, and update functions the first time they are called. This compilation can take 1–10 seconds. The JAX compilation cache stores the compiled XLA code on disk so that subsequent Python sessions skip recompilation entirely.

This cache is managed by JAX itself and is enabled automatically by phasic at import time. The cache key is based on the function structure and input shapes (not values), so different parameter vectors reuse the same compiled code.

Configuration

The default cache directory is ~/.jax_cache/. You can change it via environment variable before importing JAX:

export JAX_COMPILATION_CACHE_DIR=/fast/ssd/jax_cache

Or programmatically with the CompilationConfig class:

from phasic import CompilationConfig

config = CompilationConfig.balanced()   # sensible defaults
config.apply()

Inspecting caches

Phasic provides a unified API for inspecting all three cache layers.

Overview of all caches

print_all_cache_info()

Individual cache layers

Each layer has its own inspection functions:

# Graph cache
print_graph_cache_info()
# Trace cache
print_trace_cache_info()
# JAX compilation cache
jax_info = cache_info()
print(f"JAX cache: {jax_info['num_files']} files, {jax_info['total_size_mb']:.1f} MB")

Programmatic access

For scripting, get_all_cache_stats() returns a dictionary with statistics for each layer:

stats = get_all_cache_stats()
for name, layer_stats in stats.items():
    print(f"{name}: {layer_stats}")

Listing individual trace entries

You can list the cached traces with metadata about each entry:

for entry in list_cached_traces():
    print(f"Hash: {entry['hash'][:16]}...  "
          f"Size: {entry.get('size_kb', 0):.1f} KB  "
          f"Vertices: {entry.get('n_vertices', 'N/A')}")

Clearing caches

Phasic provides several functions for clearing caches at different granularities:

Function What it clears
clear_caches() All three cache layers
clear_model_cache() Graph cache + trace cache
clear_jax_cache() JAX compilation cache only
# Clear only model-related caches (graph + trace)
clear_model_cache(verbose=True)

# Verify they are empty
print(f"\nGraph cache: {get_graph_cache_stats()['num_graphs']} graphs")
print(f"Trace cache: {get_trace_cache_stats()['total_files']} files")

Selective cleanup

For production environments where caches grow over time, you can prune old or oversized entries from the trace cache:

# Remove traces older than 30 days or enforce a 100 MB size limit
removed = cleanup_old_traces(max_size_mb=100.0, max_age_days=30)
print(f"Removed {removed} old trace entries")

You can also remove a specific trace by hash:

from phasic.trace_cache import remove_cached_trace
removed = remove_cached_trace("abc123def456...")  # Returns True/False

Or clear everything from the command line:

# Clear all phasic caches
rm -rf ~/.phasic_cache/ ~/.jax_cache/

Cache logging

To see cache hits and misses in real time, enable debug-level logging:

set_log_level('DEBUG')

# This will log cache hit/miss for graph and trace
graph = Graph(coalescent, cache_graph=True, cache_trace=True)
_ = graph.moments(2)

set_log_level('WARNING')

Combined caching for inference

All three cache layers work together transparently during SVGD inference. On a cold start (empty caches), the first run triggers graph construction, trace recording, and JIT compilation. On a warm start (populated caches), the same code runs orders of magnitude faster:

# Build graph with both cache layers enabled
graph = Graph(coalescent, cache_graph=True, cache_trace=True)

# Simulate data
graph.update_weights([7.0])
observed_data = graph.sample(1000)

# First SVGD run: populates JAX compilation cache
start = time.time()
svgd1 = graph.svgd(observed_data, n_iterations=50)
t1 = time.time() - start
print(f"First SVGD run:  {t1:.1f}s")

# Second SVGD run: all caches populated
start = time.time()
svgd2 = graph.svgd(observed_data, n_iterations=50)
t2 = time.time() - start
print(f"Second SVGD run: {t2:.1f}s")

Summary

Cache layer Enable with Inspect Clear
Graph cache Graph(..., cache_graph=True) print_graph_cache_info() clear_model_cache()
Trace cache Graph(..., cache_trace=True) print_trace_cache_info() clear_model_cache()
JAX cache Automatic at import cache_info() clear_jax_cache()
All print_all_cache_info() clear_caches()

All caches use content-addressable hashing (SHA-256), so they invalidate automatically when the underlying model changes. There is no risk of stale results — if the code or parameters change, the hash changes, and the cache misses.