Figure S5 A, B#

Import packages#

[1]:
%load_ext nb_black

import scdiffeq as sdq
import dev
import cellplots as cp
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import pandas as pd
import numpy as np
import ABCParse
import os
import pathlib
import seaborn as sns
import larry
import scipy.stats
import glob
import matplotlib

Read data and load project#

[2]:
h5ad_path = (
    "/home/mvinyard/data/adata.reprocessed_19OCT2023.more_feature_inclusive.h5ad"
)
adata = sdq.io.read_h5ad(h5ad_path)
AnnData object with n_obs × n_vars = 130887 × 2492
    obs: 'Library', 'Cell barcode', 'Time point', 'Starting population', 'Cell type annotation', 'Well', 'SPRING-x', 'SPRING-y', 'clone_idx', 'fate_observed', 't0_fated', 'train'
    var: 'gene_ids', 'hv_gene', 'must_include', 'exclude', 'use_genes'
    uns: 'fate_counts', 'h5ad_path', 'time_occupance'
    obsm: 'X_clone', 'X_pca', 'X_umap', 'cell_fate_df'
    layers: 'X_scaled'
[3]:
project = sdq.io.Project(path="./LightningSDE-FixedPotential-RegularizedVelocityRatio/")

Compute fate prediction accuracy for each version#

[4]:
RATIOS = {}
for version, path in project._VERSION_PATHS.items():
    try:
        v = getattr(project, version)
        acc = dev.fate_prediction_accuracy(v)
        target_ratio = v.hparams["velocity_ratio_params"]["target"]
        RATIOS[version] = target_ratio, v, acc
    except:
        pass

For each un-regularized training run, compute drift [f] and diffusion [g]#

[5]:
def f(df, key="training"):
    return df.filter(regex="velo_f").filter(regex=key).dropna().sum(1).mean()


def g(df, key="training"):
    return df.filter(regex="velo_g").filter(regex=key).dropna().sum(1).mean()


UnEnforcedResults = {}
for en, (version, results) in enumerate(RATIOS.items()):
    target_ratio, v, acc = results
    if v.hparams["velocity_ratio_params"]["enforce"] == 0:
        print(version)
        grouped = v.metrics_df.groupby("epoch")
        UnEnforcedResults[version] = {
            "f_training": grouped.apply(f, key="training"),
            "f_validation": grouped.apply(f, key="validation"),
            "g_training": grouped.apply(g, key="training"),
            "g_validation": grouped.apply(g, key="validation"),
        }
version_0
version_1
version_2
version_3
version_4

Compute the empirical unenforced ratio#

[6]:
unenforced_ratio = np.array(
    [
        v["f_validation"].iloc[-1] / v["g_validation"].iloc[-1]
        for i, (k, v) in enumerate(UnEnforcedResults.items())
    ]
)
unenforced_ratio.mean(), unenforced_ratio.std()
[6]:
(0.3669882612570243, 0.13488926187436895)
[7]:
rel = unenforced_ratio / (1 + unenforced_ratio)
print(rel.mean(), rel.std())
print(unenforced_ratio.mean() / (1 + unenforced_ratio.mean()))
0.26198354528285633 0.06584885182554004
0.2684648227480443
[8]:
OrgResults = {}
for en, (version, results) in enumerate(RATIOS.items()):
    target_ratio, v, acc = results
    if v.hparams["velocity_ratio_params"]["enforce"] == 0:
        target_ratio = unenforced_ratio.mean()  # 0.3669882612570243
    if not target_ratio in OrgResults.keys():
        OrgResults[target_ratio] = []
    best = acc[acc.loc["unique_train.all_fates"].idxmax()]["unique_test.all_fates"]
    OrgResults[target_ratio].append(best)
[9]:
sorted_results = {
    key: OrgResults[key]
    for key in sorted(np.array(list(OrgResults.keys())).astype(float))
}
[10]:
for k, v in sorted_results.items():
    print("{:>9.4f} {:.3f}  {:.4f} | {}".format(k, np.mean(v), np.std(v), len(v)))
   0.0010 0.523  0.0164 | 5
   0.0100 0.499  0.0253 | 5
   0.3670 0.513  0.0271 | 5
   0.5000 0.520  0.0320 | 5
   1.0000 0.530  0.0160 | 5
   1.5000 0.506  0.0154 | 5
   2.2500 0.529  0.0386 | 5
   2.5000 0.548  0.0260 | 5
   3.0000 0.522  0.0381 | 5
   5.0000 0.530  0.0157 | 5
  10.0000 0.523  0.0347 | 5
  20.0000 0.391  0.0713 | 5
  30.0000 0.292  0.0467 | 4

