Phase-type distributions can be extended to model joint probabilities of multiple random variables. This is particularly important in applications like population genetics, where we might be interested in the joint distribution of coalescence times for multiple lineages, or in reliability theory, where we might want the joint distribution of failure times for multiple components. This section shows how to compute exact joint probabilities.
from phasic import Graph, with_ipv # ALWAYS import phasic first to set jax backend correctlyimport sysimport numpy as npimport pandas as pdimport matplotlib.pyplot as pltfrom matplotlib.colors import LogNormimport seaborn as snsfrom typing import Optionalfrom tqdm.auto import tqdmfrom vscodenb import set_vscode_theme, vscode_themenp.random.seed(42)set_vscode_theme()sns.set_palette('tab10')
Overriding theme from NOTEBOOK_THEME environment variable. <phasic._DeviceListFilter object at 0x1538b9050>
nr_samples =4@with_ipv([nr_samples]+[0]*(nr_samples-1))def coalescent(state): transitions = []for i inrange(state.size):for j inrange(i, state.size): same =int(i == j)if same and state[i] <2:continueifnot same and (state[i] <1or state[j] <1):continue new = state.copy() new[i] -=1 new[j] -=1 new[i+j+1] +=1 transitions.append((new, state[i]*(state[j]-same)/(1+same)))return transitions
def joint_prob_reward_callback(state, current_rewards=None, mutation_rate=1, reward_limit=10, tot_reward_limit=np.inf):# reward_limits = np.append(np.repeat(reward_limit, len(state)-1), 0) reward_limits = np.repeat(reward_limit, len(state)) reward_dims =len(reward_limits)if current_rewards isNone: current_rewards = np.zeros(reward_dims) reward_rates = np.zeros(reward_dims) trash_rate =0for i inrange(reward_dims): rate = state[i] * mutation_rate r = np.zeros(reward_dims) r[i] =1if np.all(current_rewards + r <= reward_limits) and np.sum(current_rewards + r) <= tot_reward_limit: reward_rates[i] = rateelse: trash_rate = trash_rate + ratereturn np.append(reward_rates, trash_rate)def joint_prob_graph(graph, reward_rates_callback, mutation_rate:Optional[float]=None, reward_limit:Optional[int]=0, tot_reward_limit:Optional[float]=np.inf): starting_vertex = graph.starting_vertex() reward_dims =len(reward_rates_callback(starting_vertex.state(), mutation_rate=mutation_rate, reward_limit=reward_limit, tot_reward_limit=tot_reward_limit )) -1# a bit of a hack. -1 to not count trash rate... orig_state_vector_length =len(graph.vertex_at(1).state()) state_vector_length = orig_state_vector_length + reward_dims state_indices = np.arange(orig_state_vector_length) joint_reward_state_indices = np.arange(orig_state_vector_length, state_vector_length) new_graph = Graph(state_vector_length) new_starting_vertex = new_graph.starting_vertex() null_rewards = np.zeros(reward_dims) index =0# add edges from starting vertex (IPV)for edge in starting_vertex.edges(): new_starting_vertex.add_edge( new_graph.find_or_create_vertex( np.append(edge.to().state(), null_rewards).astype(int)),1) prev_completion =0 pbar = tqdm(position=0, total=1, miniters=0, desc='visited/created', bar_format='{l_bar}{bar}') index = index +1 trash_rates = {} t_vertex_indices = np.array([], dtype=int)while index < new_graph.vertices_length(): new_vertex = new_graph.vertex_at(index) new_state = new_vertex.state() state = new_vertex.state()[state_indices] vertex = graph.find_vertex(state)# non-mutation transitions (coalescence)for edge in vertex.edges(): new_child_state = np.append( edge.to().state(), new_state[joint_reward_state_indices] )if np.all(new_state == new_child_state):continue new_child_vertex = new_graph.find_or_create_vertex( new_child_state) new_vertex.add_edge(new_child_vertex, edge.weight() )# if new child was absorbing in base_graph, record it as "t-state":ifnot graph.find_vertex(new_child_state[state_indices]).edges(): t_vertex_indices = np.append(t_vertex_indices, new_child_vertex.index()) # mutation transitions current_state = new_state[state_indices] current_rewards = new_state[joint_reward_state_indices] rates = reward_rates_callback(current_state, current_rewards, mutation_rate=mutation_rate, reward_limit=reward_limit, tot_reward_limit=tot_reward_limit) # list of all allowed mutation transition rates with trash rate appended# print('STATE:', state, 'RATES:', rates) trash_rates[index] = rates[reward_dims]for i inrange(reward_dims): rate = rates[i]if rate >0: new_rewards = current_rewards.copy() new_rewards[i] = new_rewards[i] +1 new_child_state = np.append(current_state, new_rewards)# if new child was absorbing in base_graph, do not add any mutation childrenifnot graph.find_vertex(new_child_state[state_indices]).edges():continue new_child_vertex = new_graph.find_or_create_vertex(new_child_state) new_vertex.add_edge( new_child_vertex, # if I use create_vertex here, I cannot find it again with find_vertex... rate )# # if new child was absorbing, record at "t-states": # if (length(edges(find_vertex(graph, new_child_state[state_indices]))) == 0) {# t_vertex_indices = c(t_vertex_indices, new_child_vertex$index) index = index +1 completion = index/new_graph.vertices_length() pbar.update(completion - prev_completion) prev_completion = completion# if not index % 10_000:# graph_size = new_graph.vertices_length()# print(f'index: {index:>6} vertices: {graph_size:>6} ratio: {index/graph_size:>4.2}', file=sys.stderr)# sys.stderr.flush() pbar.close()# trash states trash_vertex = new_graph.find_or_create_vertex(np.repeat(0, state_vector_length)) trash_loop_vertex = new_graph.create_vertex(np.repeat(0, state_vector_length)) trash_vertex.add_edge(trash_loop_vertex, 1) trash_loop_vertex.add_edge(trash_vertex, 1)# add trash edgesfor i, rate in trash_rates.items():if rate >0: new_graph.vertex_at(i).add_edge(trash_vertex, rate) # add edges from t-states to new final absorbing new_absorbing = new_graph.create_vertex(np.repeat(0, state_vector_length)) t_vertex_indices = np.unique(t_vertex_indices)for i in t_vertex_indices: new_graph.vertex_at(i).add_edge(new_absorbing, 1)# normalize graph weights_were_multiplied_with = new_graph.normalize()return new_graph
Now lets construct the joint probability graph. In addition to a mutation rate, you must specify an upper bound on the number of discrete rewards you want to account for. You can specify a maximum discrete value of each feature (how many tons of each kind you want to account for) using the reward_limit keyword argument. This takes both a scalar argument for a maximum applied to all features or an array with a maximum for each feature. You can also use the tot_reward_limit keyword argument to specify a maximum on total rewards for each feature combination. Both arguments may be specified to define the upper bound.
Since we only model a subset of feature combinations, the joint distribution is deficient. The deficit is the total probability of feature combinations beyond the upper bound. You can easily compute this as one minus the sum of joint probabilities.
base_graph = Graph(coalescent)base_graph.plot()
Overriding theme from NOTEBOOK_THEME environment variable. <phasic._DeviceListFilter object at 0x1538b9050>
Overriding theme from NOTEBOOK_THEME environment variable. <phasic._DeviceListFilter object at 0x1538b9050>
Graph has too many nodes (168). Please set max_nodes to a higher value.
Expectation is inf because of the infinite loop between the trash states:
joint_graph.expectation()
[ERROR] phasic.c: Failed to parse multiplier '@.Inf@e182798644813286610' at command 169
[WARNING] phasic.c: MPFR execution failed - falling back to double precision
inf
The joint probabilities are extracted as the sojourn times of appropriate transient states connecting only to the absorbing state:
def joint_prob_table(joint_graph, obs2idx): idx2obs = {v: k for k, v in obs2idx.items()}assertlen(idx2obs) ==len(obs2idx) t_indices =list(idx2obs.keys()) sojourn_times = joint_graph.expected_sojourn_time(t_indices)assertlen(sojourn_times) ==len(t_indices) records = []for idx, prob inzip(t_indices, sojourn_times): obs = idx2obs[idx] records.append([*obs, prob]) joint_probs = pd.DataFrame(records, columns=list(range(1, nr_samples+1)) + ['prob'])return joint_probs
Overriding theme from NOTEBOOK_THEME environment variable. <phasic._DeviceListFilter object at 0x1538b9050>
Compute the marginals to verify it matches the standard SFS:
tons = np.arange(1, nr_samples)marginals = [np.sum(joint[t] * joint['prob']) for t in tons]sns.barplot(x=tons, y=marginals);
Rabbit stuff
# Construct a rabbit model that tracks time to depletion for each island separatelydef construct_joint_time_model(nr_rabbits, flood_left, flood_right):""" Construct a model tracking joint distribution of depletion times. State: [rabbits_left, rabbits_right, left_depleted, right_depleted] """ graph = Graph(4) initial_state = [nr_rabbits, 0, 0, 0] # Start with all rabbits on left graph.starting_vertex().add_edge(graph.find_or_create_vertex(initial_state), 1) index =1while index < graph.vertices_length(): vertex = graph.vertex_at(index) state =list(vertex.state())# If left island has rabbits and not yet depletedif state[0] >0and state[2] ==0:# Jump to right child_state = [state[0] -1, state[1] +1, state[2], state[3]]if child_state[0] ==0: child_state[2] =1# Mark left as depleted vertex.add_edge(graph.find_or_create_vertex(child_state), 1)# Left flooding child_state = [0, state[1], 1, state[3]] # All left rabbits die, mark depleted vertex.add_edge(graph.find_or_create_vertex(child_state), flood_left)# If right island has rabbits and not yet depletedif state[1] >0and state[3] ==0:# Jump to left child_state = [state[0] +1, state[1] -1, state[2], state[3]]if child_state[1] ==0: child_state[3] =1# Mark right as depleted vertex.add_edge(graph.find_or_create_vertex(child_state), 1)# Right flooding child_state = [state[0], 0, state[2], 1] # All right rabbits die, mark depleted vertex.add_edge(graph.find_or_create_vertex(child_state), flood_right) index +=1return graph# Create the joint modeljoint_graph = construct_joint_time_model(3, 2.0, 4.0)print(f"Created joint model with {joint_graph.vertices_length()} states")print(f"\nFirst few states (rabbits_left, rabbits_right, left_depleted, right_depleted):")for i inrange(min(10, joint_graph.vertices_length())):print(f" State {i}: {joint_graph.vertex_at(i).state()}")
Created joint model with 24 states
First few states (rabbits_left, rabbits_right, left_depleted, right_depleted):
State 0: [0 0 0 0]
State 1: [3 0 0 0]
State 2: [2 1 0 0]
State 3: [0 0 1 0]
State 4: [1 2 0 0]
State 5: [0 1 1 0]
State 6: [3 0 0 1]
State 7: [2 0 0 1]
State 8: [0 3 1 0]
State 9: [0 2 1 0]
Now we can use rewards to extract information about the joint distribution. By defining rewards that are non-zero only until each island is depleted, we can compute the marginal time until depletion for each island. By looking at the joint accumulated rewards, we can explore the correlation between these depletion times.
# Define rewards: earn reward while island is not depletedstates = joint_graph.states()# Reward 1: time spent before left depletion (left_depleted == 0)reward_before_left_depletion = (states[:, 2] ==0).astype(float)# Reward 2: time spent before right depletion (right_depleted == 0)reward_before_right_depletion = (states[:, 3] ==0).astype(float)# Compute expectationsE_time_to_left_depletion = joint_graph.expectation(reward_before_left_depletion)E_time_to_right_depletion = joint_graph.expectation(reward_before_right_depletion)print(f"Expected time until left island depleted: {E_time_to_left_depletion:.4f}")print(f"Expected time until right island depleted: {E_time_to_right_depletion:.4f}")# Compute covariance between the two timescov = joint_graph.covariance(reward_before_left_depletion, reward_before_right_depletion)var_left = joint_graph.variance(reward_before_left_depletion)var_right = joint_graph.variance(reward_before_right_depletion)correlation = cov / np.sqrt(var_left * var_right)print(f"\nCovariance between depletion times: {cov:.6f}")print(f"Correlation: {correlation:.4f}")print("The positive correlation indicates that when the left island takes longer to deplete,")print("the right island also tends to take longer (rabbits jumping back and forth prolongs both)")
Expected time until left island depleted: 0.4836
Expected time until right island depleted: 0.4017
Covariance between depletion times: 0.131722
Correlation: 0.7704
The positive correlation indicates that when the left island takes longer to deplete,
the right island also tends to take longer (rabbits jumping back and forth prolongs both)
This framework for joint probabilities extends naturally to more complex scenarios. We can model multiple dependent processes, extract conditional distributions, and analyze the dependencies between different random variables in our model. The key is careful state space construction that encodes all relevant information, combined with judicious use of rewards to extract the quantities of interest. In population genetics applications, this approach is used to model the joint distribution of coalescence times across multiple loci or populations, capturing the complex dependencies induced by recombination and migration.
We can increase granularity for better performance: