Graph API Call Tree

This document explains the complete call hierarchy when users interact with the Graph class in the phasic library. Each callable is documented with its purpose, functionality, and what it calls. Callouts are color-coded to show implementation type. Yellow callouts indicate pure Python code. Blue callouts indicate pybind11 bindings to C or C++ code. Red callouts indicate ctypes bindings to compiled C++ libraries. Green callouts indicate JAX FFI bindings for distributed computing support.

Graph Construction

def __init__(self, state_length:int=None, callback:Callable=None,
             parameterized:bool=False, **kwargs) -> Graph

This is the primary entry point for creating phase-type distribution graphs. It serves as the user-facing constructor that validates inputs and routes to the appropriate C++ implementation through pybind11. The function validates that exactly one of state_length or callback is provided, raising an assertion error if both or neither are given.

When a callback function is provided with parameterized set to True, the constructor wraps the callback using functools.partial to bind any additional keyword arguments, then passes it to the parent class constructor via super().__init__ with callback_tuples_parameterized. This path is used for graphs where edge weights are parameterized functions of theta parameters with coefficient vectors. When parameterized is False, the callback is wrapped and passed via callback_tuples for graphs with fixed edge weights. If no callback is provided, only state_length is passed to create an empty graph structure that can be manually populated with vertices and edges.

The parent class constructor is implemented in C++ through pybind11, which provides the actual graph data structure and algorithms. The Python wrapper handles parameter validation and routing while the C++ implementation performs the computationally intensive state space exploration and graph construction.

phasic_pybind.cpp::Graph.__init__

Graph(py::function callback_tuples_parameterized)
Graph(py::function callback_tuples)
Graph(int state_length)

These C++ constructors build phase-type graphs using different initialization strategies. The callback-based constructors explore the state space by iteratively calling the Python callback function to generate reachable states and edges. The callback_tuples_parameterized variant expects tuples of the form (next_state, base_weight, edge_state) where edge_state is a coefficient vector enabling linear parameterization of edge weights as weight equals base_weight plus the dot product of coefficients and theta parameters. The callback_tuples variant expects simpler tuples (next_state, weight) where weight is a fixed numeric value. The state_length constructor creates an empty graph with just a starting vertex, allowing manual graph construction through the Vertex API.

The constructors start from an empty state vector and explore the state space breadth-first, calling the callback with each discovered state to find its successors. Vertices are created for each unique state encountered and edges are added with their weights or parameterization coefficients. The C++ implementation uses efficient hash maps and AVL trees to track visited states and maintain the graph structure. For parameterized graphs, edge coefficient vectors are stored for later use when evaluating the graph with specific theta values.

The Graph.__init__ function returns a Graph object with its complete state space structure, vertices, and edges ready for computation through methods like pdf, moments, sample, or for conversion to JAX-compatible models.

JAX Model Creation

@classmethod
def pmf_from_graph(cls, graph: Graph, discrete: bool = False,
                   use_cache: bool = True, param_length: int = None) -> Callable

This class method converts a Python-built Graph into a JAX-compatible function with full gradient support through custom vector-Jacobian product rules. It automatically detects whether the graph has parameterized edges by checking for edge state vectors and generates optimized C++ code accordingly. For parameterized graphs it produces a function with signature (theta, times) that supports JAX transformations including gradients, vmap, and jit compilation. For non-parameterized graphs it produces a simpler function (times) that maintains backward compatibility.

The method serializes the graph structure to JSON and checks a symbolic DAG cache to see if an equivalent graph structure has been processed before. If cached, it reuses the previous computation to avoid expensive symbolic elimination. If not cached, it generates custom C++ code implementing the forward algorithm for the specific graph structure. The generated code is compiled into a shared library and loaded via ctypes, then wrapped with jax.pure_callback to enable JAX integration while calling the fast C++ implementation.

def _generate_cpp_from_graph(serialized: dict) -> str

This function generates C++ source code implementing the forward algorithm for a specific graph structure. It takes a serialized graph representation containing states, edges, and topology information and produces complete C++ code with computation functions for both PDF and moments. The generated code includes optimized loops and data structures tailored to the exact graph structure, avoiding generic graph traversal overhead.

