Joint probability

Phase-type distributions can be extended to model joint probabilities of multiple random variables. This is particularly important in applications like population genetics, where we might be interested in the joint distribution of coalescence times for multiple lineages, or in reliability theory, where we might want the joint distribution of failure times for multiple components. This section shows how to compute exact joint probabilities.

from phasic import Graph, with_ipv # ALWAYS import phasic first to set jax backend correctly
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import seaborn as sns
from typing import Optional
from tqdm.auto import tqdm
from vscodenb import set_vscode_theme, vscode_theme

np.random.seed(42)
set_vscode_theme()
sns.set_palette('tab10')
Overriding theme from NOTEBOOK_THEME environment variable. <phasic._DeviceListFilter object at 0x1538b9050>
nr_samples = 4

@with_ipv([nr_samples]+[0]*(nr_samples-1))
def coalescent(state):
    transitions = []
    for i in range(state.size):
        for j in range(i, state.size):            
            same = int(i == j)
            if same and state[i] < 2:
                continue
            if not same and (state[i] < 1 or state[j] < 1):
                continue 
            new = state.copy()
            new[i] -= 1
            new[j] -= 1
            new[i+j+1] += 1
            transitions.append((new, state[i]*(state[j]-same)/(1+same)))
    return transitions

def joint_prob_reward_callback(state, current_rewards=None, 
                               mutation_rate=1, reward_limit=10, 
                               tot_reward_limit=np.inf):

    # reward_limits = np.append(np.repeat(reward_limit, len(state)-1), 0)
    reward_limits = np.repeat(reward_limit, len(state))
    
    reward_dims = len(reward_limits)
    if current_rewards is None:
        current_rewards = np.zeros(reward_dims)

    reward_rates = np.zeros(reward_dims)
    trash_rate = 0
    
    for i in range(reward_dims):
        rate = state[i] * mutation_rate 
        r = np.zeros(reward_dims)
        r[i] = 1
        if np.all(current_rewards + r <= reward_limits) and np.sum(current_rewards + r) <= tot_reward_limit:
            reward_rates[i] = rate
        else:
            trash_rate = trash_rate + rate

    return np.append(reward_rates, trash_rate)


def joint_prob_graph(graph, reward_rates_callback, mutation_rate:Optional[float]=None, 
                     reward_limit:Optional[int]=0, tot_reward_limit:Optional[float]=np.inf):

    starting_vertex = graph.starting_vertex()
    reward_dims = len(reward_rates_callback(starting_vertex.state(), mutation_rate=mutation_rate, 
                                            reward_limit=reward_limit, tot_reward_limit=tot_reward_limit )) - 1 # a bit of a hack. -1 to not count trash rate...

    orig_state_vector_length = len(graph.vertex_at(1).state())
    state_vector_length = orig_state_vector_length + reward_dims

    state_indices = np.arange(orig_state_vector_length)
    joint_reward_state_indices = np.arange(orig_state_vector_length, 
                                           state_vector_length)

    new_graph = Graph(state_vector_length)
    new_starting_vertex = new_graph.starting_vertex()

    null_rewards = np.zeros(reward_dims)

    index = 0
    # add edges from starting vertex (IPV)
    for edge in starting_vertex.edges():
        new_starting_vertex.add_edge(
          new_graph.find_or_create_vertex(
              np.append(edge.to().state(), null_rewards).astype(int)),
              1)

    prev_completion = 0
    pbar = tqdm(position=0, total=1, miniters=0, desc='visited/created', bar_format='{l_bar}{bar}')

    index = index + 1
    
    trash_rates = {}
    t_vertex_indices = np.array([], dtype=int)
    while index < new_graph.vertices_length():

        new_vertex = new_graph.vertex_at(index)
        new_state = new_vertex.state()
        state = new_vertex.state()[state_indices]
        vertex = graph.find_vertex(state)

        # non-mutation transitions (coalescence)
        for edge in vertex.edges():
            new_child_state = np.append(
                edge.to().state(), 
                new_state[joint_reward_state_indices]
                )

            if np.all(new_state == new_child_state):
                continue
                
            new_child_vertex = new_graph.find_or_create_vertex(
                new_child_state)
            new_vertex.add_edge(new_child_vertex,
                edge.weight()
            )

            # if new child was absorbing in base_graph, record it as "t-state":
            if not graph.find_vertex(new_child_state[state_indices]).edges():
                t_vertex_indices = np.append(t_vertex_indices, new_child_vertex.index()) 

        # mutation transitions
        current_state = new_state[state_indices]
        current_rewards = new_state[joint_reward_state_indices]
        rates = reward_rates_callback(current_state, current_rewards, 
                                    mutation_rate=mutation_rate, 
                                    reward_limit=reward_limit, 
                                    tot_reward_limit=tot_reward_limit) # list of all allowed mutation transition rates with trash rate appended
