Hierarchical symbolic trace

WarningUnder Construction

This notebook is under development

Summary: Graph(hierarchical=True) Implementation

The hierarchical=True flag has been added to the Graph constructor to enable trace-based computation for moments, expectation, and variance methods.

Key Features

  1. Constructor flag: Graph(…, hierarchical=True)
  2. Lazy trace computation: Trace computed on first call to compute_trace() or when needed
  3. Non-destructive: Clones graph before recording trace to preserve original
  4. Trace invalidation: Automatic invalidation via _invalidates_trace decorator when graph structure changes
  5. Theta caching: update_weights(theta) caches theta for trace-based evaluation

API

from phasic import Graph

from phasic import (
    Graph, with_ipv, 
    clear_caches, clear_jax_cache, clear_model_cache, 
    cache_info, cache_manager
)
from vscodenb import set_vscode_theme
set_vscode_theme()

import numpy as np
Overriding theme from NOTEBOOK_THEME environment variable. <phasic._DeviceListFilter object at 0x16a428a50>
nr_samples = 4

@with_ipv([nr_samples]+[0]*(nr_samples-1))
def coalescent_1param(state):
    transitions = []
    for i in range(state.size):
        for j in range(i, state.size):            
            same = int(i == j)
            if same and state[i] < 2:
                continue
            if not same and (state[i] < 1 or state[j] < 1):
                continue 
            new = state.copy()
            new[i] -= 1
            new[j] -= 1
            new[i+j+1] += 1
            transitions.append([new, [state[i]*(state[j]-same)/(1+same)]])
    return transitions
g = Graph(coalescent_1param, cache_trace=True)
g.update_weights([2])
g.expectation()
0.75

Exploration loop - fast parameter updates:

for theta in [1, 2, 3]:
    g.update_weights([theta])      # Caches theta, updates edges
    mean = g.expectation()
    var = g.variance()

Explicit trace computation:

trace = g.compute_trace()        # Non-destructive, returns cached trace

Check trace status:

g.cache_trace    # True
True
g.trace_valid     # True if trace exists and not invalidated
True