The generator creates functions compute_pmf_nonparam and compute_dph_pmf_nonparam for continuous and discrete distributions respectively. For parameterized graphs it generates versions that accept theta parameters and evaluate edge weights using the stored coefficient vectors. The code implements uniformization-based forward algorithms with granularity control for numerical stability. Matrix representations are also generated for alternative computation methods when beneficial.

def _compile_wrapper_library(wrapper_code: str, lib_name: str,
                              extra_includes: List[str]) -> str

This function compiles generated C++ code into a shared library that can be loaded and called from Python. It writes the C++ source code to a temporary file, invokes the system C++ compiler with appropriate flags for the platform, and returns the path to the compiled library. The compilation includes the phasic C headers for access to core data structures and algorithms.

The compiler is invoked with optimization flags and platform-specific settings to generate efficient machine code. Error handling captures compilation failures and reports them with context about what code failed to compile. The resulting shared library exposes C-compatible function interfaces that ctypes can call directly.

def _compute_pmf_from_ctypes(theta, times, compute_func, graph_data,
                              granularity, discrete) -> np.ndarray

This function provides the bridge between JAX arrays and the compiled C++ library using ctypes. It extracts the underlying numpy arrays from JAX tracers, sets up ctypes pointer arguments matching the C function signature, calls the compiled function, and returns the results as numpy arrays that JAX can wrap.

The function handles both discrete and continuous distributions with appropriate type conversions. For parameterized graphs it passes theta arrays, while for non-parameterized graphs theta is unused. Arrays are passed via their ctypes data_as method to get C-compatible pointers. The compiled C++ function writes results directly into the output array buffer.

The pmf_from_graph method returns a JAX-compatible callable that internally routes through the compiled C++ code for performance while maintaining JAX transformation compatibility through custom differentiation rules.

@classmethod
def pmf_and_moments_from_graph(cls, graph: Graph, nr_moments: int = 2,
                                discrete: bool = False, use_ffi: bool = False,
                                param_length: int = None) -> Callable

This class method creates a function computing both PMF or PDF values and distribution moments simultaneously. Computing both together is more efficient than separate calls because the graph is built only once with specific theta parameters. The method supports two implementation backends chosen by the use_ffi parameter, trading off between simplicity and performance depending on whether JAX FFI is available.

When use_ffi is True, the method uses JAX Foreign Function Interface to call C++ GraphBuilder directly through XLA without Python callbacks. This enables full integration with JAX transformations including pmap for multi-device parallelization. The FFI approach passes graph structure as static attributes and theta, times, rewards as dynamic buffers that XLA can batch and distribute across devices. When use_ffi is False, it falls back to jax.pure_callback wrapping the pybind11 GraphBuilder interface. This works for jit and vmap but cannot use pmap for distributed computing.

def compute_pmf_and_moments_ffi(structure_json: str, theta: jax.Array,
                                 times: jax.Array, nr_moments: int,
                                 discrete: bool = False, granularity: int = 0,
                                 rewards: jax.Array = None) -> Tuple[jax.Array, jax.Array]

This function provides JAX FFI integration for computing PMF and moments through XLA. It registers the FFI target ptd_compute_pmf_and_moments with XLA if not already registered, then calls jax.ffi.ffi_call to invoke the C++ implementation. The graph structure JSON is passed as a static attribute while theta, times, and rewards are dynamic buffers that can be batched via vmap or distributed via pmap.

The FFI call specifies vmap_method as expand_dims to enable automatic batching across theta particles in SVGD. XLA handles moving data to appropriate devices and scheduling computation. The C++ handler deserializes the graph structure, builds the graph with provided theta parameters, computes the PDF using the forward algorithm, computes moments using reward transformation, and writes results to output buffers that XLA wraps as JAX arrays.

phasic_pybind.cpp::GraphBuilder.compute_pmf_and_moments

std::pair<py::array_t<double>, py::array_t<double>>
compute_pmf_and_moments(py::array_t<double> theta, py::array_t<double> times,
                        int nr_moments, bool discrete = false,
                        int granularity = 100, py::object rewards = py::none())

This C++ method is the core implementation computing both PMF and moments efficiently. It builds the graph once using the provided theta parameters to evaluate parameterized edge weights, then performs both PDF computation via the forward algorithm and moment computation via repeated calls to expected_waiting_time. The implementation releases the Python GIL during computation to enable true parallelism.