#        print('STATE:', state, 'RATES:', rates)

        trash_rates[index] = rates[reward_dims]
        for i in range(reward_dims):
            rate = rates[i]
            if rate > 0:
                new_rewards = current_rewards.copy()
                new_rewards[i] = new_rewards[i] + 1
                new_child_state = np.append(current_state, new_rewards)

                # if new child was absorbing in base_graph, do not add any mutation children
                if not graph.find_vertex(new_child_state[state_indices]).edges():
                    continue

                new_child_vertex = new_graph.find_or_create_vertex(new_child_state)
                new_vertex.add_edge(
                    new_child_vertex, # if I use create_vertex here, I cannot find it again with find_vertex...
                    rate
                    )
                
                # # if new child was absorbing, record at "t-states":                
                # if (length(edges(find_vertex(graph, new_child_state[state_indices]))) == 0) {
                #     t_vertex_indices = c(t_vertex_indices, new_child_vertex$index) 

        index = index + 1 

        completion = index/new_graph.vertices_length()
        pbar.update(completion - prev_completion)
        prev_completion = completion

        # if not index % 10_000:
        #     graph_size = new_graph.vertices_length()
        #     print(f'index: {index:>6}      vertices: {graph_size:>6}      ratio: {index/graph_size:>4.2}', file=sys.stderr)
        #     sys.stderr.flush()

    pbar.close()

    # trash states
    trash_vertex = new_graph.find_or_create_vertex(np.repeat(0, state_vector_length))
    trash_loop_vertex = new_graph.create_vertex(np.repeat(0, state_vector_length))
    trash_vertex.add_edge(trash_loop_vertex, 1)
    trash_loop_vertex.add_edge(trash_vertex, 1)

    # add trash edges
    for i, rate in trash_rates.items():
        if rate > 0:
            new_graph.vertex_at(i).add_edge(trash_vertex, rate) 

    # add edges from t-states to new final absorbing
    new_absorbing = new_graph.create_vertex(np.repeat(0, state_vector_length))
    t_vertex_indices = np.unique(t_vertex_indices)
    for i in t_vertex_indices:
        new_graph.vertex_at(i).add_edge(new_absorbing, 1)

    # normalize graph                            
    weights_were_multiplied_with = new_graph.normalize()

    return new_graph

Now lets construct the joint probability graph. In addition to a mutation rate, you must specify an upper bound on the number of discrete rewards you want to account for. You can specify a maximum discrete value of each feature (how many tons of each kind you want to account for) using the reward_limit keyword argument. This takes both a scalar argument for a maximum applied to all features or an array with a maximum for each feature. You can also use the tot_reward_limit keyword argument to specify a maximum on total rewards for each feature combination. Both arguments may be specified to define the upper bound.

Since we only model a subset of feature combinations, the joint distribution is deficient. The deficit is the total probability of feature combinations beyond the upper bound. You can easily compute this as one minus the sum of joint probabilities.

base_graph = Graph(coalescent)
base_graph.plot()
Overriding theme from NOTEBOOK_THEME environment variable. <phasic._DeviceListFilter object at 0x1538b9050>


joint_graph = joint_prob_graph(base_graph, joint_prob_reward_callback, mutation_rate=0.1, reward_limit=3, tot_reward_limit=np.inf)
# joint_graph = joint_prob_graph(base_graph, joint_prob_reward_callback, mutation_rate=0.01, tot_reward_limit=np.inf)
# joint_graph.plot(rankdir='LR', nodesep=0.1, ranksep=0.99, size=(10, 10), by_state=lambda s: sum(s[:4])==1)
joint_graph.plot(rankdir='TB', nodesep=0.1, ranksep=0.99, size=(10, 10))
Overriding theme from NOTEBOOK_THEME environment variable. <phasic._DeviceListFilter object at 0x1538b9050>
Graph has too many nodes (168). Please set max_nodes to a higher value.

