Figure S1#
Import packages#
[ ]:
# import ABCParse
# import adata_query
# import anndata
# import autodevice
# import torch
#
# import pandas as pd
# import pathlib
# from typing import List
[4]:
%load_ext nb_black
import cellplots as cp
import scdiffeq as sdq
import scdiffeq_analyses as sdq_an
import larry
import matplotlib.pyplot as plt
import tqdm.notebook
import numpy as np
import pathlib
import ABCParse
import pandas as pd
import anndata
from typing import Dict
print(sdq.__version__, sdq.__path__)
time_cmap = sdq_an.pl.TimeColorMap()()
larry_cmap = larry.pl.InVitroColorMap()._dict
0.1.1rc0 ['/Users/mvinyard/GitHub/scDiffEq/scdiffeq']
[5]:
h5ad_path = "adata.LARRY_train.19MARCH2024.h5ad"
adata = sdq.io.read_h5ad(h5ad_path)
kNN = sdq.tl.kNN(adata)
project_path = "./LightningSDE-FixedPotential-RegularizedVelocityRatio/"
project = sdq.io.Project(project_path)
AnnData object with n_obs × n_vars = 130887 × 2492
obs: 'Library', 'Cell barcode', 'Time point', 'Starting population', 'Cell type annotation', 'Well', 'SPRING-x', 'SPRING-y', 'clone_idx', 'fate_observed', 't0_fated', 'train'
var: 'gene_ids', 'hv_gene', 'must_include', 'exclude', 'use_genes'
uns: 'fate_counts', 'h5ad_path', 'time_occupance'
obsm: 'X_clone', 'X_pca', 'X_umap', 'cell_fate_df'
layers: 'X_scaled'
[3]:
results = sdq_an.parsers.SummarizedCheckpointResults(project)
best = results()
best
[3]:
| train | test | ckpt_path | epoch | |
|---|---|---|---|---|
| version_0 | 0.549363 | 0.506403 | LightningSDE-FixedPotential-RegularizedVelocit... | 2241 |
| version_1 | 0.530255 | 0.494761 | LightningSDE-FixedPotential-RegularizedVelocit... | 2101 |
| version_2 | 0.542994 | 0.498254 | LightningSDE-FixedPotential-RegularizedVelocit... | 1461 |
| version_3 | 0.574841 | 0.569267 | LightningSDE-FixedPotential-RegularizedVelocit... | 664 |
| version_4 | 0.544586 | 0.509895 | LightningSDE-FixedPotential-RegularizedVelocit... | 966 |
[6]:
def manifold_recovery(model, kNN, N_values = [1, 10, 100, 1_000, 10_000, 20_000, 50_000], n_seeds: int = 5):
Results = {}
for seed in tqdm.notebook.tqdm(range(n_seeds)):
Results[seed] = {}
np.random.seed(seed)
for N in tqdm.notebook.tqdm(N_values):
batched_simulator = sdq_an.tl.BatchedSimulator(kNN = kNN)
t_idxs, idxs, Batched_Z0 = batched_simulator(model, N = N)
Results[seed][N] = {
"t_idxs": t_idxs,
"idxs": idxs,
"Z0_idxs": Batched_Z0,
}
return Results
[7]:
scDiffEqResults = {}
for version, best_ckpt_path in best['ckpt_path'].items():
v = int(str(key).split("_")[-1])
scdiffeq_model = sdq.io.load_model(adata = adata, ckpt_path=best_ckpt_path)
scDiffEqResults[v] = manifold_recovery(scdiffeq_model, kNN = kNN)
sdq.io.write_pickle(scDiffEqResults, "./scDiffEqResults.sim_nn.pkl")
- [INFO] | Input data configured.
- [INFO] | Bulding Annoy kNN Graph on adata.obsm['train']
Seed set to 0
- [INFO] | Using the specified parameters, LightningSDE-FixedPotential-RegularizedVelocityRatio has been called.
- [INFO] | Input data configured.
- [INFO] | Bulding Annoy kNN Graph on adata.obsm['train']
Seed set to 0
- [INFO] | Using the specified parameters, LightningSDE-FixedPotential-RegularizedVelocityRatio has been called.
- [INFO] | Input data configured.
- [INFO] | Bulding Annoy kNN Graph on adata.obsm['train']
Seed set to 0
- [INFO] | Using the specified parameters, LightningSDE-FixedPotential-RegularizedVelocityRatio has been called.
- [INFO] | Input data configured.
- [INFO] | Bulding Annoy kNN Graph on adata.obsm['train']
Seed set to 0
- [INFO] | Using the specified parameters, LightningSDE-FixedPotential-RegularizedVelocityRatio has been called.
- [INFO] | Input data configured.
- [INFO] | Bulding Annoy kNN Graph on adata.obsm['train']
Seed set to 0
- [INFO] | Using the specified parameters, LightningSDE-FixedPotential-RegularizedVelocityRatio has been called.
[3]:
ckpts = list(pathlib.Path("./prescient_ckpts.for_scdiffeq_manuscript/weinreb-fate-benchmark-kegg/fate_kegg-softplus_2_400-1e-06/").glob("seed_*/train.best.pt"))
ckpts
[3]:
[PosixPath('prescient_ckpts.for_scdiffeq_manuscript/weinreb-fate-benchmark-kegg/fate_kegg-softplus_2_400-1e-06/seed_2/train.best.pt'),
PosixPath('prescient_ckpts.for_scdiffeq_manuscript/weinreb-fate-benchmark-kegg/fate_kegg-softplus_2_400-1e-06/seed_4/train.best.pt'),
PosixPath('prescient_ckpts.for_scdiffeq_manuscript/weinreb-fate-benchmark-kegg/fate_kegg-softplus_2_400-1e-06/seed_3/train.best.pt'),
PosixPath('prescient_ckpts.for_scdiffeq_manuscript/weinreb-fate-benchmark-kegg/fate_kegg-softplus_2_400-1e-06/seed_1/train.best.pt'),
PosixPath('prescient_ckpts.for_scdiffeq_manuscript/weinreb-fate-benchmark-kegg/fate_kegg-softplus_2_400-1e-06/seed_0/train.best.pt')]
[8]:
PRESCIENTResults = {}
for ckpt_path in ckpts:
seed = int(ckpt_path.parent.name.split("seed_")[1])
prescient = sdq_an.models.PRESCIENT(adata)
prescient.load_from_ckpt(ckpt_path)
PRESCIENTResults[seed] = manifold_recovery(prescient, kNN = kNN)
sdq.io.write_pickle(PRESCIENTResults, "./PRESCIENTResults.sim_nn.pkl")
[100]:
scDiffEqResults = sdq.io.read_pickle("./scDiffEqResults.sim_nn.pkl")
[248]:
class AggrResults(ABCParse.ABCParse):
def __init__(self, adata: anndata.AnnData):
""""""
self.__parse__(locals())
def neighbor_sum(self, result):
return {n: len(np.unique(value['idxs'])) for n, value in result.items()}
def summarize_neighbor_stats(self, result_group):
df = pd.DataFrame({sim_seed: neighbor_sum(result) for sim_seed, result in result_group.items()})
df = pd.DataFrame({"mean": df.mean(1), "std": df.std(1)}, index = df.index)
df['pct'] = df['mean'].div(len(self._adata))
df['pct.std'] = df['std'].div(len(self._adata))
return df
def compute_stats(self, ResultDict):
return {model_seed: self.summarize_neighbor_stats(result_group) for model_seed, result_group in ResultDict.items()}
def __call__(self, ResultsDict: Dict, *args, **kwargs):
""""""
self.__update__(locals())
return self.compute_stats(ResultsDict)
def plot_manifold_recovery(ax, counts: pd.DataFrame, c="k", label = None):
ax.errorbar(
counts.index,
y = counts['pct'],
yerr=counts['pct.std'],
capsize=4,
c = c,
lw = 1,
capthick=0.5,
elinewidth=0.5,
)
ax.scatter(counts.index, counts['pct'], s = 15, ec = "None", c = c, label=label)
ax.set_xscale('log')
ax.set_xlim(0, 100_000)
# ax.set_ylim(-0.1, 0.6)
# ax.set_yscale('log')
_ = [x.set_linewidth(0.5) for x in ax.spines.values()]
[249]:
sdq_aggr_results = AggrResults(adata)
sdq_results = sdq_aggr_results(scDiffEqResults)
[251]:
prescient_aggr_results = AggrResults(adata)
prescient_results = prescient_aggr_results(PRESCIENTResults)
[337]:
fig, axes = cp.plot(
1,
1,
height = 0.5,
width = 0.8,
title = [''],
x_label=['$t_{0}$ cells sampled'],
y_label=['% manifold recovery'],
)
ax = axes[0]
ax.grid(alpha = 0.2)
for key, val in prescient_results.items():
plot_manifold_recovery(ax, val, c = "#fb8500", label = f"PRESCIENT-{key}")
for key, val in sdq_results.items():
plot_manifold_recovery(ax, val, c = "#023047", label = f"scDiffEq-{key}")
ax.set_ylim(-0.1, 1)
plt.legend(loc=(1.1, 0), fontsize = 6)
plt.savefig("percent_manifold_recovery.compared.svg", dpi = 500)
[285]:
umap_model = sdq.io.read_pickle("./umap.pkl")
adata.obsm['X_umap'] = X_umap = umap_model.transform(adata.obsm['X_pca'])
[307]:
N_values = [1, 10, 100, 1_000, 10_000, 20_000, 50_000]
[308]:
nplots = ncols = len(N_values)
[ ]:
[342]:
fig, axes = cp.plot(nplots, ncols=4, height = 1, width = 1.1, wspace = 0.2, title = N_values, del_xy_ticks=[True], delete="all")
for en, ax in enumerate(axes):
cp.umap_manifold(adata, groupby="Cell type annotation", c_background = larry_cmap, ax = ax)
ix = scDiffEqResults[0][0][N_values[en]]['idxs']
xu = adata[ix].obsm['X_umap'].toarray()
ax.scatter(xu[:,0], xu[:,1], c='lightgrey', s =5, ec = "None", zorder = 101, rasterized = True) # c = time_cmap[i]
ix_init = scDiffEqResults[0][0][N_values[en]]['t_idxs'][0]
xu_init = adata[ix_init].obsm['X_umap'].toarray()
ax.scatter(xu_init[:,0], xu_init[:,1], c='k', s =5, ec = "None", zorder = 105, rasterized = True)
plt.savefig("reconstructed_emanifold.svg", dpi = 500)
[ ]:
[ ]:
[ ]:
[324]:
fig, axes = cp.plot(nplots, ncols, height = 0.5, width = 0.6, wspace = 0.2, title = N_values)
for en, ax in enumerate(axes):
ix = PRESCIENTResults[0][0][N_values[en]]['idxs']
xu = adata[ix].obsm['X_umap'].toarray()
ax.scatter(xu[:,0], xu[:,1], c='k', s =5, ec = "None")
[299]:
N = 10
for ki, vi in PRESCIENTResults.items():
for kj, vj in vi.items():
for en, tx in enumerate(vj[N]['t_idxs']):
xu = adata[tx].obsm['X_umap'].toarray()
plt.scatter(xu[:,0], xu[:,1], c = time_cmap[en])
break
[305]:
N = 100
for ki, vi in scDiffEqResults.items():
for kj, vj in vi.items():
for en, tx in enumerate(vj[N]['t_idxs']):
xu = adata[tx].obsm['X_umap'].toarray()
plt.scatter(xu[:,0], xu[:,1], c = time_cmap[en], s =5, ec = "None")
break
[290]:
[290]:
array([[-4.7767506 , 2.2991521 ],
[-4.751574 , 2.2924092 ],
[-4.7922254 , 2.3088753 ],
[-4.7701173 , 2.3060944 ],
[-3.730114 , 2.0308683 ],
[-4.6966057 , 2.2970014 ],
[-4.7885785 , 2.3057384 ],
[-4.7145424 , 2.2900527 ],
[-2.799617 , 2.094828 ],
[-4.520895 , 2.2299535 ],
[-4.0035825 , 2.080969 ],
[-4.077326 , 2.1154609 ],
[-4.126018 , 2.1284027 ],
[-3.7489493 , 2.0350966 ],
[-4.5104537 , 2.2310894 ],
[-3.8782945 , 2.0610917 ],
[-2.4291978 , 1.9067383 ],
[-3.1943645 , 1.8676742 ],
[-2.137622 , 1.8053168 ],
[-1.2488801 , 1.8713697 ],
[-2.828694 , 1.8808211 ],
[-2.5772822 , 1.9022101 ],
[-3.6448293 , 1.9752283 ],
[-2.2755578 , 1.8793008 ],
[-1.8821764 , 1.7493634 ],
[-1.7083832 , 2.1759975 ],
[-3.407939 , 1.9186658 ],
[-4.627491 , 2.25463 ],
[-2.4494815 , 1.9833689 ],
[-1.2900019 , 1.9528916 ],
[-2.84367 , 1.9602436 ],
[-3.093011 , 1.8990626 ],
[-4.088991 , 2.1289902 ],
[-4.0178757 , 2.0948043 ],
[-2.406933 , 1.8706136 ],
[-4.727838 , 2.286141 ],
[-4.1435513 , 2.1279745 ],
[-4.404973 , 2.204887 ],
[-3.424393 , 2.0004501 ],
[-2.169357 , 1.9204342 ],
[-1.7109792 , 2.014365 ],
[-4.28242 , 2.1670592 ],
[-4.7827945 , 2.289518 ],
[-4.631333 , 2.2555234 ],
[-2.8044171 , 1.9213554 ],
[-4.723055 , 2.284161 ],
[-3.8156185 , 2.051848 ],
[-4.782151 , 2.311531 ],
[-4.7453065 , 2.29529 ],
[-3.7365954 , 2.02583 ],
[-3.3190424 , 1.920155 ],
[-4.148372 , 2.1349893 ],
[-4.781355 , 2.3058398 ],
[-4.707344 , 2.2932124 ],
[-4.6173687 , 2.260463 ],
[-4.474077 , 2.2896361 ],
[-2.33877 , 1.904511 ],
[-3.4120798 , 1.9855748 ],
[-4.603611 , 2.2634861 ],
[-4.390932 , 2.1097834 ],
[-4.2884197 , 2.1765602 ],
[-3.6562302 , 2.024878 ],
[-4.376602 , 2.2077658 ],
[-4.7052116 , 2.285015 ],
[-4.817945 , 2.320222 ],
[-4.7487855 , 2.3012662 ],
[-4.703614 , 2.2773175 ],
[-2.146724 , 1.9080111 ],
[-1.9268184 , 1.8085632 ],
[-4.774681 , 2.3067036 ],
[-4.7081666 , 2.236575 ],
[-3.1987333 , 1.9757875 ],
[-2.2101836 , 1.9295732 ],
[-2.359103 , 1.9303504 ],
[-0.54273707, 2.332872 ],
[-1.7061679 , 1.5992018 ],
[-1.7459401 , 1.724105 ],
[-2.9199185 , 1.9773192 ],
[-2.0617707 , 1.8422886 ],
[-1.1858304 , 1.6803308 ],
[-1.1671245 , 1.8714094 ],
[-1.2884988 , 2.0289717 ],
[-1.0132746 , 1.538424 ],
[-0.6175913 , 1.2363846 ],
[-0.24020359, 2.2742941 ],
[-2.0959451 , 1.7266654 ],
[-2.6734173 , 1.946631 ],
[-2.3008351 , 1.7731018 ],
[-0.8175638 , 2.3937385 ],
[-1.2194148 , 1.8056425 ],
[-2.7113516 , 1.9033945 ],
[-1.9453726 , 1.7494576 ],
[-1.4844509 , 1.6917282 ],
[-3.3489366 , 1.9479777 ],
[-2.9134173 , 1.9360551 ],
[-4.984356 , 2.3937087 ],
[-4.153409 , 2.1668627 ],
[-4.157774 , 2.141215 ],
[-3.542011 , 2.0028887 ],
[-4.766685 , 2.2969787 ],
[-3.3385284 , 1.9091482 ],
[-3.5845873 , 2.012328 ],
[-3.1789517 , 1.9486539 ],
[-4.189625 , 2.154807 ],
[-1.8829033 , 1.8276869 ],
[-4.528968 , 2.2358205 ],
[-3.6750882 , 2.0254116 ],
[-4.4000435 , 2.2143018 ],
[-1.9954984 , 1.7884823 ],
[-2.280135 , 1.8127214 ],
[-3.7739947 , 2.0387168 ],
[-4.0527167 , 2.1239612 ],
[-2.345219 , 1.9140921 ],
[-4.3017764 , 2.1742327 ],
[-1.6105335 , 2.5966332 ],
[-1.8481888 , 1.8800026 ],
[-3.855828 , 2.057885 ],
[-1.305972 , 2.1226351 ],
[-2.9495556 , 2.0011113 ],
[-1.1789442 , 2.2637095 ],
[-3.2122586 , 1.9825445 ],
[-0.25810847, 2.4837782 ],
[-3.1180627 , 1.988762 ],
[-0.05822635, 2.294555 ],
[-0.16174728, 1.2416362 ],
[-1.4222533 , 1.9183836 ],
[-1.4995182 , 2.0934231 ],
[-1.4360064 , 2.4125535 ],
[-0.4102964 , 1.944689 ],
[-1.2651199 , 1.6715618 ],
[-3.134154 , 1.9723772 ],
[-0.9801621 , 1.544471 ],
[-2.303235 , 2.044336 ],
[-2.9910946 , 1.9370008 ],
[-2.6345243 , 1.790739 ],
[-3.6756222 , 2.0153544 ],
[-2.3627424 , 1.939171 ],
[-1.4661573 , 1.9627844 ],
[-1.6793529 , 1.7012086 ],
[ 0.10910934, 2.2468128 ],
[-2.933586 , 1.9017797 ],
[-1.7851615 , 1.7097328 ],
[-0.75305486, 1.6694813 ],
[-0.73143506, 1.9915167 ],
[-3.539323 , 1.9688922 ],
[-2.26666 , 1.878858 ],
[-1.6995754 , 2.2655506 ],
[-3.8711774 , 2.0630822 ],
[-1.0680475 , 2.0221918 ],
[-2.0614972 , 1.8914615 ],
[-3.8507752 , 2.0568938 ],
[-1.0543239 , 2.0187333 ],
[-2.0775762 , 1.8910428 ],
[-1.2309679 , 1.8817514 ],
[-4.1467996 , 2.153647 ],
[-1.5962312 , 1.8439301 ],
[-1.8533264 , 1.7818593 ],
[-2.1884966 , 1.8760082 ],
[-0.66960716, 1.9204426 ],
[-2.0973227 , 1.8876605 ],
[-1.9171301 , 1.9794712 ],
[-2.1456065 , 1.7624878 ],
[-0.8691584 , 3.2571151 ],
[-0.7808767 , 2.0864775 ],
[-0.25690806, 2.6229534 ],
[-0.71821773, 2.041627 ],
[-2.2361987 , 2.1920946 ],
[-1.1814722 , 2.0180087 ],
[-0.729251 , 1.6594745 ],
[-4.7158985 , 2.2846072 ],
[-0.55212027, 1.4645027 ],
[-2.204938 , 1.7548376 ],
[-3.995943 , 2.0859911 ],
[-3.1256936 , 1.9882712 ],
[-1.6979619 , 1.7312567 ],
[-0.8494791 , 2.8690326 ],
[-2.1584218 , 1.9350797 ],
[-3.8358026 , 2.0448985 ],
[-3.5836513 , 1.9995531 ],
[-2.7644076 , 1.9613734 ],
[-2.0347838 , 1.9357398 ],
[-2.7370496 , 1.9348994 ],
[-2.2033062 , 1.938819 ],
[-2.8526711 , 1.8933761 ],
[-2.0242667 , 1.8547733 ],
[-0.83780175, 1.7192647 ],
[-2.1525958 , 1.9961902 ],
[-2.0460703 , 1.8378136 ],
[-4.5389924 , 2.2428405 ],
[-1.9798508 , 1.8843231 ],
[-4.086821 , 2.0913935 ],
[-1.9977895 , 1.7919558 ],
[-1.8838537 , 1.8307939 ],
[-1.3958234 , 2.5133884 ],
[-2.2842946 , 1.8045471 ],
[-2.9516404 , 1.8886069 ],
[-2.2793531 , 1.8364666 ],
[-1.8230039 , 1.6384932 ],
[ 0.72093844, 0.9156169 ],
[ 0.38623238, 1.3397753 ],
[-0.47300208, 2.4113874 ],
[-0.675761 , 1.4174216 ],
[-2.3249373 , 1.8631749 ],
[-2.6554165 , 1.9921925 ],
[-1.1044148 , 2.5359423 ],
[-1.4751823 , 2.8336987 ],
[-1.0327219 , 1.5116489 ],
[-0.01917923, 2.0600245 ],
[-2.6522262 , 1.9597028 ],
[-0.7513429 , 1.3865578 ],
[-0.09501639, 1.3680749 ],
[ 0.33406922, 1.5168957 ],
[-0.7643318 , 1.5379997 ],
[-2.7683823 , 1.8615081 ],
[-2.4744775 , 2.0519133 ],
[-1.2062119 , 1.8387916 ],
[-0.7156174 , 2.1516287 ],
[-3.7890117 , 2.037874 ],
[-2.9559467 , 1.9195647 ],
[-4.3911223 , 2.1971684 ],
[-0.729833 , 1.8268378 ],
[-2.0762165 , 1.731518 ],
[-1.1386932 , 1.7291292 ],
[-2.445149 , 2.3220508 ],
[-3.198302 , 1.9383322 ],
[-2.275375 , 1.7492725 ],
[-1.8237638 , 2.351345 ],
[-1.025823 , 2.031286 ],
[-1.5539654 , 2.5678115 ],
[-2.811343 , 1.9894252 ],
[-3.0138085 , 1.8937446 ],
[-1.1578529 , 2.1156797 ],
[-2.023731 , 2.0005655 ],
[-0.61503357, 1.6073158 ],
[-3.3622997 , 1.9318975 ],
[-0.45347777, 1.2922595 ],
[-3.2503445 , 1.9277506 ],
[-1.9878559 , 1.5770032 ],
[-1.7636534 , 1.8790101 ],
[-0.2964629 , 2.1392663 ],
[-1.2730513 , 1.8397765 ],
[-1.4599669 , 1.9270794 ],
[-1.0695912 , 1.5269443 ],
[-2.027702 , 1.8886124 ],
[-1.8414644 , 1.9475881 ],
[-1.8725396 , 3.1599863 ],
[-2.0199184 , 1.6323391 ],
[-2.1432328 , 1.7975762 ],
[-0.9427852 , 2.0535746 ],
[-2.5239174 , 1.8949182 ],
[-1.870131 , 1.8583711 ],
[-0.56513214, 1.7634226 ],
[-1.5007592 , 1.9873817 ],
[-2.092919 , 1.8435044 ],
[-3.200143 , 2.0069041 ],
[-3.4224265 , 1.9545591 ],
[-4.1098366 , 2.13951 ],
[-4.787919 , 2.3049254 ],
[-1.7640824 , 1.9770857 ],
[-4.062653 , 2.122304 ],
[-2.1140466 , 1.9008923 ],
[-4.0412173 , 2.064319 ],
[-3.858661 , 2.0751364 ],
[-4.2212524 , 2.16817 ],
[-2.5894473 , 1.9529262 ],
[-0.9722 , 1.9531602 ],
[-1.9878243 , 1.9318092 ],
[-1.3211379 , 1.7502551 ],
[-2.8651552 , 1.9603915 ],
[-1.9431208 , 1.9380236 ],
[-1.7437401 , 1.8046956 ],
[-2.1138468 , 1.9618275 ],
[-2.1316159 , 2.0636852 ],
[-2.4741142 , 1.9525416 ],
[-3.917766 , 2.070964 ],
[-1.9619702 , 1.8870287 ],
[-1.5602932 , 2.061163 ],
[-1.8825613 , 2.009452 ],
[-4.644383 , 2.256327 ],
[-2.097393 , 1.8533278 ],
[-0.7100427 , 2.205196 ],
[-3.9505098 , 2.1237686 ]], dtype=float32)
[ ]:
adata[]
[274]:
scDiffEqResults[0][0].keys()
[274]:
dict_keys([1, 10, 100, 1000, 10000, 20000, 50000])
[264]:
for k, v in scDiffEqResults[0][0].items():
break
[268]:
v['t_idxs']
[268]:
[array([ 15908, 39131, 44702, 45744, 54688, 60569, 61879, 67845,
68890, 85905, 88351, 88586, 91275, 93300, 93787, 95807,
97817, 103643, 123098, 127122]),
array([ 15908, 17264, 36136, 39131, 44702, 54688, 56825, 61879,
69108, 71160, 87339, 88586, 91275, 93889, 95807, 97817,
103643, 108016, 117512, 127122]),
array([ 14388, 14455, 15908, 15939, 17361, 24480, 39618, 44702,
58515, 61879, 78999, 79602, 88414, 88586, 92662, 100930,
108573, 110100, 116701, 127517]),
array([ 15807, 15908, 17264, 24371, 39131, 40697, 56825, 71160,
88871, 91275, 93889, 95050, 97406, 99011, 108016, 108653,
110123, 111422, 117512, 127122]),
array([ 8764, 14284, 14522, 17999, 18005, 40578, 46285, 63257,
74561, 80941, 84036, 86730, 93764, 94094, 97479, 99763,
100789, 106842, 106967, 124968]),
array([ 15220, 16637, 17999, 18005, 24038, 36480, 40578, 46285,
63223, 68273, 69320, 72349, 83609, 84066, 102354, 106967,
107626, 107880, 111809, 119194]),
array([ 13888, 14284, 14522, 17999, 18005, 27566, 40578, 59786,
63257, 73678, 84036, 86730, 93764, 94094, 97479, 99763,
100789, 102410, 106842, 124968]),
array([ 13791, 18780, 27497, 37115, 64537, 64825, 67666, 70497,
71612, 72349, 87658, 88313, 88416, 91461, 93479, 102470,
115197, 116790, 117838, 126954]),
array([ 1900, 12211, 13791, 27497, 29577, 31015, 64537, 64825,
67666, 69320, 72349, 84066, 87658, 88313, 91461, 93479,
102354, 115197, 116790, 126954]),
array([ 1900, 12211, 16637, 18005, 24038, 31015, 31147, 40578,
63223, 69320, 72349, 73799, 84066, 87658, 91461, 102354,
104747, 106967, 107626, 126954]),
array([ 16637, 17999, 18005, 24038, 31147, 40578, 63223, 67551,
69320, 73678, 73799, 84066, 87658, 91569, 102354, 104747,
106842, 106967, 107626, 113228]),
array([ 1900, 12211, 13791, 24038, 31015, 31147, 64537, 64825,
69320, 72349, 73799, 84066, 87658, 88313, 91461, 93479,
102354, 104747, 115197, 126954]),
array([ 1900, 12211, 13791, 31015, 31147, 64537, 64825, 69320,
72349, 73799, 84066, 87658, 88313, 91461, 93479, 102354,
104747, 107626, 115197, 126954]),
array([ 9841, 12014, 13791, 23170, 27497, 27550, 35677, 37115,
37472, 37918, 64537, 64825, 67666, 68579, 71612, 88416,
91461, 102470, 115197, 119475]),
array([ 9841, 12014, 13791, 18780, 23170, 27497, 27550, 32004,
37115, 37472, 64537, 64825, 67666, 68579, 71612, 88416,
91461, 102470, 115197, 119475]),
array([ 9841, 13085, 16872, 23170, 27497, 27550, 35003, 35677,
37115, 37472, 38575, 67666, 67941, 68579, 88416, 95140,
102470, 115197, 119475, 128766]),
array([ 265, 1368, 1583, 7732, 8465, 8587, 28291, 29722,
30537, 34630, 34904, 38015, 38020, 38112, 38575, 39770,
64461, 65762, 90814, 116578]),
array([ 265, 760, 1368, 7732, 8587, 21429, 27517, 28291,
30537, 34904, 37017, 38015, 38112, 38575, 39770, 39886,
65762, 67941, 90814, 116578]),
array([ 1900, 11022, 12014, 12211, 31015, 31147, 32004, 37793,
64825, 65623, 66490, 66931, 73799, 84066, 87111, 87658,
91461, 93707, 107626, 130054]),
array([ 1900, 11022, 11574, 12014, 19402, 22583, 31147, 32004,
64537, 64825, 66931, 71612, 73799, 87658, 91461, 93707,
95731, 115197, 126954, 130054]),
array([ 14868, 19464, 41782, 59786, 83400, 84937, 86749, 89470,
90597, 93995, 97773, 98898, 100924, 104677, 105809, 110250,
120754, 126500, 126632, 126648]),
array([ 19464, 41782, 59786, 83400, 84937, 86749, 88284, 89470,
93995, 94934, 97773, 98898, 100924, 105809, 110250, 120754,
126500, 126632, 126648, 129242]),
array([ 19464, 41782, 59786, 83400, 84937, 86749, 88284, 89470,
93995, 94934, 97773, 98898, 100924, 105809, 110250, 120754,
126500, 126632, 126648, 129242]),
array([ 14868, 19464, 41782, 56256, 59786, 83400, 84937, 88765,
89470, 93995, 97773, 98898, 100924, 105809, 106842, 110250,
120754, 126500, 126632, 126648]),
array([ 41782, 71942, 86749, 87578, 88423, 89144, 94692, 94934,
95329, 96004, 99320, 107115, 108107, 108292, 109785, 110963,
124970, 126500, 127538, 129242]),
array([ 47697, 73697, 80912, 84953, 86824, 86891, 88795, 89996,
90595, 90660, 90793, 90849, 96593, 100511, 106519, 107513,
107774, 122056, 123668, 130002]),
array([ 49765, 50066, 52552, 71624, 80912, 81272, 87702, 88339,
88423, 88795, 90660, 90793, 98478, 99481, 103678, 106099,
107064, 108642, 122056, 123668]),
array([ 47355, 49765, 50066, 52552, 80912, 81272, 87702, 88423,
88795, 90660, 90793, 98478, 99481, 103678, 106099, 107064,
108642, 122056, 123668, 125085]),
array([ 49765, 50066, 52552, 80912, 81272, 87702, 88339, 88423,
88795, 90660, 90793, 98478, 99481, 103678, 106099, 107064,
108642, 122056, 123668, 125085]),
array([ 47355, 49765, 50066, 52552, 80912, 81272, 87702, 88423,
88795, 90660, 90793, 98478, 99481, 103678, 106099, 107064,
108642, 122056, 123668, 125085]),
array([ 19464, 71942, 83400, 88284, 93995, 94692, 94934, 95329,
96004, 100924, 106099, 107115, 108107, 109785, 110963, 124970,
126500, 126648, 127538, 129242]),
array([ 41782, 47243, 71942, 72471, 80346, 86749, 88284, 89144,
94692, 94934, 95329, 96004, 107115, 108107, 109785, 110963,
124970, 126500, 127538, 129242]),
array([ 58393, 73697, 80912, 86891, 88795, 89713, 89996, 90595,
90849, 93182, 96593, 98466, 99124, 99926, 100334, 100511,
106519, 107513, 122056, 130002]),
array([ 73697, 77664, 86891, 88795, 89713, 89996, 90595, 90849,
93182, 94918, 96593, 99124, 99539, 99926, 100334, 100511,
103018, 106519, 107513, 130002]),
array([ 73697, 77664, 86891, 88795, 89713, 89996, 90595, 90849,
93182, 94918, 96593, 99124, 99539, 99926, 100334, 100511,
103018, 106519, 107513, 130002]),
array([ 73697, 77664, 86891, 88795, 89713, 89996, 90595, 90849,
93182, 96593, 99124, 99539, 100334, 100511, 103018, 106519,
106808, 107513, 110575, 130002]),
array([ 73697, 77664, 86891, 88795, 89713, 89996, 90595, 90849,
93182, 96593, 99124, 99926, 100334, 100511, 103018, 106519,
106808, 107513, 110575, 130002]),
array([ 73697, 86891, 88795, 89713, 89996, 90595, 90849, 93182,
96593, 99124, 99926, 100511, 101483, 103018, 106519, 106808,
107513, 109941, 110575, 130002]),
array([ 43548, 53869, 73697, 74557, 76603, 77664, 80436, 86891,
88795, 89713, 89996, 90595, 90849, 96593, 99539, 100334,
100511, 106519, 107513, 130002]),
array([ 73697, 77664, 86268, 86891, 89996, 90595, 90849, 94918,
96593, 98242, 98358, 100334, 100458, 100511, 106519, 107513,
108245, 110148, 110604, 130002]),
array([ 73697, 77664, 86891, 89996, 90231, 90595, 90849, 94918,
96593, 98358, 100334, 100458, 100511, 103018, 106519, 107513,
108245, 110148, 110604, 130002])]
[267]:
v['Z0_idxs']
[267]:
array(['15908'], dtype=object)
[ ]:
[271]:
adata[v['idxs']]
[271]:
View of AnnData object with n_obs × n_vars = 282 × 2492
obs: 'Library', 'Cell barcode', 'Time point', 'Starting population', 'Cell type annotation', 'Well', 'SPRING-x', 'SPRING-y', 'clone_idx', 'fate_observed', 't0_fated', 'train'
var: 'gene_ids', 'hv_gene', 'must_include', 'exclude', 'use_genes'
uns: 'fate_counts', 'h5ad_path', 'time_occupance'
obsm: 'X_clone', 'X_pca', 'X_umap', 'cell_fate_df'
layers: 'X_scaled'
[266]:
v.keys()
[266]:
dict_keys(['t_idxs', 'idxs', 'Z0_idxs'])
[ ]:
[78]:
sdq_results_aggr = {int(str(key).split("_")[-1]): prepare_counts(AggrResults(val)(), adata) for key, val in scDiffEqResults.items()}
prescient_results_aggr = {key: prepare_counts(AggrResults(val)(), adata) for key, val in PRESCIENTResults.items()}
[87]:
[88]:
fig, axes = cp.plot(
1,
1,
height = 0.5,
width = 0.8,
title = [''],
x_label=['$t_{0}$ cells sampled'],
y_label=['% manifold recovery'],
)
ax = axes[0]
ax.grid(alpha = 0.2)
for key, val in sdq_results_aggr.items():
plot_manifold_recovery(ax, val, c = "b", label = f"scDiffEq-{key}")
for key, val in prescient_results_aggr.items():
plot_manifold_recovery(ax, val, c = "r", label = f"PRESCIENT-{key}")
plt.legend(loc=(1.1, 0), fontsize = 6)
plt.savefig("percent_manifold_recovery.compared.svg", dpi = 500)
UMAPs#
[22]:
axes = cp.umap_manifold(adata, groupby = "Cell type annotation", c_background=larry_cmap, alpha = 0.1, s = 2)
axes = cp.umap(adata[res['idxs']], ax = axes[0], c = "dimgrey", zorder = 101)
for en, ix in enumerate(res['t_idxs']):
axes = cp.umap(adata[ix], c = time_cmap[en], ax = axes[0], zorder = int(205 - en))
axes = cp.umap(adata[res['Z0_idxs']], c = "k", ax = axes[0], zorder = int(305))