For continuous distributions it calls the forward algorithm with uniformization to compute PDF values at each requested time point. For discrete distributions it uses the discrete phase-type algorithm. Moments are computed by iteratively calling expected_waiting_time with reward vectors, using the formula E[T^k] equals k factorial times the expected waiting time with appropriate reward structure. When rewards are provided, moments represent the reward-transformed distribution E[(R dot T)^k]. The method returns both results as numpy arrays.

def _compute_pmf_and_moments_cached(theta_np, times_np, rewards_np) -> Tuple[np.ndarray, np.ndarray]

This internal function provides the non-FFI implementation by directly calling the pybind11 GraphBuilder interface. It creates a GraphBuilder instance from the serialized graph structure, then calls compute_pmf_and_moments with the numpy arrays. The GraphBuilder caches the structure allowing repeated calls with different theta values to skip structure parsing.

Results are returned as numpy arrays that jax.pure_callback wraps into JAX arrays. This approach works with jit and vmap through pure_callback mechanisms but cannot leverage pmap for distributed execution.

The pmf_and_moments_from_graph method returns a JAX-compatible function with signature (theta, times, rewards=None) that produces a tuple of PMF values and moment values, enabling efficient SVGD inference with moment-based regularization.

@classmethod
def pmf_and_moments_from_graph_multivariate(cls, graph: Graph, nr_moments: int = 2,
                                             discrete: bool = False, use_ffi: bool = False,
                                             param_length: int = None) -> Callable

This class method extends pmf_and_moments_from_graph to handle multivariate phase-type distributions where observations have multiple feature dimensions. Each feature uses the same base graph structure but with different reward vectors to extract different marginal distributions. The method creates a model function accepting 2D rewards arrays of shape (n_vertices, n_features) and 2D observation arrays of shape (n_times, n_features).

The implementation uses jax.lax.scan to loop over feature dimensions in compiled code rather than Python loops that would break JIT compilation. For each feature column it calls the base model with the corresponding reward vector, collecting PMF and moment results. The scan approach keeps all computation inside the XLA compiled graph for stability and performance. PMF results are stacked into a 2D array and moments are averaged across features for regularization.

The returned function signature is (theta, times, rewards) where times and rewards are both 2D arrays. Missing observations are represented as NaN values which are filtered out during log-likelihood computation. This enables learning from datasets where different observations have different features measured.

SVGD Inference

def svgd(self, observed_data: ArrayLike, discrete: bool = False,
         prior: Optional[Callable] = None, n_particles: int = 50,
         n_iterations: int = 1000, learning_rate: float = 0.001,
         bandwidth: str = 'median', theta_init: Optional[ArrayLike] = None,
         theta_dim: Optional[int] = None, return_history: bool = True,
         seed: int = 42, verbose: bool = True, jit: Optional[bool] = None,
         parallel: Optional[str] = None, n_devices: Optional[int] = None,
         precompile: bool = True, compilation_config: Optional[object] = None,
         regularization: float = 0.0, nr_moments: int = 2,
         positive_params: bool = True, param_transform: Optional[Callable] = None,
         rewards: Optional[ArrayLike] = None) -> Dict

This method provides a high-level interface for running Stein Variational Gradient Descent to perform Bayesian parameter inference on phase-type distributions. It automatically creates the appropriate model function based on whether rewards are provided for multivariate distributions, configures the SVGD optimizer with the specified settings, and runs the inference to approximate the posterior distribution over parameters.

The method first determines whether to use pmf_and_moments_from_graph or pmf_and_moments_from_graph_multivariate based on the dimensionality of rewards and observed_data. It serializes the graph and creates a model function with moment computation for optional regularization. The model is then passed to an SVGD instance along with the observed data, prior, and optimization settings. If parallel is set to pmap and multiple devices are available, it configures multi-device parallelization for faster inference.

def __init__(self, model: Callable, observed_data: ArrayLike, theta_dim: int,
             prior: Optional[Callable] = None, n_particles: int = 50,
             learning_rate: float = 0.001, bandwidth: str = 'median',
             seed: int = 42, positive_params: bool = True,
             param_transform: Optional[Callable] = None, rewards: Optional[ArrayLike] = None)

This class implements Stein Variational Gradient Descent for Bayesian inference. It initializes a swarm of particles in parameter space and iteratively updates them to approximate the posterior distribution. The constructor stores the model function, observed data, and inference configuration. It initializes particles either from provided theta_init or randomly from a standard normal distribution transformed to be positive if positive_params is True.

