from phasic import (
Graph, with_ipv,
clear_caches, clear_jax_cache, clear_model_cache,
cache_info, get_all_cache_stats, print_all_cache_info,
get_graph_cache_stats, print_graph_cache_info,
get_trace_cache_stats, print_trace_cache_info,
set_log_level
)
from phasic.trace_cache import list_cached_traces, cleanup_old_traces
import numpy as np
import time
from vscodenb import set_vscode_theme
set_vscode_theme()Multi-level caching
Phasic uses a three-layer caching system to avoid repeating expensive computations. Each layer targets a different bottleneck in the pipeline from model definition to inference:
| Layer | What is cached | Location | Speedup |
|---|---|---|---|
| Graph cache | Fully constructed Graph objects |
~/.phasic_cache/graphs/ |
Avoids callback-based construction |
| Trace cache | Elimination traces from Gaussian elimination | ~/.phasic_cache/traces/ |
Avoids O(n³) elimination |
| JAX compilation cache | JIT-compiled XLA code | ~/.jax_cache/ |
Avoids recompilation on restart |
All three caches are persistent — they survive across Python sessions and restarts. Cache correctness is ensured by SHA-256 content hashing: the same graph structure always produces the same hash, and any structural change automatically invalidates the entry.
We will use a simple coalescent model throughout this tutorial:
nr_samples = 4
@with_ipv([nr_samples] + [0] * (nr_samples - 1))
def coalescent(state):
transitions = []
for i in range(state.size):
for j in range(i, state.size):
same = int(i == j)
if same and state[i] < 2:
continue
if not same and (state[i] < 1 or state[j] < 1):
continue
new = state.copy()
new[i] -= 1
new[j] -= 1
new[i + j + 1] += 1
transitions.append((new, state[i] * (state[j] - same) / (1 + same)))
return transitionsStart from a clean slate so the examples below show cache misses and hits clearly:
clear_caches(verbose=True)Layer 1: Graph cache
Building a graph from a callback function requires exploring the full state space, creating vertices and edges, and can take seconds to minutes for large models. The graph cache stores fully constructed Graph objects on disk so that the same model can be loaded instantly on subsequent calls.
The cache key is a SHA-256 hash of:
- The callback function’s AST (abstract syntax tree), so whitespace/comment changes are ignored but code changes invalidate the cache
- All construction parameters (
ipv,nr_samples, keyword arguments)
Enable the graph cache by passing cache_graph=True to Graph():
# First build: constructs graph from callback and saves to cache
start = time.time()
graph = Graph(coalescent, cache_graph=True)
t1 = time.time() - start
print(f"First build: {t1*1000:.1f} ms ({graph.vertices_length()} vertices)")
# Second build: loaded from cache
start = time.time()
graph = Graph(coalescent, cache_graph=True)
t2 = time.time() - start
print(f"Second build: {t2*1000:.1f} ms (from cache)")
if t2 > 0:
print(f"Speedup: {t1/t2:.1f}x")If you modify the callback function or pass different parameters, the cache misses and the graph is rebuilt:
# Different sample size → different hash → cache miss
graph_5 = Graph(coalescent_5_samples, cache_graph=True) # builds freshThe graph cache is most useful for large models where construction takes seconds or more.
Layer 2: Trace cache
When computing moments or running SVGD inference, phasic performs Gaussian elimination on the graph to record an elimination trace — a linear sequence of operations that can be replayed cheaply with different parameter values. Recording the trace is O(n³) and is the most expensive step for large models.
The trace cache stores these elimination traces on disk, keyed by a SHA-256 hash of the graph structure (vertices, edges, rates — but not the specific parameter values). Enable trace caching by passing cache_trace=True to Graph():
# Build graph with trace caching enabled
graph = Graph(coalescent, cache_trace=True)
# First moments call: records elimination trace and caches it
start = time.time()
m1 = graph.moments(2)
t1 = time.time() - start
print(f"First moments(): {t1*1000:.1f} ms")
# Second moments call: reuses cached trace
start = time.time()
m2 = graph.moments(2)
t2 = time.time() - start
print(f"Second moments(): {t2*1000:.1f} ms")
print(f"\nFirst two moments: {m1}")Because the trace cache is persistent on disk (~/.phasic_cache/traces/), the cached trace is available even if you restart Python and construct the same graph structure again. This makes the trace cache especially valuable for iterative development and parameter exploration.
Layer 3: JAX compilation cache
When running SVGD inference, JAX JIT-compiles the log-likelihood, kernel, and update functions the first time they are called. This compilation can take 1–10 seconds. The JAX compilation cache stores the compiled XLA code on disk so that subsequent Python sessions skip recompilation entirely.
This cache is managed by JAX itself and is enabled automatically by phasic at import time. The cache key is based on the function structure and input shapes (not values), so different parameter vectors reuse the same compiled code.
Configuration
The default cache directory is ~/.jax_cache/. You can change it via environment variable before importing JAX:
export JAX_COMPILATION_CACHE_DIR=/fast/ssd/jax_cacheOr programmatically with the CompilationConfig class:
from phasic import CompilationConfig
config = CompilationConfig.balanced() # sensible defaults
config.apply()Inspecting caches
Phasic provides a unified API for inspecting all three cache layers.
Overview of all caches
print_all_cache_info()Individual cache layers
Each layer has its own inspection functions:
# Graph cache
print_graph_cache_info()# Trace cache
print_trace_cache_info()# JAX compilation cache
jax_info = cache_info()
print(f"JAX cache: {jax_info['num_files']} files, {jax_info['total_size_mb']:.1f} MB")Programmatic access
For scripting, get_all_cache_stats() returns a dictionary with statistics for each layer:
stats = get_all_cache_stats()
for name, layer_stats in stats.items():
print(f"{name}: {layer_stats}")Listing individual trace entries
You can list the cached traces with metadata about each entry:
for entry in list_cached_traces():
print(f"Hash: {entry['hash'][:16]}... "
f"Size: {entry.get('size_kb', 0):.1f} KB "
f"Vertices: {entry.get('n_vertices', 'N/A')}")Clearing caches
Phasic provides several functions for clearing caches at different granularities:
| Function | What it clears |
|---|---|
clear_caches() |
All three cache layers |
clear_model_cache() |
Graph cache + trace cache |
clear_jax_cache() |
JAX compilation cache only |
# Clear only model-related caches (graph + trace)
clear_model_cache(verbose=True)
# Verify they are empty
print(f"\nGraph cache: {get_graph_cache_stats()['num_graphs']} graphs")
print(f"Trace cache: {get_trace_cache_stats()['total_files']} files")Selective cleanup
For production environments where caches grow over time, you can prune old or oversized entries from the trace cache:
# Remove traces older than 30 days or enforce a 100 MB size limit
removed = cleanup_old_traces(max_size_mb=100.0, max_age_days=30)
print(f"Removed {removed} old trace entries")You can also remove a specific trace by hash:
from phasic.trace_cache import remove_cached_trace
removed = remove_cached_trace("abc123def456...") # Returns True/FalseOr clear everything from the command line:
# Clear all phasic caches
rm -rf ~/.phasic_cache/ ~/.jax_cache/Cache logging
To see cache hits and misses in real time, enable debug-level logging:
set_log_level('DEBUG')
# This will log cache hit/miss for graph and trace
graph = Graph(coalescent, cache_graph=True, cache_trace=True)
_ = graph.moments(2)
set_log_level('WARNING')Combined caching for inference
All three cache layers work together transparently during SVGD inference. On a cold start (empty caches), the first run triggers graph construction, trace recording, and JIT compilation. On a warm start (populated caches), the same code runs orders of magnitude faster:
# Build graph with both cache layers enabled
graph = Graph(coalescent, cache_graph=True, cache_trace=True)
# Simulate data
graph.update_weights([7.0])
observed_data = graph.sample(1000)
# First SVGD run: populates JAX compilation cache
start = time.time()
svgd1 = graph.svgd(observed_data, n_iterations=50)
t1 = time.time() - start
print(f"First SVGD run: {t1:.1f}s")
# Second SVGD run: all caches populated
start = time.time()
svgd2 = graph.svgd(observed_data, n_iterations=50)
t2 = time.time() - start
print(f"Second SVGD run: {t2:.1f}s")Summary
| Cache layer | Enable with | Inspect | Clear |
|---|---|---|---|
| Graph cache | Graph(..., cache_graph=True) |
print_graph_cache_info() |
clear_model_cache() |
| Trace cache | Graph(..., cache_trace=True) |
print_trace_cache_info() |
clear_model_cache() |
| JAX cache | Automatic at import | cache_info() |
clear_jax_cache() |
| All | print_all_cache_info() |
clear_caches() |
All caches use content-addressable hashing (SHA-256), so they invalidate automatically when the underlying model changes. There is no risk of stale results — if the code or parameters change, the hash changes, and the cache misses.