Multi-level caching

Phasic uses a three-layer caching system to avoid repeating expensive computations. Each layer targets a different bottleneck in the pipeline from model definition to inference:

Layer What is cached Location Speedup
Graph cache Fully constructed Graph objects ~/.phasic_cache/graphs/ Avoids callback-based construction
Trace cache Elimination traces from Gaussian elimination ~/.phasic_cache/traces/ Avoids O(n³) elimination
JAX compilation cache JIT-compiled XLA code ~/.jax_cache/ Avoids recompilation on restart

All three caches are persistent — they survive across Python sessions and restarts. Cache correctness is ensured by SHA-256 content hashing: the same graph structure always produces the same hash, and any structural change automatically invalidates the entry.

from phasic import (
    Graph, Property, StateIndexer, set_log_level,
    clear_caches, clear_model_cache,
    cache_info, get_all_cache_stats, print_all_cache_info,
    get_graph_cache_stats, print_graph_cache_info,
    get_trace_cache_stats, print_trace_cache_info,    
)
from phasic.trace_cache import list_cached_traces, cleanup_old_traces
import numpy as np
import time
%config InlineBackend.figure_format = 'svg'
from vscodenb import set_vscode_theme

set_vscode_theme()

We use the ARG with two parameters as example model. I have added a dummy keyword arg (that does nothing) for demonstration purposes:

nr_samples = 6
indexer = StateIndexer(descendants=[
    Property('loc1', max_value=nr_samples),
    Property('loc2', max_value=nr_samples)
])

initial = [0] * indexer.state_length
initial[indexer.props_to_index(loc1=1, loc2=1)] = nr_samples

def two_locus_arg_2param(state, indexer=None, dummy=None):

    transitions = []
    if state.sum() <= 1: return transitions

    for i in range(indexer.state_length):
        if state[i] == 0: continue
        pi = indexer.index_to_props(i)

        for j in range(i, indexer.state_length):
            if state[j] == 0: continue
            pj = indexer.index_to_props(j)
            
            same = int(i == j)
            if same and state[i] < 2:
                continue
            if not same and (state[i] < 1 or state[j] < 1):
                continue 
            child = state.copy()
            child[i] -= 1
            child[j] -= 1
            loc1 = pi.descendants.loc1 + pj.descendants.loc1
            loc2 = pi.descendants.loc2 + pj.descendants.loc2
            if loc1 <= nr_samples and loc2 <= nr_samples:
                child[indexer.props_to_index(loc1=loc1, loc2=loc2)] += 1
                transitions.append([child, [state[i]*(state[j]-same)/(1+same), 0]]) 

        if state[i] > 0 and pi.descendants.loc1 > 0 and pi.descendants.loc2 > 0:
            child = state.copy()
            child[i] -= 1
            child[indexer.props_to_index(loc1=pi.descendants.loc1, loc2=0)] += 1
            child[indexer.props_to_index(loc1=0, loc2=pi.descendants.loc2)] += 1
            transitions.append([child, [0, 1]])                                

    return transitions

Start from a clean slate and enable info logging so the examples below show cache misses and hits clearly:

set_log_level('INFO')
clear_caches(verbose=True)
[INFO] phasic.graph_cache: Cleared 0 cached graphs
  Removed 15 file(s), preserved directory structure

Graph cache

Building a graph from a callback function requires exploring the full state space, creating vertices and edges, and can take seconds to minutes for large models. The graph cache stores fully constructed Graph objects on disk so that the same model can be loaded instantly on subsequent calls.

The cache key is a SHA-256 hash of:

  • The callback function’s AST (abstract syntax tree), so whitespace/comment changes are ignored but code changes invalidate the cache
  • All construction parameters (ipv, nr_samples, keyword arguments)

Enable the graph cache by passing graph_cache=True to Graph(). First build constructs graph from callback and saves to cache:

%%time 
graph = Graph(two_locus_arg_2param, ipv=initial, indexer=indexer,
    graph_cache=True)
[INFO] phasic.graph_cache: Saved graph to cache: 7d8de78485fc9b6b... (1044 vertices)
[INFO] phasic: Saved graph to cache: 1044 vertices
CPU times: user 3.77 s, sys: 23.4 ms, total: 3.79 s
Wall time: 3.78 s

Second build is loaded from cache:

%%time
graph = Graph(two_locus_arg_2param, ipv=initial, indexer=indexer,
    graph_cache=True, dummy=42)
[INFO] phasic.graph_cache: Saved graph to cache: 1c6b4850af250c3a... (1044 vertices)
[INFO] phasic: Saved graph to cache: 1044 vertices
CPU times: user 3.83 s, sys: 27.5 ms, total: 3.86 s
Wall time: 3.85 s

