from phasic import (
Graph, Property, StateIndexer, set_log_level,
clear_caches, clear_model_cache,
cache_info, get_all_cache_stats, print_all_cache_info,
get_graph_cache_stats, print_graph_cache_info,
get_trace_cache_stats, print_trace_cache_info,
)
from phasic.trace_cache import list_cached_traces, cleanup_old_traces
import numpy as np
import time
%config InlineBackend.figure_format = 'svg'
from vscodenb import set_vscode_theme
set_vscode_theme()Multi-level caching
Phasic uses a three-layer caching system to avoid repeating expensive computations. Each layer targets a different bottleneck in the pipeline from model definition to inference:
| Layer | What is cached | Location | Speedup |
|---|---|---|---|
| Graph cache | Fully constructed Graph objects |
~/.phasic_cache/graphs/ |
Avoids callback-based construction |
| Trace cache | Elimination traces from Gaussian elimination | ~/.phasic_cache/traces/ |
Avoids O(n³) elimination |
| JAX compilation cache | JIT-compiled XLA code | ~/.jax_cache/ |
Avoids recompilation on restart |
All three caches are persistent — they survive across Python sessions and restarts. Cache correctness is ensured by SHA-256 content hashing: the same graph structure always produces the same hash, and any structural change automatically invalidates the entry.
We use the ARG with two parameters as example model. I have added a dummy keyword arg (that does nothing) for demonstration purposes:
nr_samples = 6
indexer = StateIndexer(descendants=[
Property('loc1', max_value=nr_samples),
Property('loc2', max_value=nr_samples)
])
initial = [0] * indexer.state_length
initial[indexer.props_to_index(loc1=1, loc2=1)] = nr_samples
def two_locus_arg_2param(state, indexer=None, dummy=None):
transitions = []
if state.sum() <= 1: return transitions
for i in range(indexer.state_length):
if state[i] == 0: continue
pi = indexer.index_to_props(i)
for j in range(i, indexer.state_length):
if state[j] == 0: continue
pj = indexer.index_to_props(j)
same = int(i == j)
if same and state[i] < 2:
continue
if not same and (state[i] < 1 or state[j] < 1):
continue
child = state.copy()
child[i] -= 1
child[j] -= 1
loc1 = pi.descendants.loc1 + pj.descendants.loc1
loc2 = pi.descendants.loc2 + pj.descendants.loc2
if loc1 <= nr_samples and loc2 <= nr_samples:
child[indexer.props_to_index(loc1=loc1, loc2=loc2)] += 1
transitions.append([child, [state[i]*(state[j]-same)/(1+same), 0]])
if state[i] > 0 and pi.descendants.loc1 > 0 and pi.descendants.loc2 > 0:
child = state.copy()
child[i] -= 1
child[indexer.props_to_index(loc1=pi.descendants.loc1, loc2=0)] += 1
child[indexer.props_to_index(loc1=0, loc2=pi.descendants.loc2)] += 1
transitions.append([child, [0, 1]])
return transitionsStart from a clean slate and enable info logging so the examples below show cache misses and hits clearly:
set_log_level('INFO')
clear_caches(verbose=True)[INFO] phasic.graph_cache: Cleared 0 cached graphs
Removed 15 file(s), preserved directory structure
Graph cache
Building a graph from a callback function requires exploring the full state space, creating vertices and edges, and can take seconds to minutes for large models. The graph cache stores fully constructed Graph objects on disk so that the same model can be loaded instantly on subsequent calls.
The cache key is a SHA-256 hash of:
- The callback function’s AST (abstract syntax tree), so whitespace/comment changes are ignored but code changes invalidate the cache
- All construction parameters (
ipv,nr_samples, keyword arguments)
Enable the graph cache by passing graph_cache=True to Graph(). First build constructs graph from callback and saves to cache:
%%time
graph = Graph(two_locus_arg_2param, ipv=initial, indexer=indexer,
graph_cache=True)[INFO] phasic.graph_cache: Saved graph to cache: 7d8de78485fc9b6b... (1044 vertices)
[INFO] phasic: Saved graph to cache: 1044 vertices
CPU times: user 3.77 s, sys: 23.4 ms, total: 3.79 s
Wall time: 3.78 s
Second build is loaded from cache:
%%time
graph = Graph(two_locus_arg_2param, ipv=initial, indexer=indexer,
graph_cache=True, dummy=42)[INFO] phasic.graph_cache: Saved graph to cache: 1c6b4850af250c3a... (1044 vertices)
[INFO] phasic: Saved graph to cache: 1044 vertices
CPU times: user 3.83 s, sys: 27.5 ms, total: 3.86 s
Wall time: 3.85 s
If you modify the callback function or pass different parameters, the cache misses and the graph is rebuilt. Even though our dummy keyword arg does nothing, passing a new value triggers a rebuild of the graph:
%%time
graph = Graph(two_locus_arg_2param, ipv=initial, indexer=indexer,
graph_cache=True, dummy=99)[INFO] phasic.graph_cache: Saved graph to cache: d3cee6407f2ba48f... (1044 vertices)
[INFO] phasic: Saved graph to cache: 1044 vertices
CPU times: user 3.8 s, sys: 29.5 ms, total: 3.83 s
Wall time: 3.82 s
clear_caches(verbose=True)[INFO] phasic.graph_cache: Cleared 0 cached graphs
Removed 3 file(s), preserved directory structure
Trace cache
When computing moments or running SVGD inference, phasic performs Gaussian elimination on the graph to record an elimination trace — a linear sequence of operations that can be replayed cheaply with different parameter values. Recording the trace is O(n³) and is the most expensive step for large models.
The trace cache stores these elimination traces on disk, keyed by a SHA-256 hash of the graph structure (vertices, edges, rates — but not the specific parameter values). Enable trace caching by passing cache_trace=True to Graph(). Build graph with trace caching enabled:
graph_w_trace = Graph(two_locus_arg_2param, ipv=initial, indexer=indexer,
cache_trace=True)
graph_w_trace.update_weights([2, 5])
graph_w_trace.expectation()[INFO] phasic.hierarchical_trace_cache: Using hierarchical SCC subdivision (graph=1044 vertices, min_size=50)
[INFO] phasic.hierarchical_trace_cache: SCC grouping: 0 large SCCs (≥50 vertices), 123 small SCCs (<50 vertices, 1044 total vertices)
[INFO] phasic.hierarchical_trace_cache: All SCCs are below min_size=50, recording full graph directly (no subdivision)
[INFO] phasic.hierarchical_trace_cache: Parallelization: Auto-selected 'vmap' (multiprocessing with 10 CPUs)
[INFO] phasic.hierarchical_trace_cache: VMAP: Using multiprocessing with 1 workers over 1 work units
[INFO] phasic.trace_elimination: Trace recording complete: 1044 vertices, 469891 operations, phase 2, param_length=2, reward_length=0
[INFO] phasic.hierarchical_trace_cache: SCC 1/1: 1044 vertices, 469891 operations, param_length=2
[INFO] phasic.hierarchical_trace_cache: ✓ Hierarchical trace computation complete (no stitching)
[INFO] phasic.c: Auto-activating MPFR for moment computation (condition 8.61e+12 > threshold 1.00e+12)
[INFO] phasic.c: Computing MPFR graph with 128-bit precision
[INFO] phasic.c: MPFR computation successful - returning high-precision results
0.6872139668331487
graph = Graph(two_locus_arg_2param, ipv=initial, indexer=indexer,
cache_trace=True)
graph.update_weights([2, 5])The trace is computed lazily when first needed and cached:
graph.vertices_length()1044
graph.expectation()[INFO] phasic.hierarchical_trace_cache: ✓ Full graph cache HIT: returning cached trace
[INFO] phasic.c: Auto-activating MPFR for moment computation (condition 8.61e+12 > threshold 1.00e+12)
[INFO] phasic.c: Computing MPFR graph with 128-bit precision
[INFO] phasic.c: MPFR computation successful - returning high-precision results
0.6872139668331488
From now on the trace is read from cache:
%%time
graph.variance()[INFO] phasic.c: Auto-activating MPFR for moment computation (condition 8.61e+12 > threshold 1.00e+12)
[INFO] phasic.c: Computing MPFR graph with 128-bit precision
[INFO] phasic.c: MPFR computation successful - returning high-precision results
[INFO] phasic.c: Auto-activating MPFR for moment computation (condition 8.61e+12 > threshold 1.00e+12)
[INFO] phasic.c: MPFR computation successful - returning high-precision results
CPU times: user 646 ms, sys: 6.29 ms, total: 652 ms
Wall time: 652 ms
0.2403942575637451
If needed you can compute/cache the trace directly like this:
trace = graph.compute_trace()You can check if a trace is cached for the graph an if this is still valid (graph did not change since trace was computed):
graph.cache_trace, graph.trace_valid (True, True)
Because the trace cache is persistent on disk (~/.phasic_cache/traces/), the cached trace is available even if you restart Python and construct the same graph structure again. This makes the trace cache especially valuable for iterative development and parameter exploration.
JAX compilation cache
When running SVGD inference, JAX JIT-compiles the log-likelihood, kernel, and update functions the first time they are called. This compilation can take 1–10 seconds. The JAX compilation cache stores the compiled XLA code on disk so that subsequent Python sessions skip recompilation entirely.
This cache is managed by JAX itself and is enabled automatically by phasic at import time. The cache key is based on the function structure and input shapes (not values), so different parameter vectors reuse the same compiled code.
Configuration
The default cache directory is ~/.jax_cache/. You can change it via environment variable before importing JAX:
export JAX_COMPILATION_CACHE_DIR=/fast/ssd/jax_cacheOr programmatically with the CompilationConfig class:
from phasic import CompilationConfig
config = CompilationConfig.balanced() # sensible defaults
config.apply()
# ## JAX Compilation Cache
# ### What It Caches
# JAX caches compiled XLA code based on:
# - Function structure (HLO graph)
# - Input shapes
# - Device configuration
# ### Basic Configuration
from phasic.jax_config import CompilationConfig
# Balanced preset (default)
config = CompilationConfig.balanced()
config.apply()
# Maximum performance
config = CompilationConfig.max_performance()
config.apply()
# Fast compilation (for development)
config = CompilationConfig.fast_compile()
config.apply()
Inspecting caches
Phasic provides a unified API for inspecting all three cache layers.
Overview of all caches
print_all_cache_info()Cache directory: /Users/kmt/.jax_cache
Cached compilations: 0
Total size: 0.0 MB
Cache directory: /Users/kmt/.phasic_cache/graphs
Status: No cached graphs
Cache directory: /Users/kmt/.phasic_cache/traces
Status: No cached traces
Individual cache layers
Each layer has its own inspection functions:
# Graph cache
print_graph_cache_info()Cache directory: /Users/kmt/.phasic_cache/graphs
Status: No cached graphs
# Trace cache
print_trace_cache_info()Cache directory: /Users/kmt/.phasic_cache/traces
Status: No cached traces
# JAX compilation cache
jax_info = cache_info()
print(f"JAX cache: {jax_info['num_files']} files, {jax_info['total_size_mb']:.1f} MB")JAX cache: 0 files, 0.0 MB
Programmatic access
For scripting, get_all_cache_stats() returns a dictionary with statistics for each layer:
stats = get_all_cache_stats()
for name, layer_stats in stats.items():
print(f"{name}: {layer_stats}")jax: {'exists': True, 'path': '/Users/kmt/.jax_cache', 'num_files': 0, 'total_size_mb': 0.0, 'files': []}
graph: {'num_graphs': 0, 'total_size_mb': 0.0, 'cache_dir': '/Users/kmt/.phasic_cache/graphs'}
trace: {'total_files': 0, 'total_bytes': 0, 'total_mb': 0.0, 'cache_dir': '/Users/kmt/.phasic_cache/traces'}
Listing individual trace entries
You can list the cached traces with metadata about each entry:
for entry in list_cached_traces():
print(f"Hash: {entry['hash'][:16]}... "
f"Size: {entry.get('size_kb', 0):.1f} KB "
f"Vertices: {entry.get('n_vertices', 'N/A')}")Clearing caches
Phasic provides several functions for clearing caches at different granularities:
| Function | What it clears |
|---|---|
clear_caches() |
All three cache layers |
clear_model_cache() |
Graph cache + trace cache |
clear_jax_cache() |
JAX compilation cache only |
# Clear only model-related caches (graph + trace)
clear_model_cache(verbose=True)
# Verify they are empty
print(f"\nGraph cache: {get_graph_cache_stats()['num_graphs']} graphs")
print(f"Trace cache: {get_trace_cache_stats()['total_files']} files")[INFO] phasic.graph_cache: Cleared 0 cached graphs
Removed 1 file(s), preserved directory structure
Graph cache: 0 graphs
Trace cache: 0 files
Selective cleanup
For production environments where caches grow over time, you can prune old or oversized entries from the trace cache:
# Remove traces older than 30 days or enforce a 100 MB size limit
removed = cleanup_old_traces(max_size_mb=100.0, max_age_days=30)
print(f"Removed {removed} old trace entries")Removed 0 old trace entries
You can also remove a specific trace by hash:
from phasic.trace_cache import remove_cached_trace
removed = remove_cached_trace("abc123def456...") # Returns True/FalseOr clear everything from the command line:
# Clear all phasic caches
rm -rf ~/.phasic_cache/ ~/.jax_cache/