Expectation is inf because of the infinite loop between the trash states:

joint_graph.expectation()
[ERROR] phasic.c: Failed to parse multiplier '@.Inf@e182798644813286610' at command 169
[WARNING] phasic.c: MPFR execution failed - falling back to double precision
inf

The joint probabilities are extracted as the sojourn times of appropriate transient states connecting only to the absorbing state:

def joint_prob_table(joint_graph, obs2idx):

    idx2obs = {v: k for k, v in obs2idx.items()}
    assert len(idx2obs) == len(obs2idx)

    t_indices = list(idx2obs.keys())
    sojourn_times = joint_graph.expected_sojourn_time(t_indices)
    assert len(sojourn_times) == len(t_indices)
    records = []
    for idx, prob in zip(t_indices, sojourn_times):
        obs = idx2obs[idx]
        records.append([*obs, prob])
    joint_probs = pd.DataFrame(records, columns=list(range(1, nr_samples+1)) + ['prob'])
    return joint_probs
def coalescent_obs2idx_map(graph, base_graph_state_length):
    t_vertex_indices = []
    for vertex in graph.vertices():
        for edge in vertex.edges():
            if len(edge.to().edges()) == 0:
                t_vertex_indices.append(vertex.index())
                break
    t_vertex_indices = np.unique(t_vertex_indices)
    states = graph.states()
    joint_reward_state_indices = np.arange(base_graph_state_length, graph.state_length())
    state_reward_matrix = states[t_vertex_indices, :][:, joint_reward_state_indices]
    mapping = {}
    for rewards, idx in zip(state_reward_matrix, t_vertex_indices):
        mapping[tuple(rewards.tolist())] = int(idx)
    return mapping

obs2idx = coalescent_obs2idx_map(joint_graph, base_graph.state_length())
joint = joint_prob_table(joint_graph, obs2idx)
joint
1 2 3 4 prob
0 0 0 0 0 7.102273e-01
1 0 1 0 0 6.097911e-02
2 1 0 0 0 1.268904e-01
3 0 0 1 0 3.945707e-02
4 0 2 0 0 8.424030e-03
... ... ... ... ... ...
59 1 3 3 0 4.866498e-09
60 3 2 3 0 1.145715e-08
61 3 3 2 0 3.694461e-09
62 2 3 3 0 1.729842e-09
63 3 3 3 0 4.520252e-10

64 rows × 5 columns

Deficit of the multivariate PMF:

1 - sum(joint['prob'])
0.0006308355342619087
ton_pair = [2, 3]
plot_df = joint[ton_pair + ['prob']].groupby(ton_pair).sum().reset_index()   
plot_df = plot_df.pivot(index=ton_pair[1], columns=ton_pair[0], values='prob')

with vscode_theme(style='ticks'):
    ax = sns.heatmap(
        plot_df,
        cmap='viridis',
        cbar_kws={'label': 'Probability'},
        xticklabels=1,
        yticklabels=1,
        norm=LogNorm(),
    )
    ax.set(xlabel=ton_pair[0], ylabel=ton_pair[1])
    ax.invert_yaxis()
Overriding theme from NOTEBOOK_THEME environment variable. <phasic._DeviceListFilter object at 0x1538b9050>

Compute the marginals to verify it matches the standard SFS:

tons = np.arange(1, nr_samples)
marginals = [np.sum(joint[t] * joint['prob']) for t in tons]
sns.barplot(x=tons, y=marginals);

Rabbit stuff