def optimize(self, n_iterations: int = 1000, return_history: bool = True,
             verbose: bool = True, precompile: bool = True) -> Dict

This method runs the SVGD optimization loop to approximate the posterior. It iteratively computes log-likelihood gradients for each particle, evaluates the Stein kernel to capture interactions between particles, and updates particle positions using the SVGD update rule. The kernel bandwidth is computed using the median heuristic unless specified otherwise. After optimization it returns summary statistics including posterior mean, standard deviation, and optionally the full history of particle positions.

The svgd method returns a dictionary containing the SVGD results including estimated posterior mean and standard deviation over parameters, the final particle positions, and optionally the full optimization history showing how particles evolved.

Direct Graph Operations

phasic_pybind.cpp::Graph.pdf

double pdf(float time, int granularity = 0)

This method computes the probability density function at a specific time point using the forward algorithm with uniformization. It creates or reuses a probability distribution context structure that maintains state for iterative PDF computation. The granularity parameter controls the discretization fineness for numerical stability, with 0 triggering automatic selection based on maximum vertex rate.

The forward algorithm discretizes continuous time into steps of size 1/lambda where lambda is the uniformization rate. At each discrete step it propagates probability mass through the graph according to transition probabilities, tracking the probability of absorption at each step. The PDF value at the requested time is extracted from the cached results. The context is reused across multiple calls for efficiency when computing PDF at many time points.

phasic_pybind.cpp::Graph.dph_pmf

int dph_pmf(int jumps)

This method computes the probability mass function for discrete phase-type distributions at a specific number of jumps. It uses the discrete phase-type algorithm which propagates probability mass through discrete time steps corresponding to Markov chain jumps rather than continuous time. The method creates or reuses a discrete probability distribution context and steps it forward until reaching the requested number of jumps, then returns the probability of absorption at that point.

phasic_pybind.cpp::Graph.moments

py::array_t<double> moments(int power, py::object rewards_obj = py::none())

This method computes moments of the phase-type distribution up to the specified power. Without rewards it computes standard moments E[T], E[T^2], through E[T^power]. With rewards it computes moments of the reward-transformed distribution E[R·T], E[(R·T)^2], through E[(R·T)^power]. The implementation calls ptd_expected_waiting_time iteratively with appropriately constructed reward vectors, using the formula E[T^k] equals k factorial times the result.

For reward-transformed moments the reward vector is applied at each iteration scaled by previous moment values. This efficiently computes higher moments through the phase-type structure without requiring explicit distribution integration. Results are returned as a numpy array of moment values.

phasic_pybind.cpp::Graph.sample

double sample(py::object rewards_obj = py::none())

This method generates a random sample from the phase-type distribution using the Gillespie algorithm. It simulates a path through the Markov chain by drawing waiting times from exponential distributions and choosing transitions according to edge probabilities. The method accumulates the total time until absorption and returns it as the sample value.

When rewards are provided, the method multiplies each waiting time by the corresponding reward value before accumulating, effectively sampling from the reward-transformed distribution. This enables sampling from multivariate phase-type distributions by using different reward vectors for different features.

phasic_pybind.cpp::Graph.serialize

py::dict serialize(int param_length = -1)

This method serializes the complete graph structure into a Python dictionary containing all information needed to reconstruct the graph or generate computation code. The dictionary includes arrays of vertex states, edge connectivity, edge weights or coefficient vectors for parameterized edges, and metadata like state_length and param_length. The serialization format is designed to be JSON-compatible for caching and code generation.

For parameterized graphs the method attempts to auto-detect param_length by examining edge state vectors unless explicitly provided. This avoids potential issues with uninitialized memory when edge state vectors have not been properly allocated. The serialized format is used by pmf_from_graph and related methods to generate optimized C++ code.

phasic_pybind.cpp::Graph.as_matrices

MatrixRepresentation as_matrices()

This method extracts the phase-type distribution’s matrix representation including the initial probability vector, sub-intensity matrix, and state matrix. It performs graph elimination to reduce cycles and computes the canonical form matrices. The result includes the states array with one row per vertex, the sub-intensity matrix SIM encoding transition rates between transient states, and the initial probability vector IPV giving the initial distribution over states.

This representation enables alternative computation methods based on matrix exponentiation and provides a standard format for comparison with other phase-type libraries. The elimination process may produce a different but equivalent representation of the same distribution.