Distributed computing

Phasic automatically configures JAX for multi-CPU parallelism at import time, and provides utilities for scaling to multi-node SLURM clusters. This tutorial covers:

  1. Automatic multi-CPU setup — how phasic detects and uses available CPUs
  2. Parallelism in SVGD — how particles are distributed across devices
  3. Manual configuration — overriding defaults and context managers
  4. SLURM clusters — single-node and multi-node distributed computing

The parallelization strategy depends on the number of available JAX devices:

Devices Strategy Description
>1 pmap Particles distributed across devices
1 (multi-CPU) vmap Vectorized computation on single device
1 (single-CPU) none Sequential execution
from phasic import (
    Graph, with_ipv,
    init_parallel,
    detect_environment, get_parallel_config,
    parallel_config, disable_parallel,
    EnvironmentInfo, ParallelConfig,
    set_log_level
)
import jax
import numpy as np
import time
from vscodenb import set_vscode_theme

set_vscode_theme()

We use a simple parameterized coalescent model throughout:

nr_samples = 4

@with_ipv([nr_samples] + [0] * (nr_samples - 1))
def coalescent(state):
    transitions = []
    for i in range(state.size):
        for j in range(i, state.size):
            same = int(i == j)
            if same and state[i] < 2:
                continue
            if not same and (state[i] < 1 or state[j] < 1):
                continue
            new = state.copy()
            new[i] -= 1
            new[j] -= 1
            new[i + j + 1] += 1
            transitions.append((new, state[i] * (state[j] - same) / (1 + same)))
    return transitions

Automatic multi-CPU setup

When you import phasic, it automatically:

  1. Detects available CPUs — on Apple Silicon, it uses only performance cores; otherwise, os.cpu_count()
  2. Sets XLA_FLAGS — configures --xla_force_host_platform_device_count so JAX creates one virtual device per CPU
  3. Enables 64-bit precision — sets JAX_ENABLE_X64=1 for numerical accuracy
  4. Sets platform to CPU — via JAX_PLATFORMS=cpu

This happens before JAX is imported, which is why phasic must be imported before JAX:

from phasic import Graph   # Sets XLA_FLAGS, then imports JAX
import jax                  # Picks up the flags phasic set

If JAX is already imported, phasic raises an ImportError with instructions.

You can check the current device setup:

devices = jax.devices()
print(f"Number of JAX devices: {len(devices)}")
print(f"Device type: {devices[0].platform}")

Overriding CPU count

To change the number of CPUs phasic uses, set the PTDALG_CPUS environment variable before importing phasic:

export PTDALG_CPUS=4
python my_script.py

Or in a notebook (before any imports):

import os
os.environ['PTDALG_CPUS'] = '4'
from phasic import Graph

Inspecting the detected environment

The detect_environment() function returns an EnvironmentInfo object describing the current execution context:

env = detect_environment()
print(f"Environment type:  {env.env_type}")
print(f"Interactive:       {env.is_interactive}")
print(f"Available CPUs:    {env.available_cpus}")
print(f"SLURM detected:   {env.slurm_info is not None}")
print(f"JAX imported:      {env.jax_already_imported}")

Explicit initialization with init_parallel()

For more control, call init_parallel() to explicitly configure parallelism. This is especially useful when you want to set a specific CPU count or when running on a SLURM cluster where automatic detection should be triggered:

config = init_parallel()
print(f"Device count:       {config.device_count}")
print(f"Local device count: {config.local_device_count}")
print(f"Strategy:           {config.strategy}")
print(f"Environment:        {config.env_info.env_type}")

You can pass an explicit CPU count:

config = init_parallel(cpus=8)  # Use exactly 8 devices

The returned ParallelConfig is stored globally and used by graph.svgd() to select its parallelization strategy.

Parallelism in SVGD

The graph.svgd() method automatically parallelizes particle updates across available devices. The parallel parameter controls the strategy:

Value Behavior
None (default) Auto-select: pmap if multiple devices, vmap otherwise
'pmap' Distribute particles across devices (multiple CPUs/GPUs)
'vmap' Vectorize particles on a single device
'none' Sequential execution (useful for debugging)

With pmap, particles are split evenly across devices. Each device computes log-likelihood and gradients for its particles independently. The SVGD kernel computation and particle updates are also parallelized.

The n_devices parameter limits how many devices pmap uses (default: all available).

graph = Graph(coalescent)

# Simulate data
graph.update_weights([7.0])
observed_data = graph.sample(500)

# Auto-parallelized SVGD (uses pmap if multiple devices)
start = time.time()
result = graph.svgd(
    observed_data,
    n_particles=50,
    n_iterations=50,
    progress=False
)
elapsed = time.time() - start
print(f"SVGD completed in {elapsed:.1f}s")
print(f"Posterior mean: {result['theta_mean']}")
print(f"Posterior std:  {result['theta_std']}")

To explicitly control the strategy:

# Force pmap with specific device count
result_pmap = graph.svgd(
    observed_data,
    n_particles=50,
    n_iterations=50,
    parallel='pmap',
    n_devices=len(jax.devices()),
    progress=False
)
print(f"pmap posterior mean: {result_pmap['theta_mean']}")

# Force vmap (single-device vectorization)
result_vmap = graph.svgd(
    observed_data,
    n_particles=50,
    n_iterations=50,
    parallel='vmap',
    progress=False
)
print(f"vmap posterior mean: {result_vmap['theta_mean']}")

Parallelism outside SVGD

Graph construction and moment/expectation computation are not automatically parallelized — they run on a single CPU. The graph construction callback is inherently sequential (state-space exploration), and Gaussian elimination is a single dense computation.

However, if you need to evaluate a model at many parameter values (outside SVGD), you can use JAX’s vmap or pmap directly with trace-based evaluation:

import jax
import jax.numpy as jnp

# Build and record trace
graph = Graph(coalescent)
model = graph.pmf_from_graph()

# Evaluate at many parameter values in parallel
theta_grid = jnp.linspace(0.5, 10.0, 100).reshape(-1, 1)
times = jnp.array([1.0, 2.0, 3.0])

# vmap over parameter vectors
batch_model = jax.vmap(lambda theta: model(theta, times))
all_pdfs = batch_model(theta_grid)  # Shape: (100, 3)

For multi-device parallelism, replace vmap with pmap and reshape the input to (n_devices, batch_per_device, ...).

Context managers

Phasic provides context managers for temporarily changing the parallelization strategy.

Disabling parallelism

The disable_parallel() context manager forces sequential execution, which is useful for debugging:

with disable_parallel():
    result = graph.svgd(
        observed_data,
        n_particles=20,
        n_iterations=10,
        progress=False
    )
    print(f"Sequential posterior mean: {result['theta_mean']}")

# Outside the context, parallelism is restored
current = get_parallel_config()
if current:
    print(f"Strategy restored to: {current.strategy}")

Custom parallel configuration

The parallel_config() context manager allows temporarily switching to any strategy:

# Temporarily switch to vmap (single-device vectorization)
with parallel_config(strategy='vmap'):
    result = graph.svgd(
        observed_data,
        n_particles=20,
        n_iterations=10,
        progress=False
    )
    print(f"vmap posterior mean: {result['theta_mean']}")

SLURM clusters

Phasic detects SLURM environments automatically and configures JAX accordingly. There are two modes:

Mode SLURM config Use case
Single-node --cpus-per-task=N Multiple CPUs on one machine
Multi-node --nodes=N --ntasks-per-node=1 Distribute across machines

Single-node SLURM

For single-node jobs, phasic reads SLURM_CPUS_PER_TASK and creates that many JAX devices. The SLURM script is straightforward:

#!/bin/bash
#SBATCH --job-name=svgd_inference
#SBATCH --cpus-per-task=16
#SBATCH --mem-per-cpu=4G
#SBATCH --time=01:00:00

python my_inference.py

The Python script needs no special configuration — init_parallel() detects the SLURM allocation:

from phasic import Graph, init_parallel

config = init_parallel()  # Detects 16 CPUs from SLURM
print(f"Using {config.device_count} devices")  # 16

graph = Graph(my_model)
graph.update_weights([7.0])
observed_data = graph.sample(1000)

# SVGD automatically uses pmap across 16 devices
result = graph.svgd(observed_data, n_particles=160)

Multi-node SLURM

For multi-node jobs, phasic uses jax.distributed.initialize() to coordinate computation across machines. The SLURM batch script sets up the coordinator address and launches one process per node via srun:

#!/bin/bash
#SBATCH --nodes=4
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=8
#SBATCH --time=01:00:00

# First node becomes the coordinator
COORDINATOR_NODE=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
export SLURM_COORDINATOR_ADDRESS=$COORDINATOR_NODE
export JAX_COORDINATOR_PORT=12345

# Configure XLA for local CPUs
export XLA_FLAGS="--xla_force_host_platform_device_count=$SLURM_CPUS_PER_TASK"
export JAX_PLATFORMS=cpu
export JAX_ENABLE_X64=1

# Launch one process per node
srun python my_distributed_inference.py

With 4 nodes x 8 CPUs each, this creates 32 global JAX devices.

Multi-node Python script

The Python script uses phasic’s distributed utilities to initialize JAX across nodes:

from phasic import (
    Graph, init_parallel,
    detect_slurm_environment,
    get_coordinator_address,
    initialize_jax_distributed
)
import jax

# Step 1: Detect SLURM environment
slurm_env = detect_slurm_environment()
print(f"Process {slurm_env['process_id']}/{slurm_env['num_processes']}")
print(f"CPUs per task: {slurm_env['cpus_per_task']}")
print(f"Nodes: {slurm_env['node_count']}")

# Step 2: Get coordinator address (first node)
coordinator = get_coordinator_address(slurm_env)
print(f"Coordinator: {coordinator}")

# Step 3: Initialize JAX distributed
initialize_jax_distributed(
    coordinator_address=coordinator,
    num_processes=slurm_env['num_processes'],
    process_id=slurm_env['process_id']
)

print(f"Local devices:  {len(jax.local_devices())}")
print(f"Global devices: {len(jax.devices())}")

# Step 4: Build model and run SVGD
graph = Graph(my_model)
graph.update_weights([7.0])
observed_data = graph.sample(1000)

# SVGD distributes particles across all 32 global devices
result = graph.svgd(
    observed_data,
    n_particles=320,    # Must be divisible by global device count
    parallel='pmap'
)

SLURM detection API

The detect_slurm_environment() function parses SLURM environment variables and returns a dictionary with the allocation details:

from phasic import detect_slurm_environment

slurm_env = detect_slurm_environment()
print(f"Running under SLURM: {slurm_env['is_slurm']}")

if slurm_env['is_slurm']:
    print(f"Job ID:           {slurm_env['job_id']}")
    print(f"Process ID:       {slurm_env['process_id']}")
    print(f"Total processes:  {slurm_env['num_processes']}")
    print(f"CPUs per task:    {slurm_env['cpus_per_task']}")
    print(f"Nodes:            {slurm_env['node_count']}")

Parallelism and graph construction

Graph construction explores the state space by calling the user-provided callback function repeatedly. This is inherently sequential — there is no automatic parallelization of graph construction.

However, if you need to build multiple graphs (e.g., for different sample sizes), you can parallelize at the Python level:

from concurrent.futures import ProcessPoolExecutor

def build_graph(n):
    @with_ipv([n] + [0] * (n - 1))
    def coalescent(state):
        # ... callback ...
        return transitions
    return Graph(coalescent, cache_graph=True)

# Build graphs for different sample sizes in parallel
with ProcessPoolExecutor(max_workers=4) as pool:
    graphs = list(pool.map(build_graph, [4, 6, 8, 10]))

Using cache_graph=True ensures that each graph is built only once and loaded from cache on subsequent runs.

Parallelism and trace evaluation

Gaussian elimination (trace recording) is O(n³) and runs on a single CPU. However, the resulting elimination trace can be evaluated in parallel at different parameter values using JAX’s vmap or pmap.

When you call graph.svgd(), phasic automatically:

  1. Records the elimination trace once (sequential, O(n³))
  2. Creates a JIT-compiled function that evaluates the trace (O(n) per evaluation)
  3. Uses pmap or vmap to evaluate this function across particles in parallel

This means the expensive elimination happens only once, and all subsequent evaluations during SVGD iterations benefit from parallelism.

Computing expectations in parallel

Moment and expectation computations (graph.moments(), graph.expected_waiting_time()) operate on a single graph with fixed parameters and are not parallelized.

To evaluate expectations at many parameter values in parallel, use the model function with vmap:

# Create parameterized model
model = graph.pmf_and_moments_from_graph(nr_moments=2)

# Evaluate at 100 parameter values in parallel
theta_values = jnp.linspace(0.5, 10.0, 100).reshape(-1, 1)
times = jnp.array([1.0, 2.0, 3.0])

def eval_at_theta(theta):
    pmf, moments = model(theta, times)
    return moments

all_moments = jax.vmap(eval_at_theta)(theta_values)  # (100, 2)

Environment variables

Phasic recognizes these environment variables for configuring parallelism:

Variable Description
PTDALG_CPUS Override the number of CPUs used (set before importing phasic)
XLA_FLAGS JAX/XLA flags; phasic sets --xla_force_host_platform_device_count automatically
JAX_PLATFORMS Platform selection; phasic sets to cpu by default
JAX_ENABLE_X64 Enable 64-bit precision; phasic enables by default
SLURM_CPUS_PER_TASK Read by phasic on SLURM clusters to set device count
SLURM_COORDINATOR_ADDRESS Manual override for coordinator node address
JAX_COORDINATOR_PORT Port for JAX distributed coordinator (default: 12345)

Tips and troubleshooting

Particle count: When using pmap, the number of particles must be divisible by the number of devices. If not, SVGD will raise a ValueError.

Import order: Always import phasic before JAX. If you see an ImportError about JAX being imported first, restart your kernel and fix the import order.

Debugging: Use parallel='none' or the disable_parallel() context manager to get readable error messages. Parallel execution can obscure the source of errors.

Performance: For small models (< 100 vertices), the overhead of pmap may outweigh the benefit. Use parallel='vmap' or let the auto-detection choose.

SLURM proxies: On some HPC systems, HTTP proxy variables can interfere with JAX distributed initialization. Phasic’s initialize_jax_distributed() temporarily unsets proxy variables during initialization.

Graph construction scaling: Graph construction is CPU-bound and single-threaded. For very large state spaces, consider using cache_graph=True to avoid rebuilding across sessions.

Summary

Operation Parallelism How
Graph construction Sequential Cache with cache_graph=True
Gaussian elimination Sequential Cache with cache_trace=True
Trace evaluation vmap/pmap Automatic in SVGD; manual via jax.vmap
SVGD particle updates pmap/vmap/none graph.svgd(parallel=...)
Moment computation Sequential Parallelize manually via jax.vmap over parameters
Environment Setup
Local (notebook/script) Automatic at import; override with PTDALG_CPUS
SLURM single-node #SBATCH --cpus-per-task=N + init_parallel()
SLURM multi-node Batch script with coordinator + initialize_jax_distributed()