# Construct a rabbit model that tracks time to depletion for each island separately
def construct_joint_time_model(nr_rabbits, flood_left, flood_right):
    """
    Construct a model tracking joint distribution of depletion times.
    State: [rabbits_left, rabbits_right, left_depleted, right_depleted]
    """
    graph = Graph(4)

    initial_state = [nr_rabbits, 0, 0, 0]  # Start with all rabbits on left
    graph.starting_vertex().add_edge(graph.find_or_create_vertex(initial_state), 1)

    index = 1
    while index < graph.vertices_length():
        vertex = graph.vertex_at(index)
        state = list(vertex.state())

        # If left island has rabbits and not yet depleted
        if state[0] > 0 and state[2] == 0:
            # Jump to right
            child_state = [state[0] - 1, state[1] + 1, state[2], state[3]]
            if child_state[0] == 0:
                child_state[2] = 1  # Mark left as depleted
            vertex.add_edge(graph.find_or_create_vertex(child_state), 1)

            # Left flooding
            child_state = [0, state[1], 1, state[3]]  # All left rabbits die, mark depleted
            vertex.add_edge(graph.find_or_create_vertex(child_state), flood_left)

        # If right island has rabbits and not yet depleted
        if state[1] > 0 and state[3] == 0:
            # Jump to left
            child_state = [state[0] + 1, state[1] - 1, state[2], state[3]]
            if child_state[1] == 0:
                child_state[3] = 1  # Mark right as depleted
            vertex.add_edge(graph.find_or_create_vertex(child_state), 1)

            # Right flooding
            child_state = [state[0], 0, state[2], 1]  # All right rabbits die, mark depleted
            vertex.add_edge(graph.find_or_create_vertex(child_state), flood_right)

        index += 1

    return graph

# Create the joint model
joint_graph = construct_joint_time_model(3, 2.0, 4.0)
print(f"Created joint model with {joint_graph.vertices_length()} states")
print(f"\nFirst few states (rabbits_left, rabbits_right, left_depleted, right_depleted):")
for i in range(min(10, joint_graph.vertices_length())):
    print(f"  State {i}: {joint_graph.vertex_at(i).state()}")
Created joint model with 24 states

First few states (rabbits_left, rabbits_right, left_depleted, right_depleted):
  State 0: [0 0 0 0]
  State 1: [3 0 0 0]
  State 2: [2 1 0 0]
  State 3: [0 0 1 0]
  State 4: [1 2 0 0]
  State 5: [0 1 1 0]
  State 6: [3 0 0 1]
  State 7: [2 0 0 1]
  State 8: [0 3 1 0]
  State 9: [0 2 1 0]

Now we can use rewards to extract information about the joint distribution. By defining rewards that are non-zero only until each island is depleted, we can compute the marginal time until depletion for each island. By looking at the joint accumulated rewards, we can explore the correlation between these depletion times.

# Define rewards: earn reward while island is not depleted
states = joint_graph.states()

# Reward 1: time spent before left depletion (left_depleted == 0)
reward_before_left_depletion = (states[:, 2] == 0).astype(float)

# Reward 2: time spent before right depletion (right_depleted == 0)
reward_before_right_depletion = (states[:, 3] == 0).astype(float)

# Compute expectations
E_time_to_left_depletion = joint_graph.expectation(reward_before_left_depletion)
E_time_to_right_depletion = joint_graph.expectation(reward_before_right_depletion)

print(f"Expected time until left island depleted: {E_time_to_left_depletion:.4f}")
print(f"Expected time until right island depleted: {E_time_to_right_depletion:.4f}")

# Compute covariance between the two times
cov = joint_graph.covariance(reward_before_left_depletion, reward_before_right_depletion)
var_left = joint_graph.variance(reward_before_left_depletion)
var_right = joint_graph.variance(reward_before_right_depletion)
correlation = cov / np.sqrt(var_left * var_right)

print(f"\nCovariance between depletion times: {cov:.6f}")
print(f"Correlation: {correlation:.4f}")
print("The positive correlation indicates that when the left island takes longer to deplete,")
print("the right island also tends to take longer (rabbits jumping back and forth prolongs both)")
Expected time until left island depleted: 0.4836
Expected time until right island depleted: 0.4017

Covariance between depletion times: 0.131722
Correlation: 0.7704
The positive correlation indicates that when the left island takes longer to deplete,
the right island also tends to take longer (rabbits jumping back and forth prolongs both)

This framework for joint probabilities extends naturally to more complex scenarios. We can model multiple dependent processes, extract conditional distributions, and analyze the dependencies between different random variables in our model. The key is careful state space construction that encodes all relevant information, combined with judicious use of rewards to extract the quantities of interest. In population genetics applications, this approach is used to model the joint distribution of coalescence times across multiple loci or populations, capturing the complex dependencies induced by recombination and migration.

We can increase granularity for better performance:

np.sum(graph.accumulated_occupancy(0.05, granularity=1000000)*graph.states()[:,1])

Discrete joint prob

# import sys

# def discrete_joint_prob(graph, reward_rates, precision=1e-15, return_fun=False, return_graph=False):

#     starting_vertex = graph.starting_vertex()
#     reward_dims = len(reward_rates(starting_vertex.state())) - 1 # a bit of a hack. -1 to not count trash rate...

#     orig_state_vector_length = len(graph.vertex_at(1).state())
#     state_vector_length = orig_state_vector_length + reward_dims

#     state_indices = np.arange(orig_state_vector_length)
#     reward_indices = np.arange(orig_state_vector_length, state_vector_length)

#     new_graph = Graph(state_vector_length)
#     # new_starting_vertex = new_graph.vertex_at(1)
#     new_starting_vertex = new_graph.starting_vertex()

#     null_rewards = np.zeros(reward_dims)

#     index = 0
#     # add edges from starting vertex (IPV)
#     for edge in starting_vertex.edges():
#         new_starting_vertex.add_edge(
#           new_graph.find_or_create_vertex(np.append(edge.to().state(), null_rewards).astype(int)), 1)

#     index = index + 1
    
#     trash_rates = {}
#     t_vertex_indices = np.array([], dtype=int)
#     while index < new_graph.vertices_length():

#         new_vertex = new_graph.vertex_at(index)
#         new_state = new_vertex.state()
#         state = new_vertex.state()[state_indices]
#         vertex = graph.find_vertex(state)

#         # non-mutation transitions (coalescence)
#         for edge in vertex.edges():
#             new_child_state = np.append(edge.to().state(), new_state[reward_indices])

#             if np.all(new_state == new_child_state):
#                 continue
                
#             new_child_vertex = new_graph.find_or_create_vertex(new_child_state)
#             new_vertex.add_edge(new_child_vertex,
#                 edge.weight()
#             )

#             # if new child was absorbing, record at "t-states":
#             if not graph.find_vertex(new_child_state[state_indices]).edges():
#                 t_vertex_indices = np.append(t_vertex_indices, new_child_vertex.index()) 

#         # mutation transitions
#         current_state = new_state[state_indices]
#         current_rewards = new_state[reward_indices]
#         rates = reward_rates(current_state, current_rewards) # list of all allowed mutation transition rates with trash rate appended

#         trash_rates[index] = rates[reward_dims]
#         for i in range(reward_dims):
#             rate = rates[i]
#             if rate > 0:
#                 new_rewards = current_rewards
#                 new_rewards[i] = new_rewards[i] + 1
#                 new_child_vertex = new_graph.find_or_create_vertex(np.append(current_state, new_rewards))
#                 # stopifnot(sum(new_child_vertex$state) > 4)
#                 # cat(new_child_vertex$state, "\n")
#                 new_vertex.add_edge(
#                     new_child_vertex, # if I use create_vertex here, I cannot find it again with find_vertex...
#                     rate
#                     )
                
#                 # # if new child was absorbing, record at "t-states":                
#                 # if (length(edges(find_vertex(graph, new_child_state[state_indices]))) == 0) {
#                 #     t_vertex_indices = c(t_vertex_indices, new_child_vertex$index) 

#         index = index + 1 

#         if not index % 10_000:
#             graph_size = new_graph.vertices_length()
#             print(f'index: {index:>6}      vertices: {graph_size:>6}      ratio: {graph_size/index:>4.2}', file=sys.stderr)
#             sys.stderr.flush()

#     # trash states
#     trash_vertex = new_graph.find_or_create_vertex(np.repeat(0, state_vector_length))
#     trash_loop_vertex = new_graph.create_vertex(np.repeat(0, state_vector_length))
#     trash_vertex.add_edge(trash_loop_vertex, 1)
#     trash_loop_vertex.add_edge(trash_vertex, 1)

#     # add trash edges
#     for i, rate in trash_rates.items():
#         new_graph.vertex_at(i).add_edge(trash_vertex, rate) 

#     # add edges from t-states to new final absorbing
#     new_absorbing = new_graph.create_vertex(np.repeat(0, state_vector_length))

#     t_vertex_indices = np.unique(t_vertex_indices)
    
#     for i in t_vertex_indices:
#         new_graph.vertex_at(i).add_edge(new_absorbing, 1)

#     # normalize graph                            
#     weights_were_multiplied_with = new_graph.normalize()

#     if return_graph:                           
#         return(new_graph)                                             

#     # time spent in each of the the t-states at time stop or after some appropriately large time (these are the joint probs)

#     prev = None
#     for decade in range(1000):
#         accum_time_all = new_graph.accumulated_occupancy(decade*10)
#         accum_time = np.array(accum_time_all)[t_vertex_indices]
#         if prev is not None and np.all(np.abs(accum_time - prev) < precision):
#             break
#         prev = accum_time

#     assert decade < 100

#     class Fun():

#         def __init__(self, new_graph, t_vertex_indices):
#             self.new_graph = new_graph
#             self.t_vertex_indices = t_vertex_indices

#         # def __call__(self, tup):

#         def __call__(self, stop):
#             accum_time_all = self.new_graph.accumulated_occupancy(stop)
#             accum_time = np.array(accum_time_all)[self.t_vertex_indices]

#             states = new_graph.states()
#             state_reward_matrix = states[self.t_vertex_indices, :][:, reward_indices]
#             joint_probs = pd.DataFrame(state_reward_matrix)
#             index_cols = joint_probs.columns.values.tolist()
#             joint_probs['time'] = stop
#             joint_probs['prob'] = accum_time
#             joint_probs.set_index(index_cols, inplace=True)

#             return joint_probs 

#     fun = Fun(new_graph, t_vertex_indices)

#     if return_fun:
#         return fun

#     return fun(decade*10).drop(columns='time')



# def joint_prob_reward_callback(state, current_rewards=None, mutation_rate=1, reward_limit=10, tot_reward_limit=np.inf):

#     reward_limits = np.append(np.repeat(reward_limit, len(state)-1), 0)
    
#     reward_dims = len(reward_limits)
#     if current_rewards is None:
#         current_rewards = np.zeros(reward_dims)

#     reward_rates = np.zeros(reward_dims)
#     trash_rate = 0
    
#     for i in range(reward_dims):
#         rate = state[i] * mutation_rate 
#         r = np.zeros(reward_dims)
#         r[i] = 1
#         if np.all(current_rewards + r <= reward_limits) and np.sum(current_rewards + r) <= tot_reward_limit:
#             reward_rates[i] = rate
#         else:
#             trash_rate = trash_rate + rate

#     return np.append(reward_rates, trash_rate)


# def joint_prob_graph(graph, reward_rates, **kwargs):

#     starting_vertex = graph.starting_vertex()
#     reward_dims = len(reward_rates(starting_vertex.state())) - 1 # a bit of a hack. -1 to not count trash rate...

#     orig_state_vector_length = len(graph.vertex_at(1).state())
#     state_vector_length = orig_state_vector_length + reward_dims

#     state_indices = np.arange(orig_state_vector_length)
#     reward_indices = np.arange(orig_state_vector_length, state_vector_length)

#     new_graph = Graph(state_vector_length)
#     # new_starting_vertex = new_graph.vertex_at(1)
#     new_starting_vertex = new_graph.starting_vertex()

#     null_rewards = np.zeros(reward_dims)

#     index = 0
#     # add edges from starting vertex (IPV)
#     for edge in starting_vertex.edges():
#         new_starting_vertex.add_edge(
#           new_graph.find_or_create_vertex(np.append(edge.to().state(), null_rewards).astype(int)), 1)

#     index = index + 1
    
#     trash_rates = {}
#     t_vertex_indices = np.array([], dtype=int)
#     while index < new_graph.vertices_length():

#         new_vertex = new_graph.vertex_at(index)
#         new_state = new_vertex.state()
#         state = new_vertex.state()[state_indices]
#         vertex = graph.find_vertex(state)

#         # non-mutation transitions (coalescence)
#         for edge in vertex.edges():
#             new_child_state = np.append(edge.to().state(), new_state[reward_indices])

#             if np.all(new_state == new_child_state):
#                 continue
                
#             new_child_vertex = new_graph.find_or_create_vertex(new_child_state)
#             new_vertex.add_edge(new_child_vertex,
#                 edge.weight()
#             )