Define helper functions for isolating best validation epochs#

[13]:
def _test_from_best_train(acc):
    best_train_epoch = acc.loc["unique_train.all_fates"].idxmax()
    best_train = acc[best_train_epoch]
    return best_train_epoch, best_train["unique_test.all_fates"]


def _get_best_F_hat_path(version, best_epoch):
    """use glob to filter and grab the right saved F_hat"""
    try:
        regex = f"LightningSDE-FixedPotential-RegularizedVelocityRatio/version_{version}/fate_prediction_metrics/{best_epoch}*/F_hat.processed.csv"
        return glob.glob(regex)[0]
    except:
        regex = f"LightningSDE-FixedPotential-RegularizedVelocityRatio/version_{version}/fate_prediction_metrics/on_train_end*/F_hat.processed.csv"
        print(regex)
        return glob.glob(regex)[0]


def _convert_best_epoch_name(row):
    if row["best_epoch"] == 2500:
        return "last"
    if row["best_epoch"] == 2499:
        return "on_train_end"
    return f"epoch_{row['best_epoch']}"


def get_best_results(RATIOS):
    BestResults = []
    for en, (version, results) in enumerate(RATIOS.items()):
        #         if not version in ["version_46", "version_71"]:

        target_ratio, v, acc = results
        #             if target_ratio == 30:
        #                 print(acc)
        #                 print()
        #         if not str(target_ratio) in OrgResults.keys():
        #             BestResults[] = []

        best_epoch, best_score = _test_from_best_train(acc)
        v_key = version.split("version_")[-1]
        BestResults.append(
            {
                "target_ratio": str(target_ratio),
                "best_epoch": best_epoch,
                "best_score": best_score,
                "version": v_key,
            }
        )

    return pd.DataFrame(BestResults)
[14]:
best_results = get_best_results(RATIOS)
best_results
[14]:
target_ratio best_epoch best_score version
0 2.0 1891 0.508731 0
1 2.0 2500 0.485448 1
2 5.0 2285 0.522701 10
3 0.01 1674 0.472643 11
4 1.0 1294 0.552969 12
... ... ... ... ...
59 2.25 1603 0.516880 62
60 2.25 2285 0.477299 63
61 5.0 1330 0.514552 7
62 5.0 1119 0.555297 8
63 1.0 1287 0.534342 9

64 rows × 4 columns

[31]:
unreg = OrgResults[0.3669882612570243]
unreg
[31]:
[0.5087310826542492,
 0.4854481955762514,
 0.4982537834691501,
 0.5646100116414435,
 0.509895227008149]
[58]:
grouped_accuracy = {
    group: group_df["best_score"].values
    for group, group_df in best_results.groupby("target_ratio")
}
grouped_accuracy[0.3669882612570243] = unreg
[67]:
sorted_keys = [
    0.001,
    0.01,
    0.3669882612570243,
    0.5,
    1.0,
    1.5,
    2.0,
    2.25,
    2.5,
    3.0,
    5.0,
    10.0,
    20.0,
    30.0,
]
[73]:
fig, axes = cp.plot(height=0.8, width=2)

for en, k in enumerate(sorted_keys):
    try:
        y = grouped_accuracy[str(k)]
    except:
        y = grouped_accuracy[k]
    x = np.full(len(y), fill_value=float(k))
    xr = np.random.random(len(x)) * (x / 20)
    x = x + xr
    axes[0].scatter(x, y, ec="None", c=cm.tab20.colors[en])
axes[0].set_xscale("log")
axes[0].set_ylim(0.2, 0.6)
axes[0].set_xlabel("Velocity Ratio")
axes[0].set_ylabel("Accuracy Score")
plt.savefig("fate_prediction.velocity_ratio.hp_optimization.svg")
../_images/_analyses_FigureS5AB_22_0.png
[82]:
for en, k in enumerate(sorted_keys):
    try:
        y = grouped_accuracy[str(k)]
    except:
        y = grouped_accuracy[k]
    print("{:<20}".format(k), "{:.3f}".format(np.mean(y)), " {:.4f}".format(np.std(y)))
0.001                0.523  0.0164
0.01                 0.499  0.0253
0.3669882612570243   0.513  0.0271
0.5                  0.520  0.0320
1.0                  0.530  0.0160
1.5                  0.506  0.0154
2.0                  0.513  0.0271
2.25                 0.529  0.0386
2.5                  0.548  0.0260
3.0                  0.522  0.0381
5.0                  0.530  0.0157
10.0                 0.523  0.0347
20.0                 0.391  0.0713
30.0                 0.292  0.0467
[ ]: