Figure S1#

Import packages#

[ ]:
# import ABCParse
# import adata_query
# import anndata
# import autodevice
# import torch
#
# import pandas as pd
# import pathlib

# from typing import List
[4]:
%load_ext nb_black

import cellplots as cp
import scdiffeq as sdq
import scdiffeq_analyses as sdq_an
import larry
import matplotlib.pyplot as plt
import tqdm.notebook
import numpy as np
import pathlib
import ABCParse
import pandas as pd
import anndata

from typing import Dict

print(sdq.__version__, sdq.__path__)

time_cmap = sdq_an.pl.TimeColorMap()()
larry_cmap = larry.pl.InVitroColorMap()._dict
0.1.1rc0 ['/Users/mvinyard/GitHub/scDiffEq/scdiffeq']
[5]:
h5ad_path = "adata.LARRY_train.19MARCH2024.h5ad"
adata = sdq.io.read_h5ad(h5ad_path)
kNN = sdq.tl.kNN(adata)

project_path = "./LightningSDE-FixedPotential-RegularizedVelocityRatio/"
project = sdq.io.Project(project_path)
AnnData object with n_obs × n_vars = 130887 × 2492
    obs: 'Library', 'Cell barcode', 'Time point', 'Starting population', 'Cell type annotation', 'Well', 'SPRING-x', 'SPRING-y', 'clone_idx', 'fate_observed', 't0_fated', 'train'
    var: 'gene_ids', 'hv_gene', 'must_include', 'exclude', 'use_genes'
    uns: 'fate_counts', 'h5ad_path', 'time_occupance'
    obsm: 'X_clone', 'X_pca', 'X_umap', 'cell_fate_df'
    layers: 'X_scaled'
[3]:
results = sdq_an.parsers.SummarizedCheckpointResults(project)
best = results()
best
[3]:
train test ckpt_path epoch
version_0 0.549363 0.506403 LightningSDE-FixedPotential-RegularizedVelocit... 2241
version_1 0.530255 0.494761 LightningSDE-FixedPotential-RegularizedVelocit... 2101
version_2 0.542994 0.498254 LightningSDE-FixedPotential-RegularizedVelocit... 1461
version_3 0.574841 0.569267 LightningSDE-FixedPotential-RegularizedVelocit... 664
version_4 0.544586 0.509895 LightningSDE-FixedPotential-RegularizedVelocit... 966
[6]:
def manifold_recovery(model, kNN, N_values = [1, 10, 100, 1_000, 10_000, 20_000, 50_000], n_seeds: int = 5):
    Results = {}
    for seed in tqdm.notebook.tqdm(range(n_seeds)):
        Results[seed] = {}
        np.random.seed(seed)
        for N in tqdm.notebook.tqdm(N_values):
            batched_simulator = sdq_an.tl.BatchedSimulator(kNN = kNN)
            t_idxs, idxs, Batched_Z0 = batched_simulator(model, N = N)
            Results[seed][N] = {
                "t_idxs": t_idxs,
                "idxs": idxs,
                "Z0_idxs": Batched_Z0,
            }
    return Results
[7]:
scDiffEqResults = {}
for version, best_ckpt_path in best['ckpt_path'].items():
    v = int(str(key).split("_")[-1])
    scdiffeq_model = sdq.io.load_model(adata = adata, ckpt_path=best_ckpt_path)
    scDiffEqResults[v] = manifold_recovery(scdiffeq_model, kNN = kNN)

sdq.io.write_pickle(scDiffEqResults, "./scDiffEqResults.sim_nn.pkl")
 - [INFO] | Input data configured.
 - [INFO] | Bulding Annoy kNN Graph on adata.obsm['train']
Seed set to 0
 - [INFO] | Using the specified parameters, LightningSDE-FixedPotential-RegularizedVelocityRatio has been called.
 - [INFO] | Input data configured.
 - [INFO] | Bulding Annoy kNN Graph on adata.obsm['train']
Seed set to 0
 - [INFO] | Using the specified parameters, LightningSDE-FixedPotential-RegularizedVelocityRatio has been called.
 - [INFO] | Input data configured.
 - [INFO] | Bulding Annoy kNN Graph on adata.obsm['train']
Seed set to 0
 - [INFO] | Using the specified parameters, LightningSDE-FixedPotential-RegularizedVelocityRatio has been called.
 - [INFO] | Input data configured.
 - [INFO] | Bulding Annoy kNN Graph on adata.obsm['train']
Seed set to 0
 - [INFO] | Using the specified parameters, LightningSDE-FixedPotential-RegularizedVelocityRatio has been called.
 - [INFO] | Input data configured.
 - [INFO] | Bulding Annoy kNN Graph on adata.obsm['train']
Seed set to 0
 - [INFO] | Using the specified parameters, LightningSDE-FixedPotential-RegularizedVelocityRatio has been called.