#             # if new child was absorbing, record at "t-states":
#             if not graph.find_vertex(new_child_state[state_indices]).edges():
#                 t_vertex_indices = np.append(t_vertex_indices, new_child_vertex.index()) 

#         # mutation transitions
#         current_state = new_state[state_indices]
#         current_rewards = new_state[reward_indices]
#         rates = reward_rates(current_state, current_rewards, **kwargs) # list of all allowed mutation transition rates with trash rate appended

#         trash_rates[index] = rates[reward_dims]
#         for i in range(reward_dims):
#             rate = rates[i]
#             if rate > 0:
#                 new_rewards = current_rewards
#                 new_rewards[i] = new_rewards[i] + 1
#                 new_child_vertex = new_graph.find_or_create_vertex(np.append(current_state, new_rewards))
#                 # stopifnot(sum(new_child_vertex$state) > 4)
#                 # cat(new_child_vertex$state, "\n")
#                 new_vertex.add_edge(
#                     new_child_vertex, # if I use create_vertex here, I cannot find it again with find_vertex...
#                     rate
#                     )
                
#                 # # if new child was absorbing, record at "t-states":                
#                 # if (length(edges(find_vertex(graph, new_child_state[state_indices]))) == 0) {
#                 #     t_vertex_indices = c(t_vertex_indices, new_child_vertex$index) 

#         index = index + 1 

#         if not index % 50_000:
#             graph_size = new_graph.vertices_length()
#             print(f'index: {index:>6}      vertices: {graph_size:>6}      ratio: {graph_size/index:>4.2}', file=sys.stderr)
#             sys.stderr.flush()

#     # trash states
#     trash_vertex = new_graph.find_or_create_vertex(np.repeat(0, state_vector_length))
#     trash_loop_vertex = new_graph.create_vertex(np.repeat(0, state_vector_length))
#     trash_vertex.add_edge(trash_loop_vertex, 1)
#     trash_loop_vertex.add_edge(trash_vertex, 1)

#     # add trash edges
#     for i, rate in trash_rates.items():
#         new_graph.vertex_at(i).add_edge(trash_vertex, rate) 

#     # add edges from t-states to new final absorbing
#     new_absorbing = new_graph.create_vertex(np.repeat(0, state_vector_length))

#     t_vertex_indices = np.unique(t_vertex_indices)
    
#     for i in t_vertex_indices:
#         new_graph.vertex_at(i).add_edge(new_absorbing, 1)

#     # normalize graph                            
#     weights_were_multiplied_with = new_graph.normalize()

#     return new_graph


# def joint_pdf_discrete(graph, obs, reward_indices, precision=1e-15):

#     # find states with an absorbing child
#     t_vertex_indices = []
#     for vertex in new_graph.vertices():
#         for edge in vertex.edges():
#             if len(edge.to().edges()) == 0:
#                 t_vertex_indices.append(vertex.index())
#                 break
#     t_vertex_indices = np.unique(t_vertex_indices)
#     np.sort(t_vertex_indices)  

#     # time spent in each of the the t-states at time stop or after some appropriately large time (these are the joint probs)
#     prev = None
#     precision=1e-15
#     for decade in range(1000):
#         accum_time_all = new_graph.accumulated_occupancy(decade*10)
#         accum_time = np.array(accum_time_all)[t_vertex_indices]
#         if prev is not None and np.all(np.abs(accum_time - prev) < precision):
#             break
#         prev = accum_time

#     assert decade < 100

#     states = new_graph.states()
#     state_reward_matrix = states[t_vertex_indices, :][:, reward_indices]
#     joint_probs = pd.DataFrame(state_reward_matrix)
#     index_cols = joint_probs.columns.values.tolist()
#     joint_probs['prob'] = accum_time
#     joint_probs.set_index(index_cols, inplace=True)

#     return joint_probs.loc[obs].prob.to_numpy()



# # new_graph = joint_prob_graph(graph, reward_callback)
# # reward_indices = np.arange(graph.state_length(), new_graph.state_length())

# # data = [
# #     [1, 0, 1, 0],
# #     [1, 1, 0, 0],
# #     [1, 0, 1, 0],
# # ]
# # joint_pdf_discrete(new_graph, data, reward_indices, precision=1e-15)