If you modify the callback function or pass different parameters, the cache misses and the graph is rebuilt. Even though our dummy keyword arg does nothing, passing a new value triggers a rebuild of the graph:

%%time
graph = Graph(two_locus_arg_2param, ipv=initial, indexer=indexer,
    graph_cache=True, dummy=99)
[INFO] phasic.graph_cache: Saved graph to cache: d3cee6407f2ba48f... (1044 vertices)
[INFO] phasic: Saved graph to cache: 1044 vertices
CPU times: user 3.8 s, sys: 29.5 ms, total: 3.83 s
Wall time: 3.82 s
clear_caches(verbose=True)
[INFO] phasic.graph_cache: Cleared 0 cached graphs
  Removed 3 file(s), preserved directory structure

Trace cache

When computing moments or running SVGD inference, phasic performs Gaussian elimination on the graph to record an elimination trace — a linear sequence of operations that can be replayed cheaply with different parameter values. Recording the trace is O(n³) and is the most expensive step for large models.

The trace cache stores these elimination traces on disk, keyed by a SHA-256 hash of the graph structure (vertices, edges, rates — but not the specific parameter values). Enable trace caching by passing cache_trace=True to Graph(). Build graph with trace caching enabled:

graph_w_trace = Graph(two_locus_arg_2param, ipv=initial, indexer=indexer,
    cache_trace=True)
graph_w_trace.update_weights([2, 5])
graph_w_trace.expectation()
[INFO] phasic.hierarchical_trace_cache: Using hierarchical SCC subdivision (graph=1044 vertices, min_size=50)
[INFO] phasic.hierarchical_trace_cache: SCC grouping: 0 large SCCs (≥50 vertices), 123 small SCCs (<50 vertices, 1044 total vertices)
[INFO] phasic.hierarchical_trace_cache: All SCCs are below min_size=50, recording full graph directly (no subdivision)
[INFO] phasic.hierarchical_trace_cache: Parallelization: Auto-selected 'vmap' (multiprocessing with 10 CPUs)
[INFO] phasic.hierarchical_trace_cache: VMAP: Using multiprocessing with 1 workers over 1 work units
[INFO] phasic.trace_elimination: Trace recording complete: 1044 vertices, 469891 operations, phase 2, param_length=2, reward_length=0
[INFO] phasic.hierarchical_trace_cache:   SCC 1/1: 1044 vertices, 469891 operations, param_length=2
[INFO] phasic.hierarchical_trace_cache: ✓ Hierarchical trace computation complete (no stitching)
[INFO] phasic.c: Auto-activating MPFR for moment computation (condition 8.61e+12 > threshold 1.00e+12)
[INFO] phasic.c: Computing MPFR graph with 128-bit precision
[INFO] phasic.c: MPFR computation successful - returning high-precision results
0.6872139668331487
graph = Graph(two_locus_arg_2param, ipv=initial, indexer=indexer,
    cache_trace=True)
graph.update_weights([2, 5])

The trace is computed lazily when first needed and cached:

graph.vertices_length()
1044
graph.expectation()
[INFO] phasic.hierarchical_trace_cache: ✓ Full graph cache HIT: returning cached trace
[INFO] phasic.c: Auto-activating MPFR for moment computation (condition 8.61e+12 > threshold 1.00e+12)
[INFO] phasic.c: Computing MPFR graph with 128-bit precision
[INFO] phasic.c: MPFR computation successful - returning high-precision results
0.6872139668331488

From now on the trace is read from cache:

%%time
graph.variance()
[INFO] phasic.c: Auto-activating MPFR for moment computation (condition 8.61e+12 > threshold 1.00e+12)
[INFO] phasic.c: Computing MPFR graph with 128-bit precision
[INFO] phasic.c: MPFR computation successful - returning high-precision results
[INFO] phasic.c: Auto-activating MPFR for moment computation (condition 8.61e+12 > threshold 1.00e+12)
[INFO] phasic.c: MPFR computation successful - returning high-precision results
CPU times: user 646 ms, sys: 6.29 ms, total: 652 ms
Wall time: 652 ms
0.2403942575637451

If needed you can compute/cache the trace directly like this:

trace = graph.compute_trace()

You can check if a trace is cached for the graph an if this is still valid (graph did not change since trace was computed):

graph.cache_trace, graph.trace_valid 
(True, True)

Because the trace cache is persistent on disk (~/.phasic_cache/traces/), the cached trace is available even if you restart Python and construct the same graph structure again. This makes the trace cache especially valuable for iterative development and parameter exploration.

JAX compilation cache

When running SVGD inference, JAX JIT-compiles the log-likelihood, kernel, and update functions the first time they are called. This compilation can take 1–10 seconds. The JAX compilation cache stores the compiled XLA code on disk so that subsequent Python sessions skip recompilation entirely.

This cache is managed by JAX itself and is enabled automatically by phasic at import time. The cache key is based on the function structure and input shapes (not values), so different parameter vectors reuse the same compiled code.

Configuration

The default cache directory is ~/.jax_cache/. You can change it via environment variable before importing JAX:

export JAX_COMPILATION_CACHE_DIR=/fast/ssd/jax_cache

Or programmatically with the CompilationConfig class:

from phasic import CompilationConfig

config = CompilationConfig.balanced()   # sensible defaults
config.apply()




# ## JAX Compilation Cache

# ### What It Caches

# JAX caches compiled XLA code based on:
# - Function structure (HLO graph)
# - Input shapes
# - Device configuration

# ### Basic Configuration


from phasic.jax_config import CompilationConfig

# Balanced preset (default)
config = CompilationConfig.balanced()
config.apply()

# Maximum performance
config = CompilationConfig.max_performance()
config.apply()

# Fast compilation (for development)
config = CompilationConfig.fast_compile()
config.apply()

Inspecting caches

Phasic provides a unified API for inspecting all three cache layers.

Overview of all caches

print_all_cache_info()
Cache directory: /Users/kmt/.jax_cache
Cached compilations: 0
Total size: 0.0 MB

Cache directory: /Users/kmt/.phasic_cache/graphs
Status: No cached graphs

Cache directory: /Users/kmt/.phasic_cache/traces
Status: No cached traces

Individual cache layers

Each layer has its own inspection functions:

# Graph cache
print_graph_cache_info()
Cache directory: /Users/kmt/.phasic_cache/graphs
Status: No cached graphs
# Trace cache
print_trace_cache_info()
Cache directory: /Users/kmt/.phasic_cache/traces
Status: No cached traces
# JAX compilation cache
jax_info = cache_info()
print(f"JAX cache: {jax_info['num_files']} files, {jax_info['total_size_mb']:.1f} MB")
JAX cache: 0 files, 0.0 MB

Programmatic access

For scripting, get_all_cache_stats() returns a dictionary with statistics for each layer:

stats = get_all_cache_stats()
for name, layer_stats in stats.items():
    print(f"{name}: {layer_stats}")
jax: {'exists': True, 'path': '/Users/kmt/.jax_cache', 'num_files': 0, 'total_size_mb': 0.0, 'files': []}
graph: {'num_graphs': 0, 'total_size_mb': 0.0, 'cache_dir': '/Users/kmt/.phasic_cache/graphs'}
trace: {'total_files': 0, 'total_bytes': 0, 'total_mb': 0.0, 'cache_dir': '/Users/kmt/.phasic_cache/traces'}

Listing individual trace entries

You can list the cached traces with metadata about each entry:

for entry in list_cached_traces():
    print(f"Hash: {entry['hash'][:16]}...  "
          f"Size: {entry.get('size_kb', 0):.1f} KB  "
          f"Vertices: {entry.get('n_vertices', 'N/A')}")

Clearing caches

Phasic provides several functions for clearing caches at different granularities:

Function What it clears
clear_caches() All three cache layers
clear_model_cache() Graph cache + trace cache
clear_jax_cache() JAX compilation cache only
# Clear only model-related caches (graph + trace)
clear_model_cache(verbose=True)

# Verify they are empty
print(f"\nGraph cache: {get_graph_cache_stats()['num_graphs']} graphs")
print(f"Trace cache: {get_trace_cache_stats()['total_files']} files")
[INFO] phasic.graph_cache: Cleared 0 cached graphs
  Removed 1 file(s), preserved directory structure

Graph cache: 0 graphs
Trace cache: 0 files

Selective cleanup

For production environments where caches grow over time, you can prune old or oversized entries from the trace cache:

# Remove traces older than 30 days or enforce a 100 MB size limit
removed = cleanup_old_traces(max_size_mb=100.0, max_age_days=30)
print(f"Removed {removed} old trace entries")
Removed 0 old trace entries

You can also remove a specific trace by hash:

from phasic.trace_cache import remove_cached_trace
removed = remove_cached_trace("abc123def456...")  # Returns True/False

Or clear everything from the command line:

# Clear all phasic caches
rm -rf ~/.phasic_cache/ ~/.jax_cache/