Lets try to estimate both coalescent and migration rate in a two population model.
from phasic import (
Graph, with_ipv, set_log_level,
StateIndexer, Property,
GaussPrior, HalfCauchyPrior, ExpStepSize, ExpRegularization,
clear_caches, clear_jax_cache, clear_model_cache,
) # ALWAYS import phasic first to set jax backend correctly
import numpy as np
import jax.numpy as jnp
import pandas as pd
from typing import Optional
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import seaborn as sns
from tqdm.auto import tqdm
from vscodenb import set_vscode_theme
np.random.seed(42 )
set_vscode_theme()
sns.set_palette('tab10' )
set_log_level('WARNING' )
Overriding theme from NOTEBOOK_THEME environment variable. <phasic._DeviceListFilter object at 0x14f9bd6d0>
<phasic.logging_config.set_log_level at 0x17f354490>
nr_samples = 2
indexer = StateIndexer(
descendants= [
Property('pop1' , min_value= 0 , max_value= nr_samples),
Property('pop2' , min_value= 0 , max_value= nr_samples),
Property('in_pop' , min_value= 1 , max_value= 2 ),
])
initial = [0 ] * indexer.state_length
# set initial state with all lineages having one descendant at both loci
initial[indexer.descendants.props_to_index(pop1= 1 , pop2= 0 , in_pop= 1 )] = nr_samples
@with_ipv (initial)
def coalescent_islands(state):
transitions = []
if state[indexer.descendants.indices()].sum () <= 1 :
return transitions
for i in range (indexer.descendants.state_length):
if state[i] == 0 : continue
props_i = indexer.descendants.index_to_props(i)
for j in range (i, indexer.descendants.state_length):
if state[j] == 0 : continue
props_j = indexer.descendants.index_to_props(j)
if props_j.in_pop != props_i.in_pop:
continue
same = int (i == j)
if same and state[i] < 2 :
continue
if not same and (state[i] < 1 or state[j] < 1 ):
continue
child = state.copy()
child[i] -= 1
child[j] -= 1
des_pop1 = props_i.pop1 + props_j.pop1
des_pop2 = props_i.pop2 + props_j.pop2
if des_pop1 <= nr_samples and des_pop2 <= nr_samples:
k = indexer.descendants.props_to_index(
pop1= des_pop1,
pop2= des_pop2,
in_pop= props_i.in_pop
)
child[k] += 1
transitions.append([child, [state[i]* (state[j]- same)/ (1 + same), 0 ]])
if state[i] > 0 :
child = state.copy()
other_pop = 2 if props_i.in_pop == 1 else 1
child = state.copy()
child[i] -= 1
k = indexer.descendants.props_to_index(
pop1= props_i.pop1,
pop2= props_i.pop2,
in_pop= other_pop
)
child[k] += 1
transitions.append([child, [0 , state[i]]])
return transitions
graph = Graph(coalescent_islands)
true_theta = [0.7 , 0.3 ]
graph.update_weights(true_theta)
graph.plot(rankdir= 'LR' , by_index= lambda i: f"Desc in pop1: { indexer. index_to_props(i). descendants. pop1} " )
Overriding theme from NOTEBOOK_THEME environment variable. <phasic._DeviceListFilter object at 0x14f9bd6d0>
graph.plot(rankdir= 'LR' , by_state= lambda s: f"In pop1: { s[indexer.descendants.props_to_index(pop1= 1 )]. sum ()} " )
Overriding theme from NOTEBOOK_THEME environment variable. <phasic._DeviceListFilter object at 0x14f9bd6d0>
observations = graph.sample(1000 )
# quick run to get reasonable priors
svgd = graph.svgd(observations, n_particles= 5 , n_iterations= 20 )
res = svgd.get_results()
gauss_priors = []
for mean, std in zip (res['theta_mean' ], res['theta_std' ]):
gauss_priors.append(GaussPrior(mean, std))
# full run with informed priors
svgd = graph.svgd(observations, prior= gauss_priors)
svgd.summary()
Parameter Fixed MAP Mean SD HPD 95% lo HPD 95% hi
0 No 0.7210 0.6512 0.1303 0.3622 0.7249
1 No 0.2929 0.3236 0.2250 0.0000 0.7813
Particles: 40, Iterations: 100
<Figure size 640x480 with 0 Axes>
svgd.plot_pairwise(true_theta= true_theta)
#svgd.animate_pairwise(true_theta=true_theta)