# def joint_pdf_vectorized(graph, joint_prob_reward_callback, precision=1e-15, return_graph=False, current_rewards=None, mutation_rate=1, reward_limit=10, tot_reward_limit=np.inf):

#     reward_rates = partial(joint_prob_reward_callback, current_rewards=current_rewards, mutation_rate=mutation_rate, reward_limit=reward_limit, tot_reward_limit=tot_reward_limit)

#     joint_probs = joint_prob_discrete(graph, reward_rates, precision=precision, return_graph=return_graph)

    
# discrete_joint_prob(graph, joint_prob_reward_callback)

Joint prob of single and doubleton counts


# plot_df = joint_probs.groupby(level=[0,1]).sum().reset_index()   

# from matplotlib.colors import LogNorm

# with vscode_theme(style='ticks'):
#     ax = sns.heatmap(
#         plot_df.pivot(index=0, columns=1, values='prob'),
#         cmap='viridis',
#         cbar_kws={'label': 'Probability'},
#         xticklabels=1,
#         yticklabels=1,
#         norm=LogNorm(),
#     )
#     ax.set(xlabel='Singletons', ylabel='Doubletons')

# def joint_pmf(graph, rate_fun, reward_limit=10, return_fun=True):
#     """
#     Returns a joint probability mass function for the graph
#     """

#     def reward_callback(state, current_rewards=None):

#         reward_limits = np.append(np.repeat(reward_limit, len(state)-1), 0)
        
#         reward_dims = len(reward_limits)
#         if current_rewards is None:
#             current_rewards = np.zeros(reward_dims)

#         reward_rates = np.zeros(reward_dims)
#         trash_rate = 0
        
#         for i in range(reward_dims):
#             rate = rate_fun(state[i])
#             r = np.zeros(reward_dims)
#             r[i] = 1
#             if np.all(current_rewards + r <= reward_limits):
#                 reward_rates[i] = rate
#             else:
#                 trash_rate = trash_rate + rate

#         return np.append(reward_rates, trash_rate)

#     return discrete_joint_prob(graph, reward_callback, return_fun=return_fun)
#joint_pmf(graph, lambda x: x*1, reward_limit=10, return_fun=False)



# fun = joint_pmf(graph, lambda x: x*1, reward_limit=10)
# f = np.vectorize(fun.__call__, otypes=[object])
# df = pd.concat(f(np.arange(0, 10)))
# df.index.names = [x + 1 for x in df.index.names]
# df = df.set_index('time', append=True)
# df

Joint pmf of double and triple-tons

# df.groupby([1,2,'time'])['prob'].sum()
# plot_df = df.groupby([1,2,'time'])['prob'].sum().reset_index()
plot_df
2 0 1 2 3
3
0 0.855126 0.072378 9.800885e-03 1.560540e-03
1 0.053275 0.001718 5.536407e-05 1.784103e-06
2 0.004839 0.000156 5.024903e-06 1.618553e-07
3 0.000439 0.000014 4.557678e-07 1.467303e-08

# sns.scatterplot(data=plot_df, x='time', y='prob', size=2, hue=3, palette='viridis')
# joint_probs.at[(0, 1, 0, 0, 0), 'prob']
# def joint_pmf(graph, reward_callback):
#     df = 
#     df.loc[(0, 1, 0, 0)]

# outcomes = np.matrix(list(map(list, joint_probs.index.values)))
# probs = joint_probs['prob'].values
# with_deficit = probs @ outcomes
# with_deficit = with_deficit[:,:nr_samples-1]
# no_deficit = np.matrix([2/x for x in range(1, 4)])
# deficit = (no_deficit - with_deficit) / no_deficit
# deficit
# joint_prob_at_time = discrete_joint_prob(graph, reward_rates, return_fun=True)
# joint_prob_at_time
# df = pd.concat([joint_prob_at_time(t) for t in np.arange(1, 10, 1)])
# df.head(20)
# df.pivot(columns='time')
# new_graph = discrete_joint_prob(graph, reward_rates, return_graph=True)
# new_graph.plot(size=(8, 8), ranksep=0.6, nodesep=0.3, rainbow=True)

# new_graph.plot(size=(8, 8), ranksep=3, nodesep=0.3, rainbow=True,
#     subgraphfun=lambda state: ','.join(map(str, state[:nr_samples])),
#     splines='line',
# )