Figure S5 C#

Import packages#

[1]:
%load_ext nb_black

import scdiffeq as sdq
import dev
import cellplots as cp
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import pandas as pd
import numpy as np
import ABCParse
import os
import pathlib
import seaborn as sns
import larry
import scipy.stats
import glob
import matplotlib

Read data and load project#

[2]:
h5ad_path = (
    "/home/mvinyard/data/adata.reprocessed_19OCT2023.more_feature_inclusive.h5ad"
)
adata = sdq.io.read_h5ad(h5ad_path)
AnnData object with n_obs × n_vars = 130887 × 2492
    obs: 'Library', 'Cell barcode', 'Time point', 'Starting population', 'Cell type annotation', 'Well', 'SPRING-x', 'SPRING-y', 'clone_idx', 'fate_observed', 't0_fated', 'train'
    var: 'gene_ids', 'hv_gene', 'must_include', 'exclude', 'use_genes'
    uns: 'fate_counts', 'h5ad_path', 'time_occupance'
    obsm: 'X_clone', 'X_pca', 'X_umap', 'cell_fate_df'
    layers: 'X_scaled'
[3]:
project = sdq.io.Project(path="./LightningSDE-FixedPotential-RegularizedVelocityRatio/")

Compute fate prediction accuracy for each version#

[4]:
RATIOS = {}
for version, path in project._VERSION_PATHS.items():
    try:
        v = getattr(project, version)
        acc = dev.fate_prediction_accuracy(v)
        target_ratio = v.hparams["velocity_ratio_params"]["target"]
        RATIOS[version] = target_ratio, v, acc
    except:
        pass

For each un-regularized training run, compute drift [f] and diffusion [g]#

[5]:
def f(df, key="training"):
    return df.filter(regex="velo_f").filter(regex=key).dropna().sum(1).mean()


def g(df, key="training"):
    return df.filter(regex="velo_g").filter(regex=key).dropna().sum(1).mean()


UnEnforcedResults = {}
for en, (version, results) in enumerate(RATIOS.items()):
    target_ratio, v, acc = results
    if v.hparams["velocity_ratio_params"]["enforce"] == 0:
        print(version)
        grouped = v.metrics_df.groupby("epoch")
        UnEnforcedResults[version] = {
            "f_training": grouped.apply(f, key="training"),
            "f_validation": grouped.apply(f, key="validation"),
            "g_training": grouped.apply(g, key="training"),
            "g_validation": grouped.apply(g, key="validation"),
        }
version_0
version_1
version_2
version_3
version_4

Compute the empirical unenforced ratio#

[6]:
unenforced_ratio = np.array(
    [
        v["f_validation"].iloc[-1] / v["g_validation"].iloc[-1]
        for i, (k, v) in enumerate(UnEnforcedResults.items())
    ]
)
unenforced_ratio.mean(), unenforced_ratio.std()
[6]:
(0.3669882612570243, 0.13488926187436895)
[7]:
rel = unenforced_ratio / (1 + unenforced_ratio)
print(rel.mean(), rel.std())
print(unenforced_ratio.mean() / (1 + unenforced_ratio.mean()))
0.26198354528285633 0.06584885182554004
0.2684648227480443
[8]:
OrgResults = {}
for en, (version, results) in enumerate(RATIOS.items()):
    target_ratio, v, acc = results
    if v.hparams["velocity_ratio_params"]["enforce"] == 0:
        target_ratio = unenforced_ratio.mean()  # 0.3669882612570243
    if not target_ratio in OrgResults.keys():
        OrgResults[target_ratio] = []
    best = acc[acc.loc["unique_train.all_fates"].idxmax()]["unique_test.all_fates"]
    OrgResults[target_ratio].append(best)
[9]:
sorted_results = {
    key: OrgResults[key]
    for key in sorted(np.array(list(OrgResults.keys())).astype(float))
}
[10]:
for k, v in sorted_results.items():
    print("{:>9.4f} {:.3f}  {:.4f} | {}".format(k, np.mean(v), np.std(v), len(v)))
   0.0010 0.523  0.0164 | 5
   0.0100 0.499  0.0253 | 5
   0.3670 0.513  0.0271 | 5
   0.5000 0.520  0.0320 | 5
   1.0000 0.530  0.0160 | 5
   1.5000 0.506  0.0154 | 5
   2.2500 0.529  0.0386 | 5
   2.5000 0.548  0.0260 | 5
   3.0000 0.522  0.0381 | 5
   5.0000 0.530  0.0157 | 5
  10.0000 0.523  0.0347 | 5
  20.0000 0.391  0.0713 | 5
  30.0000 0.292  0.0467 | 4
[11]:
F_obs = larry.tasks.fate_prediction.F_obs
[15]:
def _test_from_best_train(acc):
    best_train_epoch = acc.loc["unique_train.all_fates"].idxmax()
    best_train = acc[best_train_epoch]
    return best_train_epoch, best_train["unique_test.all_fates"]


def _get_best_F_hat_path(version, best_epoch):
    """use glob to filter and grab the right saved F_hat"""
    try:
        regex = f"LightningSDE-FixedPotential-RegularizedVelocityRatio/version_{version}/fate_prediction_metrics/{best_epoch}*/F_hat.processed.csv"
        return glob.glob(regex)[0]
    except:
        regex = f"LightningSDE-FixedPotential-RegularizedVelocityRatio/version_{version}/fate_prediction_metrics/on_train_end*/F_hat.processed.csv"
        print(regex)
        return glob.glob(regex)[0]


def _convert_best_epoch_name(row):
    if row["best_epoch"] == 2500:
        return "last"
    if row["best_epoch"] == 2499:
        return "on_train_end"
    return f"epoch_{row['best_epoch']}"


def get_best_results(RATIOS):
    BestResults = []
    for en, (version, results) in enumerate(RATIOS.items()):
        if not version in ["version_46", "version_71"]:

            target_ratio, v, acc = results
            #             if target_ratio == 30:
            #                 print(acc)
            #                 print()
            #         if not str(target_ratio) in OrgResults.keys():
            #             BestResults[] = []

            best_epoch, best_score = _test_from_best_train(acc)
            v_key = version.split("version_")[-1]
            BestResults.append(
                {
                    "target_ratio": str(target_ratio),
                    "best_epoch": best_epoch,
                    "best_score": best_score,
                    "version": v_key,
                }
            )

    return pd.DataFrame(BestResults)
[16]:
best_results = get_best_results(RATIOS)
best_results
[16]:
target_ratio best_epoch best_score version
0 2.0 1891 0.508731 0
1 2.0 2500 0.485448 1
2 5.0 2285 0.522701 10
3 0.01 1674 0.472643 11
4 1.0 1294 0.552969 12
... ... ... ... ...
58 2.25 1603 0.516880 62
59 2.25 2285 0.477299 63
60 5.0 1330 0.514552 7
61 5.0 1119 0.555297 8
62 1.0 1287 0.534342 9

63 rows × 4 columns

[17]:
def get_grouped_F_hat_dfs(group_df):
    GroupF_hats = []
    for i, row in group_df.iterrows():
        #         print(row)
        F_hat_path = _get_best_F_hat_path(
            version=row["version"], best_epoch=_convert_best_epoch_name(row)
        )
        F_hat = pd.read_csv(F_hat_path, index_col=0)
        GroupF_hats.append(F_hat)
    return GroupF_hats
[18]:
best_F_hats = (
    best_results.groupby("target_ratio").apply(get_grouped_F_hat_dfs).to_dict()
)
LightningSDE-FixedPotential-RegularizedVelocityRatio/version_50/fate_prediction_metrics/on_train_end*/F_hat.processed.csv
[19]:
Entropy = {}
for key, val in best_F_hats.items():
    Entropy[key] = [scipy.stats.entropy(F_hat, axis=1) for F_hat in val]

MeanStdEntropy = {}
for en, (k, v) in enumerate(Entropy.items()):
    y = np.array(v).flatten()
    y = y[np.isfinite(y)]
    MeanStdEntropy[en] = {"vR": float(k), "mean": y.mean(), "std": y.std()}
entropy = pd.DataFrame(MeanStdEntropy).T
entropy = entropy.sort_values("vR").reset_index(drop=True)
entropy
[19]:
vR mean std
0 0.001 0.738279 0.361552
1 0.010 0.762121 0.391919
2 0.500 0.549093 0.370522
3 1.000 0.503532 0.374606
4 1.500 0.486654 0.380101
5 2.000 0.582116 0.397537
6 2.250 0.431386 0.363275
7 2.500 0.354065 0.373081
8 3.000 0.368025 0.403702
9 5.000 0.340746 0.377944
10 10.000 0.215317 0.313107
11 20.000 0.026230 0.117157
12 30.000 0.000000 0.000000
[21]:
mean_obs_entropy = scipy.stats.entropy(F_obs, axis=1).mean()
[23]:
fig, axes = cp.plot(
    height=0.4,
    width=0.7,
    title=["Entropy as a function of vR"],
    x_label=["Velocity Ratio"],
    y_label=["Mean Entropy"],
)
axes[0].set_xscale("log")
[x.set_linewidth(0.5) for x in axes[0].spines.values()]
axes[0].xaxis.set_tick_params(width=0.5)
axes[0].yaxis.set_tick_params(width=0.5)
# axes[0].plot(mean_std_entropy["R"], mean_std_entropy["mean"], color="k", lw = 0.5)
axes[0].scatter(entropy["vR"], entropy["mean"], s=20, c="k", ec="None")
axes[0].errorbar(
    entropy["vR"],
    entropy["mean"],
    lw=0.5,
    yerr=entropy["std"],
    capsize=3,
    capthick=0.5,
    elinewidth=0.5,
    color="k",
)
plt.xlim(0.0005, 50)
axes[0].hlines(
    y=mean_obs_entropy, xmin=0.0005, xmax=50, color="dodgerblue", ls="--", lw=0.5
)
plt.savefig("entropy_function_of_velo_ratio.fig_s7.svg")
../_images/_analyses_FigureS5C_23_0.png