[3]:
ckpts = list(pathlib.Path("./prescient_ckpts.for_scdiffeq_manuscript/weinreb-fate-benchmark-kegg/fate_kegg-softplus_2_400-1e-06/").glob("seed_*/train.best.pt"))
ckpts
[3]:
[PosixPath('prescient_ckpts.for_scdiffeq_manuscript/weinreb-fate-benchmark-kegg/fate_kegg-softplus_2_400-1e-06/seed_2/train.best.pt'),
 PosixPath('prescient_ckpts.for_scdiffeq_manuscript/weinreb-fate-benchmark-kegg/fate_kegg-softplus_2_400-1e-06/seed_4/train.best.pt'),
 PosixPath('prescient_ckpts.for_scdiffeq_manuscript/weinreb-fate-benchmark-kegg/fate_kegg-softplus_2_400-1e-06/seed_3/train.best.pt'),
 PosixPath('prescient_ckpts.for_scdiffeq_manuscript/weinreb-fate-benchmark-kegg/fate_kegg-softplus_2_400-1e-06/seed_1/train.best.pt'),
 PosixPath('prescient_ckpts.for_scdiffeq_manuscript/weinreb-fate-benchmark-kegg/fate_kegg-softplus_2_400-1e-06/seed_0/train.best.pt')]
[8]:
PRESCIENTResults = {}
for ckpt_path in ckpts:
    seed = int(ckpt_path.parent.name.split("seed_")[1])
    prescient = sdq_an.models.PRESCIENT(adata)
    prescient.load_from_ckpt(ckpt_path)
    PRESCIENTResults[seed] = manifold_recovery(prescient, kNN = kNN)

sdq.io.write_pickle(PRESCIENTResults, "./PRESCIENTResults.sim_nn.pkl")

[100]:
scDiffEqResults = sdq.io.read_pickle("./scDiffEqResults.sim_nn.pkl")
[248]:
class AggrResults(ABCParse.ABCParse):
    def __init__(self, adata: anndata.AnnData):
        """"""
        self.__parse__(locals())

    def neighbor_sum(self, result):
        return {n: len(np.unique(value['idxs'])) for n, value in result.items()}

    def summarize_neighbor_stats(self, result_group):
        df = pd.DataFrame({sim_seed: neighbor_sum(result) for sim_seed, result in result_group.items()})
        df = pd.DataFrame({"mean": df.mean(1), "std": df.std(1)}, index = df.index)
        df['pct'] = df['mean'].div(len(self._adata))
        df['pct.std'] = df['std'].div(len(self._adata))
        return df

    def compute_stats(self, ResultDict):
        return {model_seed: self.summarize_neighbor_stats(result_group) for model_seed, result_group in ResultDict.items()}

    def __call__(self, ResultsDict: Dict, *args, **kwargs):
        """"""
        self.__update__(locals())

        return self.compute_stats(ResultsDict)

def plot_manifold_recovery(ax, counts: pd.DataFrame, c="k", label = None):
    ax.errorbar(
        counts.index,
        y = counts['pct'],
        yerr=counts['pct.std'],
        capsize=4,
        c = c,
        lw = 1,
        capthick=0.5,
        elinewidth=0.5,
    )
    ax.scatter(counts.index, counts['pct'], s = 15, ec = "None", c = c, label=label)
    ax.set_xscale('log')
    ax.set_xlim(0, 100_000)
    # ax.set_ylim(-0.1, 0.6)
    # ax.set_yscale('log')
    _ = [x.set_linewidth(0.5) for x in ax.spines.values()]
[249]:
sdq_aggr_results = AggrResults(adata)
sdq_results = sdq_aggr_results(scDiffEqResults)
[251]:
prescient_aggr_results = AggrResults(adata)
prescient_results = prescient_aggr_results(PRESCIENTResults)
[337]:
fig, axes = cp.plot(
    1,
    1,
    height = 0.5,
    width = 0.8,
    title = [''],
    x_label=['$t_{0}$ cells sampled'],
    y_label=['% manifold recovery'],
)
ax = axes[0]
ax.grid(alpha = 0.2)
for key, val in prescient_results.items():
    plot_manifold_recovery(ax, val, c = "#fb8500", label = f"PRESCIENT-{key}")
for key, val in sdq_results.items():
    plot_manifold_recovery(ax, val, c = "#023047", label = f"scDiffEq-{key}")
ax.set_ylim(-0.1, 1)
plt.legend(loc=(1.1, 0), fontsize = 6)
plt.savefig("percent_manifold_recovery.compared.svg", dpi = 500)
../_images/_analyses_FigureS1_15_0.png
[285]:
umap_model = sdq.io.read_pickle("./umap.pkl")
adata.obsm['X_umap'] = X_umap = umap_model.transform(adata.obsm['X_pca'])
[307]:
N_values = [1, 10, 100, 1_000, 10_000, 20_000, 50_000]
[308]:
nplots = ncols = len(N_values)
[ ]:

[342]:
fig, axes = cp.plot(nplots, ncols=4, height = 1, width = 1.1, wspace = 0.2, title = N_values, del_xy_ticks=[True], delete="all")

for en, ax in enumerate(axes):
    cp.umap_manifold(adata, groupby="Cell type annotation", c_background = larry_cmap, ax = ax)
    ix = scDiffEqResults[0][0][N_values[en]]['idxs']
    xu = adata[ix].obsm['X_umap'].toarray()
    ax.scatter(xu[:,0], xu[:,1], c='lightgrey', s =5, ec = "None", zorder = 101, rasterized = True) # c = time_cmap[i]
    ix_init = scDiffEqResults[0][0][N_values[en]]['t_idxs'][0]
    xu_init = adata[ix_init].obsm['X_umap'].toarray()
    ax.scatter(xu_init[:,0], xu_init[:,1], c='k', s =5, ec = "None", zorder = 105, rasterized = True)
plt.savefig("reconstructed_emanifold.svg", dpi = 500)
../_images/_analyses_FigureS1_20_0.png
[ ]:

[ ]:

[ ]:

[324]:
fig, axes = cp.plot(nplots, ncols, height = 0.5, width = 0.6, wspace = 0.2, title = N_values)

for en, ax in enumerate(axes):
    ix = PRESCIENTResults[0][0][N_values[en]]['idxs']
    xu = adata[ix].obsm['X_umap'].toarray()
    ax.scatter(xu[:,0], xu[:,1], c='k', s =5, ec = "None")
../_images/_analyses_FigureS1_24_0.png
[299]:
N = 10

for ki, vi in PRESCIENTResults.items():
    for kj, vj in vi.items():
        for en, tx in enumerate(vj[N]['t_idxs']):
            xu = adata[tx].obsm['X_umap'].toarray()
            plt.scatter(xu[:,0], xu[:,1], c = time_cmap[en])

    break
../_images/_analyses_FigureS1_25_0.png
[305]:
N = 100

for ki, vi in scDiffEqResults.items():
    for kj, vj in vi.items():
        for en, tx in enumerate(vj[N]['t_idxs']):
            xu = adata[tx].obsm['X_umap'].toarray()
            plt.scatter(xu[:,0], xu[:,1], c = time_cmap[en], s =5, ec = "None")
    break
../_images/_analyses_FigureS1_26_0.png
[290]:

[290]:
array([[-4.7767506 ,  2.2991521 ],
       [-4.751574  ,  2.2924092 ],
       [-4.7922254 ,  2.3088753 ],
       [-4.7701173 ,  2.3060944 ],
       [-3.730114  ,  2.0308683 ],
       [-4.6966057 ,  2.2970014 ],
       [-4.7885785 ,  2.3057384 ],
       [-4.7145424 ,  2.2900527 ],
       [-2.799617  ,  2.094828  ],
       [-4.520895  ,  2.2299535 ],
       [-4.0035825 ,  2.080969  ],
       [-4.077326  ,  2.1154609 ],
       [-4.126018  ,  2.1284027 ],
       [-3.7489493 ,  2.0350966 ],
       [-4.5104537 ,  2.2310894 ],
       [-3.8782945 ,  2.0610917 ],
       [-2.4291978 ,  1.9067383 ],
       [-3.1943645 ,  1.8676742 ],
       [-2.137622  ,  1.8053168 ],
       [-1.2488801 ,  1.8713697 ],
       [-2.828694  ,  1.8808211 ],
       [-2.5772822 ,  1.9022101 ],
       [-3.6448293 ,  1.9752283 ],
       [-2.2755578 ,  1.8793008 ],
       [-1.8821764 ,  1.7493634 ],
       [-1.7083832 ,  2.1759975 ],
       [-3.407939  ,  1.9186658 ],
       [-4.627491  ,  2.25463   ],
       [-2.4494815 ,  1.9833689 ],
       [-1.2900019 ,  1.9528916 ],
       [-2.84367   ,  1.9602436 ],
       [-3.093011  ,  1.8990626 ],
       [-4.088991  ,  2.1289902 ],
       [-4.0178757 ,  2.0948043 ],
       [-2.406933  ,  1.8706136 ],
       [-4.727838  ,  2.286141  ],
       [-4.1435513 ,  2.1279745 ],
       [-4.404973  ,  2.204887  ],
       [-3.424393  ,  2.0004501 ],
       [-2.169357  ,  1.9204342 ],
       [-1.7109792 ,  2.014365  ],
       [-4.28242   ,  2.1670592 ],
       [-4.7827945 ,  2.289518  ],
       [-4.631333  ,  2.2555234 ],
       [-2.8044171 ,  1.9213554 ],
       [-4.723055  ,  2.284161  ],
       [-3.8156185 ,  2.051848  ],
       [-4.782151  ,  2.311531  ],
       [-4.7453065 ,  2.29529   ],
       [-3.7365954 ,  2.02583   ],
       [-3.3190424 ,  1.920155  ],
       [-4.148372  ,  2.1349893 ],
       [-4.781355  ,  2.3058398 ],
       [-4.707344  ,  2.2932124 ],
       [-4.6173687 ,  2.260463  ],
       [-4.474077  ,  2.2896361 ],
       [-2.33877   ,  1.904511  ],
       [-3.4120798 ,  1.9855748 ],
       [-4.603611  ,  2.2634861 ],
       [-4.390932  ,  2.1097834 ],
       [-4.2884197 ,  2.1765602 ],
       [-3.6562302 ,  2.024878  ],
       [-4.376602  ,  2.2077658 ],
       [-4.7052116 ,  2.285015  ],
       [-4.817945  ,  2.320222  ],
       [-4.7487855 ,  2.3012662 ],
       [-4.703614  ,  2.2773175 ],
       [-2.146724  ,  1.9080111 ],
       [-1.9268184 ,  1.8085632 ],
       [-4.774681  ,  2.3067036 ],
       [-4.7081666 ,  2.236575  ],
       [-3.1987333 ,  1.9757875 ],
       [-2.2101836 ,  1.9295732 ],
       [-2.359103  ,  1.9303504 ],
       [-0.54273707,  2.332872  ],
       [-1.7061679 ,  1.5992018 ],
       [-1.7459401 ,  1.724105  ],
       [-2.9199185 ,  1.9773192 ],
       [-2.0617707 ,  1.8422886 ],
       [-1.1858304 ,  1.6803308 ],
       [-1.1671245 ,  1.8714094 ],
       [-1.2884988 ,  2.0289717 ],
       [-1.0132746 ,  1.538424  ],
       [-0.6175913 ,  1.2363846 ],
       [-0.24020359,  2.2742941 ],
       [-2.0959451 ,  1.7266654 ],
       [-2.6734173 ,  1.946631  ],
       [-2.3008351 ,  1.7731018 ],
       [-0.8175638 ,  2.3937385 ],
       [-1.2194148 ,  1.8056425 ],
       [-2.7113516 ,  1.9033945 ],
       [-1.9453726 ,  1.7494576 ],
       [-1.4844509 ,  1.6917282 ],
       [-3.3489366 ,  1.9479777 ],
       [-2.9134173 ,  1.9360551 ],
       [-4.984356  ,  2.3937087 ],
       [-4.153409  ,  2.1668627 ],
       [-4.157774  ,  2.141215  ],
       [-3.542011  ,  2.0028887 ],
       [-4.766685  ,  2.2969787 ],
       [-3.3385284 ,  1.9091482 ],
       [-3.5845873 ,  2.012328  ],
       [-3.1789517 ,  1.9486539 ],
       [-4.189625  ,  2.154807  ],
       [-1.8829033 ,  1.8276869 ],
       [-4.528968  ,  2.2358205 ],
       [-3.6750882 ,  2.0254116 ],
       [-4.4000435 ,  2.2143018 ],
       [-1.9954984 ,  1.7884823 ],
       [-2.280135  ,  1.8127214 ],
       [-3.7739947 ,  2.0387168 ],
       [-4.0527167 ,  2.1239612 ],
       [-2.345219  ,  1.9140921 ],
       [-4.3017764 ,  2.1742327 ],
       [-1.6105335 ,  2.5966332 ],
       [-1.8481888 ,  1.8800026 ],
       [-3.855828  ,  2.057885  ],
       [-1.305972  ,  2.1226351 ],
       [-2.9495556 ,  2.0011113 ],
       [-1.1789442 ,  2.2637095 ],
       [-3.2122586 ,  1.9825445 ],
       [-0.25810847,  2.4837782 ],
       [-3.1180627 ,  1.988762  ],
       [-0.05822635,  2.294555  ],
       [-0.16174728,  1.2416362 ],
       [-1.4222533 ,  1.9183836 ],
       [-1.4995182 ,  2.0934231 ],
       [-1.4360064 ,  2.4125535 ],
       [-0.4102964 ,  1.944689  ],
       [-1.2651199 ,  1.6715618 ],
       [-3.134154  ,  1.9723772 ],
       [-0.9801621 ,  1.544471  ],
       [-2.303235  ,  2.044336  ],
       [-2.9910946 ,  1.9370008 ],
       [-2.6345243 ,  1.790739  ],
       [-3.6756222 ,  2.0153544 ],
       [-2.3627424 ,  1.939171  ],
       [-1.4661573 ,  1.9627844 ],
       [-1.6793529 ,  1.7012086 ],
       [ 0.10910934,  2.2468128 ],
       [-2.933586  ,  1.9017797 ],
       [-1.7851615 ,  1.7097328 ],
       [-0.75305486,  1.6694813 ],
       [-0.73143506,  1.9915167 ],
       [-3.539323  ,  1.9688922 ],
       [-2.26666   ,  1.878858  ],
       [-1.6995754 ,  2.2655506 ],
       [-3.8711774 ,  2.0630822 ],
       [-1.0680475 ,  2.0221918 ],
       [-2.0614972 ,  1.8914615 ],
       [-3.8507752 ,  2.0568938 ],
       [-1.0543239 ,  2.0187333 ],
       [-2.0775762 ,  1.8910428 ],
       [-1.2309679 ,  1.8817514 ],
       [-4.1467996 ,  2.153647  ],
       [-1.5962312 ,  1.8439301 ],
       [-1.8533264 ,  1.7818593 ],
       [-2.1884966 ,  1.8760082 ],
       [-0.66960716,  1.9204426 ],
       [-2.0973227 ,  1.8876605 ],
       [-1.9171301 ,  1.9794712 ],
       [-2.1456065 ,  1.7624878 ],
       [-0.8691584 ,  3.2571151 ],
       [-0.7808767 ,  2.0864775 ],
       [-0.25690806,  2.6229534 ],
       [-0.71821773,  2.041627  ],
       [-2.2361987 ,  2.1920946 ],
       [-1.1814722 ,  2.0180087 ],
       [-0.729251  ,  1.6594745 ],
       [-4.7158985 ,  2.2846072 ],
       [-0.55212027,  1.4645027 ],
       [-2.204938  ,  1.7548376 ],
       [-3.995943  ,  2.0859911 ],
       [-3.1256936 ,  1.9882712 ],
       [-1.6979619 ,  1.7312567 ],
       [-0.8494791 ,  2.8690326 ],
       [-2.1584218 ,  1.9350797 ],
       [-3.8358026 ,  2.0448985 ],
       [-3.5836513 ,  1.9995531 ],
       [-2.7644076 ,  1.9613734 ],
       [-2.0347838 ,  1.9357398 ],
       [-2.7370496 ,  1.9348994 ],
       [-2.2033062 ,  1.938819  ],
       [-2.8526711 ,  1.8933761 ],
       [-2.0242667 ,  1.8547733 ],
       [-0.83780175,  1.7192647 ],
       [-2.1525958 ,  1.9961902 ],
       [-2.0460703 ,  1.8378136 ],
       [-4.5389924 ,  2.2428405 ],
       [-1.9798508 ,  1.8843231 ],
       [-4.086821  ,  2.0913935 ],
       [-1.9977895 ,  1.7919558 ],
       [-1.8838537 ,  1.8307939 ],
       [-1.3958234 ,  2.5133884 ],
       [-2.2842946 ,  1.8045471 ],
       [-2.9516404 ,  1.8886069 ],
       [-2.2793531 ,  1.8364666 ],
       [-1.8230039 ,  1.6384932 ],
       [ 0.72093844,  0.9156169 ],
       [ 0.38623238,  1.3397753 ],
       [-0.47300208,  2.4113874 ],
       [-0.675761  ,  1.4174216 ],
       [-2.3249373 ,  1.8631749 ],
       [-2.6554165 ,  1.9921925 ],
       [-1.1044148 ,  2.5359423 ],
       [-1.4751823 ,  2.8336987 ],
       [-1.0327219 ,  1.5116489 ],
       [-0.01917923,  2.0600245 ],
       [-2.6522262 ,  1.9597028 ],
       [-0.7513429 ,  1.3865578 ],
       [-0.09501639,  1.3680749 ],
       [ 0.33406922,  1.5168957 ],
       [-0.7643318 ,  1.5379997 ],
       [-2.7683823 ,  1.8615081 ],
       [-2.4744775 ,  2.0519133 ],
       [-1.2062119 ,  1.8387916 ],
       [-0.7156174 ,  2.1516287 ],
       [-3.7890117 ,  2.037874  ],
       [-2.9559467 ,  1.9195647 ],
       [-4.3911223 ,  2.1971684 ],
       [-0.729833  ,  1.8268378 ],
       [-2.0762165 ,  1.731518  ],
       [-1.1386932 ,  1.7291292 ],
       [-2.445149  ,  2.3220508 ],
       [-3.198302  ,  1.9383322 ],
       [-2.275375  ,  1.7492725 ],
       [-1.8237638 ,  2.351345  ],
       [-1.025823  ,  2.031286  ],
       [-1.5539654 ,  2.5678115 ],
       [-2.811343  ,  1.9894252 ],
       [-3.0138085 ,  1.8937446 ],
       [-1.1578529 ,  2.1156797 ],
       [-2.023731  ,  2.0005655 ],
       [-0.61503357,  1.6073158 ],
       [-3.3622997 ,  1.9318975 ],
       [-0.45347777,  1.2922595 ],
       [-3.2503445 ,  1.9277506 ],
       [-1.9878559 ,  1.5770032 ],
       [-1.7636534 ,  1.8790101 ],
       [-0.2964629 ,  2.1392663 ],
       [-1.2730513 ,  1.8397765 ],
       [-1.4599669 ,  1.9270794 ],
       [-1.0695912 ,  1.5269443 ],
       [-2.027702  ,  1.8886124 ],
       [-1.8414644 ,  1.9475881 ],
       [-1.8725396 ,  3.1599863 ],
       [-2.0199184 ,  1.6323391 ],
       [-2.1432328 ,  1.7975762 ],
       [-0.9427852 ,  2.0535746 ],
       [-2.5239174 ,  1.8949182 ],
       [-1.870131  ,  1.8583711 ],
       [-0.56513214,  1.7634226 ],
       [-1.5007592 ,  1.9873817 ],
       [-2.092919  ,  1.8435044 ],
       [-3.200143  ,  2.0069041 ],
       [-3.4224265 ,  1.9545591 ],
       [-4.1098366 ,  2.13951   ],
       [-4.787919  ,  2.3049254 ],
       [-1.7640824 ,  1.9770857 ],
       [-4.062653  ,  2.122304  ],
       [-2.1140466 ,  1.9008923 ],
       [-4.0412173 ,  2.064319  ],
       [-3.858661  ,  2.0751364 ],
       [-4.2212524 ,  2.16817   ],
       [-2.5894473 ,  1.9529262 ],
       [-0.9722    ,  1.9531602 ],
       [-1.9878243 ,  1.9318092 ],
       [-1.3211379 ,  1.7502551 ],
       [-2.8651552 ,  1.9603915 ],
       [-1.9431208 ,  1.9380236 ],
       [-1.7437401 ,  1.8046956 ],
       [-2.1138468 ,  1.9618275 ],
       [-2.1316159 ,  2.0636852 ],
       [-2.4741142 ,  1.9525416 ],
       [-3.917766  ,  2.070964  ],
       [-1.9619702 ,  1.8870287 ],
       [-1.5602932 ,  2.061163  ],
       [-1.8825613 ,  2.009452  ],
       [-4.644383  ,  2.256327  ],
       [-2.097393  ,  1.8533278 ],
       [-0.7100427 ,  2.205196  ],
       [-3.9505098 ,  2.1237686 ]], dtype=float32)
[ ]:
adata[]
[274]:
scDiffEqResults[0][0].keys()
[274]:
dict_keys([1, 10, 100, 1000, 10000, 20000, 50000])
[264]:
for k, v in scDiffEqResults[0][0].items():
    break
[268]:
v['t_idxs']
[268]:
[array([ 15908,  39131,  44702,  45744,  54688,  60569,  61879,  67845,
         68890,  85905,  88351,  88586,  91275,  93300,  93787,  95807,
         97817, 103643, 123098, 127122]),
 array([ 15908,  17264,  36136,  39131,  44702,  54688,  56825,  61879,
         69108,  71160,  87339,  88586,  91275,  93889,  95807,  97817,
        103643, 108016, 117512, 127122]),
 array([ 14388,  14455,  15908,  15939,  17361,  24480,  39618,  44702,
         58515,  61879,  78999,  79602,  88414,  88586,  92662, 100930,
        108573, 110100, 116701, 127517]),
 array([ 15807,  15908,  17264,  24371,  39131,  40697,  56825,  71160,
         88871,  91275,  93889,  95050,  97406,  99011, 108016, 108653,
        110123, 111422, 117512, 127122]),
 array([  8764,  14284,  14522,  17999,  18005,  40578,  46285,  63257,
         74561,  80941,  84036,  86730,  93764,  94094,  97479,  99763,
        100789, 106842, 106967, 124968]),
 array([ 15220,  16637,  17999,  18005,  24038,  36480,  40578,  46285,
         63223,  68273,  69320,  72349,  83609,  84066, 102354, 106967,
        107626, 107880, 111809, 119194]),
 array([ 13888,  14284,  14522,  17999,  18005,  27566,  40578,  59786,
         63257,  73678,  84036,  86730,  93764,  94094,  97479,  99763,
        100789, 102410, 106842, 124968]),
 array([ 13791,  18780,  27497,  37115,  64537,  64825,  67666,  70497,
         71612,  72349,  87658,  88313,  88416,  91461,  93479, 102470,
        115197, 116790, 117838, 126954]),
 array([  1900,  12211,  13791,  27497,  29577,  31015,  64537,  64825,
         67666,  69320,  72349,  84066,  87658,  88313,  91461,  93479,
        102354, 115197, 116790, 126954]),
 array([  1900,  12211,  16637,  18005,  24038,  31015,  31147,  40578,
         63223,  69320,  72349,  73799,  84066,  87658,  91461, 102354,
        104747, 106967, 107626, 126954]),
 array([ 16637,  17999,  18005,  24038,  31147,  40578,  63223,  67551,
         69320,  73678,  73799,  84066,  87658,  91569, 102354, 104747,
        106842, 106967, 107626, 113228]),
 array([  1900,  12211,  13791,  24038,  31015,  31147,  64537,  64825,
         69320,  72349,  73799,  84066,  87658,  88313,  91461,  93479,
        102354, 104747, 115197, 126954]),
 array([  1900,  12211,  13791,  31015,  31147,  64537,  64825,  69320,
         72349,  73799,  84066,  87658,  88313,  91461,  93479, 102354,
        104747, 107626, 115197, 126954]),
 array([  9841,  12014,  13791,  23170,  27497,  27550,  35677,  37115,
         37472,  37918,  64537,  64825,  67666,  68579,  71612,  88416,
         91461, 102470, 115197, 119475]),
 array([  9841,  12014,  13791,  18780,  23170,  27497,  27550,  32004,
         37115,  37472,  64537,  64825,  67666,  68579,  71612,  88416,
         91461, 102470, 115197, 119475]),
 array([  9841,  13085,  16872,  23170,  27497,  27550,  35003,  35677,
         37115,  37472,  38575,  67666,  67941,  68579,  88416,  95140,
        102470, 115197, 119475, 128766]),
 array([   265,   1368,   1583,   7732,   8465,   8587,  28291,  29722,
         30537,  34630,  34904,  38015,  38020,  38112,  38575,  39770,
         64461,  65762,  90814, 116578]),
 array([   265,    760,   1368,   7732,   8587,  21429,  27517,  28291,
         30537,  34904,  37017,  38015,  38112,  38575,  39770,  39886,
         65762,  67941,  90814, 116578]),
 array([  1900,  11022,  12014,  12211,  31015,  31147,  32004,  37793,
         64825,  65623,  66490,  66931,  73799,  84066,  87111,  87658,
         91461,  93707, 107626, 130054]),
 array([  1900,  11022,  11574,  12014,  19402,  22583,  31147,  32004,
         64537,  64825,  66931,  71612,  73799,  87658,  91461,  93707,
         95731, 115197, 126954, 130054]),
 array([ 14868,  19464,  41782,  59786,  83400,  84937,  86749,  89470,
         90597,  93995,  97773,  98898, 100924, 104677, 105809, 110250,
        120754, 126500, 126632, 126648]),
 array([ 19464,  41782,  59786,  83400,  84937,  86749,  88284,  89470,
         93995,  94934,  97773,  98898, 100924, 105809, 110250, 120754,
        126500, 126632, 126648, 129242]),
 array([ 19464,  41782,  59786,  83400,  84937,  86749,  88284,  89470,
         93995,  94934,  97773,  98898, 100924, 105809, 110250, 120754,
        126500, 126632, 126648, 129242]),
 array([ 14868,  19464,  41782,  56256,  59786,  83400,  84937,  88765,
         89470,  93995,  97773,  98898, 100924, 105809, 106842, 110250,
        120754, 126500, 126632, 126648]),
 array([ 41782,  71942,  86749,  87578,  88423,  89144,  94692,  94934,
         95329,  96004,  99320, 107115, 108107, 108292, 109785, 110963,
        124970, 126500, 127538, 129242]),
 array([ 47697,  73697,  80912,  84953,  86824,  86891,  88795,  89996,
         90595,  90660,  90793,  90849,  96593, 100511, 106519, 107513,
        107774, 122056, 123668, 130002]),
 array([ 49765,  50066,  52552,  71624,  80912,  81272,  87702,  88339,
         88423,  88795,  90660,  90793,  98478,  99481, 103678, 106099,
        107064, 108642, 122056, 123668]),
 array([ 47355,  49765,  50066,  52552,  80912,  81272,  87702,  88423,
         88795,  90660,  90793,  98478,  99481, 103678, 106099, 107064,
        108642, 122056, 123668, 125085]),
 array([ 49765,  50066,  52552,  80912,  81272,  87702,  88339,  88423,
         88795,  90660,  90793,  98478,  99481, 103678, 106099, 107064,
        108642, 122056, 123668, 125085]),
 array([ 47355,  49765,  50066,  52552,  80912,  81272,  87702,  88423,
         88795,  90660,  90793,  98478,  99481, 103678, 106099, 107064,
        108642, 122056, 123668, 125085]),
 array([ 19464,  71942,  83400,  88284,  93995,  94692,  94934,  95329,
         96004, 100924, 106099, 107115, 108107, 109785, 110963, 124970,
        126500, 126648, 127538, 129242]),
 array([ 41782,  47243,  71942,  72471,  80346,  86749,  88284,  89144,
         94692,  94934,  95329,  96004, 107115, 108107, 109785, 110963,
        124970, 126500, 127538, 129242]),
 array([ 58393,  73697,  80912,  86891,  88795,  89713,  89996,  90595,
         90849,  93182,  96593,  98466,  99124,  99926, 100334, 100511,
        106519, 107513, 122056, 130002]),
 array([ 73697,  77664,  86891,  88795,  89713,  89996,  90595,  90849,
         93182,  94918,  96593,  99124,  99539,  99926, 100334, 100511,
        103018, 106519, 107513, 130002]),
 array([ 73697,  77664,  86891,  88795,  89713,  89996,  90595,  90849,
         93182,  94918,  96593,  99124,  99539,  99926, 100334, 100511,
        103018, 106519, 107513, 130002]),
 array([ 73697,  77664,  86891,  88795,  89713,  89996,  90595,  90849,
         93182,  96593,  99124,  99539, 100334, 100511, 103018, 106519,
        106808, 107513, 110575, 130002]),
 array([ 73697,  77664,  86891,  88795,  89713,  89996,  90595,  90849,
         93182,  96593,  99124,  99926, 100334, 100511, 103018, 106519,
        106808, 107513, 110575, 130002]),
 array([ 73697,  86891,  88795,  89713,  89996,  90595,  90849,  93182,
         96593,  99124,  99926, 100511, 101483, 103018, 106519, 106808,
        107513, 109941, 110575, 130002]),
 array([ 43548,  53869,  73697,  74557,  76603,  77664,  80436,  86891,
         88795,  89713,  89996,  90595,  90849,  96593,  99539, 100334,
        100511, 106519, 107513, 130002]),
 array([ 73697,  77664,  86268,  86891,  89996,  90595,  90849,  94918,
         96593,  98242,  98358, 100334, 100458, 100511, 106519, 107513,
        108245, 110148, 110604, 130002]),
 array([ 73697,  77664,  86891,  89996,  90231,  90595,  90849,  94918,
         96593,  98358, 100334, 100458, 100511, 103018, 106519, 107513,
        108245, 110148, 110604, 130002])]
[267]:
v['Z0_idxs']
[267]:
array(['15908'], dtype=object)
[ ]:

[271]:
adata[v['idxs']]
[271]:
View of AnnData object with n_obs × n_vars = 282 × 2492
    obs: 'Library', 'Cell barcode', 'Time point', 'Starting population', 'Cell type annotation', 'Well', 'SPRING-x', 'SPRING-y', 'clone_idx', 'fate_observed', 't0_fated', 'train'
    var: 'gene_ids', 'hv_gene', 'must_include', 'exclude', 'use_genes'
    uns: 'fate_counts', 'h5ad_path', 'time_occupance'
    obsm: 'X_clone', 'X_pca', 'X_umap', 'cell_fate_df'
    layers: 'X_scaled'
[266]:
v.keys()
[266]:
dict_keys(['t_idxs', 'idxs', 'Z0_idxs'])
[ ]:

[78]:
sdq_results_aggr = {int(str(key).split("_")[-1]): prepare_counts(AggrResults(val)(), adata) for key, val in scDiffEqResults.items()}
prescient_results_aggr = {key: prepare_counts(AggrResults(val)(), adata) for key, val in PRESCIENTResults.items()}
[87]:

[88]:
fig, axes = cp.plot(
    1,
    1,
    height = 0.5,
    width = 0.8,
    title = [''],
    x_label=['$t_{0}$ cells sampled'],
    y_label=['% manifold recovery'],
)
ax = axes[0]
ax.grid(alpha = 0.2)
for key, val in sdq_results_aggr.items():
    plot_manifold_recovery(ax, val, c = "b", label = f"scDiffEq-{key}")

for key, val in prescient_results_aggr.items():
    plot_manifold_recovery(ax, val, c = "r", label = f"PRESCIENT-{key}")
plt.legend(loc=(1.1, 0), fontsize = 6)
plt.savefig("percent_manifold_recovery.compared.svg", dpi = 500)
../_images/_analyses_FigureS1_39_0.png

UMAPs#

[22]:
axes = cp.umap_manifold(adata, groupby = "Cell type annotation", c_background=larry_cmap,  alpha = 0.1, s = 2)
axes = cp.umap(adata[res['idxs']], ax = axes[0], c = "dimgrey", zorder = 101)
for en, ix in enumerate(res['t_idxs']):
    axes = cp.umap(adata[ix], c = time_cmap[en], ax = axes[0], zorder = int(205 - en))
axes = cp.umap(adata[res['Z0_idxs']], c = "k", ax = axes[0], zorder = int(305))
../_images/_analyses_FigureS1_41_0.png