from phasic import (
Graph, with_ipv,
init_parallel,
detect_environment, get_parallel_config,
parallel_config, disable_parallel,
EnvironmentInfo, ParallelConfig,
set_log_level
)
import jax
import numpy as np
import time
from vscodenb import set_vscode_theme
set_vscode_theme()Distributed computing
Phasic automatically configures JAX for multi-CPU parallelism at import time, and provides utilities for scaling to multi-node SLURM clusters. This tutorial covers:
- Automatic multi-CPU setup — how phasic detects and uses available CPUs
- Parallelism in SVGD — how particles are distributed across devices
- Manual configuration — overriding defaults and context managers
- SLURM clusters — single-node and multi-node distributed computing
The parallelization strategy depends on the number of available JAX devices:
| Devices | Strategy | Description |
|---|---|---|
| >1 | pmap |
Particles distributed across devices |
| 1 (multi-CPU) | vmap |
Vectorized computation on single device |
| 1 (single-CPU) | none |
Sequential execution |
We use a simple parameterized coalescent model throughout:
nr_samples = 4
@with_ipv([nr_samples] + [0] * (nr_samples - 1))
def coalescent(state):
transitions = []
for i in range(state.size):
for j in range(i, state.size):
same = int(i == j)
if same and state[i] < 2:
continue
if not same and (state[i] < 1 or state[j] < 1):
continue
new = state.copy()
new[i] -= 1
new[j] -= 1
new[i + j + 1] += 1
transitions.append((new, state[i] * (state[j] - same) / (1 + same)))
return transitionsAutomatic multi-CPU setup
When you import phasic, it automatically:
- Detects available CPUs — on Apple Silicon, it uses only performance cores; otherwise,
os.cpu_count() - Sets
XLA_FLAGS— configures--xla_force_host_platform_device_countso JAX creates one virtual device per CPU - Enables 64-bit precision — sets
JAX_ENABLE_X64=1for numerical accuracy - Sets platform to CPU — via
JAX_PLATFORMS=cpu
This happens before JAX is imported, which is why phasic must be imported before JAX:
from phasic import Graph # Sets XLA_FLAGS, then imports JAX
import jax # Picks up the flags phasic setIf JAX is already imported, phasic raises an ImportError with instructions.
You can check the current device setup:
devices = jax.devices()
print(f"Number of JAX devices: {len(devices)}")
print(f"Device type: {devices[0].platform}")Overriding CPU count
To change the number of CPUs phasic uses, set the PTDALG_CPUS environment variable before importing phasic:
export PTDALG_CPUS=4
python my_script.pyOr in a notebook (before any imports):
import os
os.environ['PTDALG_CPUS'] = '4'
from phasic import GraphInspecting the detected environment
The detect_environment() function returns an EnvironmentInfo object describing the current execution context:
env = detect_environment()
print(f"Environment type: {env.env_type}")
print(f"Interactive: {env.is_interactive}")
print(f"Available CPUs: {env.available_cpus}")
print(f"SLURM detected: {env.slurm_info is not None}")
print(f"JAX imported: {env.jax_already_imported}")Explicit initialization with init_parallel()
For more control, call init_parallel() to explicitly configure parallelism. This is especially useful when you want to set a specific CPU count or when running on a SLURM cluster where automatic detection should be triggered:
config = init_parallel()
print(f"Device count: {config.device_count}")
print(f"Local device count: {config.local_device_count}")
print(f"Strategy: {config.strategy}")
print(f"Environment: {config.env_info.env_type}")You can pass an explicit CPU count:
config = init_parallel(cpus=8) # Use exactly 8 devicesThe returned ParallelConfig is stored globally and used by graph.svgd() to select its parallelization strategy.
Parallelism in SVGD
The graph.svgd() method automatically parallelizes particle updates across available devices. The parallel parameter controls the strategy:
| Value | Behavior |
|---|---|
None (default) |
Auto-select: pmap if multiple devices, vmap otherwise |
'pmap' |
Distribute particles across devices (multiple CPUs/GPUs) |
'vmap' |
Vectorize particles on a single device |
'none' |
Sequential execution (useful for debugging) |
With pmap, particles are split evenly across devices. Each device computes log-likelihood and gradients for its particles independently. The SVGD kernel computation and particle updates are also parallelized.
The n_devices parameter limits how many devices pmap uses (default: all available).
graph = Graph(coalescent)
# Simulate data
graph.update_weights([7.0])
observed_data = graph.sample(500)
# Auto-parallelized SVGD (uses pmap if multiple devices)
start = time.time()
result = graph.svgd(
observed_data,
n_particles=50,
n_iterations=50,
progress=False
)
elapsed = time.time() - start
print(f"SVGD completed in {elapsed:.1f}s")
print(f"Posterior mean: {result['theta_mean']}")
print(f"Posterior std: {result['theta_std']}")To explicitly control the strategy:
# Force pmap with specific device count
result_pmap = graph.svgd(
observed_data,
n_particles=50,
n_iterations=50,
parallel='pmap',
n_devices=len(jax.devices()),
progress=False
)
print(f"pmap posterior mean: {result_pmap['theta_mean']}")
# Force vmap (single-device vectorization)
result_vmap = graph.svgd(
observed_data,
n_particles=50,
n_iterations=50,
parallel='vmap',
progress=False
)
print(f"vmap posterior mean: {result_vmap['theta_mean']}")Parallelism outside SVGD
Graph construction and moment/expectation computation are not automatically parallelized — they run on a single CPU. The graph construction callback is inherently sequential (state-space exploration), and Gaussian elimination is a single dense computation.
However, if you need to evaluate a model at many parameter values (outside SVGD), you can use JAX’s vmap or pmap directly with trace-based evaluation:
import jax
import jax.numpy as jnp
# Build and record trace
graph = Graph(coalescent)
model = graph.pmf_from_graph()
# Evaluate at many parameter values in parallel
theta_grid = jnp.linspace(0.5, 10.0, 100).reshape(-1, 1)
times = jnp.array([1.0, 2.0, 3.0])
# vmap over parameter vectors
batch_model = jax.vmap(lambda theta: model(theta, times))
all_pdfs = batch_model(theta_grid) # Shape: (100, 3)For multi-device parallelism, replace vmap with pmap and reshape the input to (n_devices, batch_per_device, ...).
Context managers
Phasic provides context managers for temporarily changing the parallelization strategy.
Disabling parallelism
The disable_parallel() context manager forces sequential execution, which is useful for debugging:
with disable_parallel():
result = graph.svgd(
observed_data,
n_particles=20,
n_iterations=10,
progress=False
)
print(f"Sequential posterior mean: {result['theta_mean']}")
# Outside the context, parallelism is restored
current = get_parallel_config()
if current:
print(f"Strategy restored to: {current.strategy}")Custom parallel configuration
The parallel_config() context manager allows temporarily switching to any strategy:
# Temporarily switch to vmap (single-device vectorization)
with parallel_config(strategy='vmap'):
result = graph.svgd(
observed_data,
n_particles=20,
n_iterations=10,
progress=False
)
print(f"vmap posterior mean: {result['theta_mean']}")SLURM clusters
Phasic detects SLURM environments automatically and configures JAX accordingly. There are two modes:
| Mode | SLURM config | Use case |
|---|---|---|
| Single-node | --cpus-per-task=N |
Multiple CPUs on one machine |
| Multi-node | --nodes=N --ntasks-per-node=1 |
Distribute across machines |
Single-node SLURM
For single-node jobs, phasic reads SLURM_CPUS_PER_TASK and creates that many JAX devices. The SLURM script is straightforward:
#!/bin/bash
#SBATCH --job-name=svgd_inference
#SBATCH --cpus-per-task=16
#SBATCH --mem-per-cpu=4G
#SBATCH --time=01:00:00
python my_inference.pyThe Python script needs no special configuration — init_parallel() detects the SLURM allocation:
from phasic import Graph, init_parallel
config = init_parallel() # Detects 16 CPUs from SLURM
print(f"Using {config.device_count} devices") # 16
graph = Graph(my_model)
graph.update_weights([7.0])
observed_data = graph.sample(1000)
# SVGD automatically uses pmap across 16 devices
result = graph.svgd(observed_data, n_particles=160)Multi-node SLURM
For multi-node jobs, phasic uses jax.distributed.initialize() to coordinate computation across machines. The SLURM batch script sets up the coordinator address and launches one process per node via srun:
#!/bin/bash
#SBATCH --nodes=4
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=8
#SBATCH --time=01:00:00
# First node becomes the coordinator
COORDINATOR_NODE=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
export SLURM_COORDINATOR_ADDRESS=$COORDINATOR_NODE
export JAX_COORDINATOR_PORT=12345
# Configure XLA for local CPUs
export XLA_FLAGS="--xla_force_host_platform_device_count=$SLURM_CPUS_PER_TASK"
export JAX_PLATFORMS=cpu
export JAX_ENABLE_X64=1
# Launch one process per node
srun python my_distributed_inference.pyWith 4 nodes x 8 CPUs each, this creates 32 global JAX devices.
Multi-node Python script
The Python script uses phasic’s distributed utilities to initialize JAX across nodes:
from phasic import (
Graph, init_parallel,
detect_slurm_environment,
get_coordinator_address,
initialize_jax_distributed
)
import jax
# Step 1: Detect SLURM environment
slurm_env = detect_slurm_environment()
print(f"Process {slurm_env['process_id']}/{slurm_env['num_processes']}")
print(f"CPUs per task: {slurm_env['cpus_per_task']}")
print(f"Nodes: {slurm_env['node_count']}")
# Step 2: Get coordinator address (first node)
coordinator = get_coordinator_address(slurm_env)
print(f"Coordinator: {coordinator}")
# Step 3: Initialize JAX distributed
initialize_jax_distributed(
coordinator_address=coordinator,
num_processes=slurm_env['num_processes'],
process_id=slurm_env['process_id']
)
print(f"Local devices: {len(jax.local_devices())}")
print(f"Global devices: {len(jax.devices())}")
# Step 4: Build model and run SVGD
graph = Graph(my_model)
graph.update_weights([7.0])
observed_data = graph.sample(1000)
# SVGD distributes particles across all 32 global devices
result = graph.svgd(
observed_data,
n_particles=320, # Must be divisible by global device count
parallel='pmap'
)SLURM detection API
The detect_slurm_environment() function parses SLURM environment variables and returns a dictionary with the allocation details:
from phasic import detect_slurm_environment
slurm_env = detect_slurm_environment()
print(f"Running under SLURM: {slurm_env['is_slurm']}")
if slurm_env['is_slurm']:
print(f"Job ID: {slurm_env['job_id']}")
print(f"Process ID: {slurm_env['process_id']}")
print(f"Total processes: {slurm_env['num_processes']}")
print(f"CPUs per task: {slurm_env['cpus_per_task']}")
print(f"Nodes: {slurm_env['node_count']}")Parallelism and graph construction
Graph construction explores the state space by calling the user-provided callback function repeatedly. This is inherently sequential — there is no automatic parallelization of graph construction.
However, if you need to build multiple graphs (e.g., for different sample sizes), you can parallelize at the Python level:
from concurrent.futures import ProcessPoolExecutor
def build_graph(n):
@with_ipv([n] + [0] * (n - 1))
def coalescent(state):
# ... callback ...
return transitions
return Graph(coalescent, cache_graph=True)
# Build graphs for different sample sizes in parallel
with ProcessPoolExecutor(max_workers=4) as pool:
graphs = list(pool.map(build_graph, [4, 6, 8, 10]))Using cache_graph=True ensures that each graph is built only once and loaded from cache on subsequent runs.
Parallelism and trace evaluation
Gaussian elimination (trace recording) is O(n³) and runs on a single CPU. However, the resulting elimination trace can be evaluated in parallel at different parameter values using JAX’s vmap or pmap.
When you call graph.svgd(), phasic automatically:
- Records the elimination trace once (sequential, O(n³))
- Creates a JIT-compiled function that evaluates the trace (O(n) per evaluation)
- Uses
pmaporvmapto evaluate this function across particles in parallel
This means the expensive elimination happens only once, and all subsequent evaluations during SVGD iterations benefit from parallelism.
Computing expectations in parallel
Moment and expectation computations (graph.moments(), graph.expected_waiting_time()) operate on a single graph with fixed parameters and are not parallelized.
To evaluate expectations at many parameter values in parallel, use the model function with vmap:
# Create parameterized model
model = graph.pmf_and_moments_from_graph(nr_moments=2)
# Evaluate at 100 parameter values in parallel
theta_values = jnp.linspace(0.5, 10.0, 100).reshape(-1, 1)
times = jnp.array([1.0, 2.0, 3.0])
def eval_at_theta(theta):
pmf, moments = model(theta, times)
return moments
all_moments = jax.vmap(eval_at_theta)(theta_values) # (100, 2)Environment variables
Phasic recognizes these environment variables for configuring parallelism:
| Variable | Description |
|---|---|
PTDALG_CPUS |
Override the number of CPUs used (set before importing phasic) |
XLA_FLAGS |
JAX/XLA flags; phasic sets --xla_force_host_platform_device_count automatically |
JAX_PLATFORMS |
Platform selection; phasic sets to cpu by default |
JAX_ENABLE_X64 |
Enable 64-bit precision; phasic enables by default |
SLURM_CPUS_PER_TASK |
Read by phasic on SLURM clusters to set device count |
SLURM_COORDINATOR_ADDRESS |
Manual override for coordinator node address |
JAX_COORDINATOR_PORT |
Port for JAX distributed coordinator (default: 12345) |
Tips and troubleshooting
Particle count: When using pmap, the number of particles must be divisible by the number of devices. If not, SVGD will raise a ValueError.
Import order: Always import phasic before JAX. If you see an ImportError about JAX being imported first, restart your kernel and fix the import order.
Debugging: Use parallel='none' or the disable_parallel() context manager to get readable error messages. Parallel execution can obscure the source of errors.
Performance: For small models (< 100 vertices), the overhead of pmap may outweigh the benefit. Use parallel='vmap' or let the auto-detection choose.
SLURM proxies: On some HPC systems, HTTP proxy variables can interfere with JAX distributed initialization. Phasic’s initialize_jax_distributed() temporarily unsets proxy variables during initialization.
Graph construction scaling: Graph construction is CPU-bound and single-threaded. For very large state spaces, consider using cache_graph=True to avoid rebuilding across sessions.
Summary
| Operation | Parallelism | How |
|---|---|---|
| Graph construction | Sequential | Cache with cache_graph=True |
| Gaussian elimination | Sequential | Cache with cache_trace=True |
| Trace evaluation | vmap/pmap |
Automatic in SVGD; manual via jax.vmap |
| SVGD particle updates | pmap/vmap/none |
graph.svgd(parallel=...) |
| Moment computation | Sequential | Parallelize manually via jax.vmap over parameters |
| Environment | Setup |
|---|---|
| Local (notebook/script) | Automatic at import; override with PTDALG_CPUS |
| SLURM single-node | #SBATCH --cpus-per-task=N + init_parallel() |
| SLURM multi-node | Batch script with coordinator + initialize_jax_distributed() |