Figure S4#

Import packages#

[1]:
%load_ext nb_black

import scdiffeq as sdq
import dev
import cellplots as cp
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import pandas as pd
import numpy as np
import ABCParse
import os
import pathlib
import seaborn as sns
import larry
import scipy.stats
import glob
import matplotlib

Read data and load project#

[2]:
h5ad_path = (
    "/home/mvinyard/data/adata.reprocessed_19OCT2023.more_feature_inclusive.h5ad"
)
adata = sdq.io.read_h5ad(h5ad_path)
AnnData object with n_obs × n_vars = 130887 × 2492
    obs: 'Library', 'Cell barcode', 'Time point', 'Starting population', 'Cell type annotation', 'Well', 'SPRING-x', 'SPRING-y', 'clone_idx', 'fate_observed', 't0_fated', 'train'
    var: 'gene_ids', 'hv_gene', 'must_include', 'exclude', 'use_genes'
    uns: 'fate_counts', 'h5ad_path', 'time_occupance'
    obsm: 'X_clone', 'X_pca', 'X_umap', 'cell_fate_df'
    layers: 'X_scaled'
[3]:
project = sdq.io.Project(path="./LightningSDE-FixedPotential-RegularizedVelocityRatio/")

Compute fate prediction accuracy for each version#

[4]:
RATIOS = {}
for version, path in project._VERSION_PATHS.items():
    try:
        v = getattr(project, version)
        acc = dev.fate_prediction_accuracy(v)
        target_ratio = v.hparams["velocity_ratio_params"]["target"]
        RATIOS[version] = target_ratio, v, acc
    except:
        pass

For each un-regularized training run, compute drift [f] and diffusion [g]#

[5]:
def f(df, key="training"):
    return df.filter(regex="velo_f").filter(regex=key).dropna().sum(1).mean()


def g(df, key="training"):
    return df.filter(regex="velo_g").filter(regex=key).dropna().sum(1).mean()


UnEnforcedResults = {}
for en, (version, results) in enumerate(RATIOS.items()):
    target_ratio, v, acc = results
    if v.hparams["velocity_ratio_params"]["enforce"] == 0:
        print(version)
        grouped = v.metrics_df.groupby("epoch")
        UnEnforcedResults[version] = {
            "f_training": grouped.apply(f, key="training"),
            "f_validation": grouped.apply(f, key="validation"),
            "g_training": grouped.apply(g, key="training"),
            "g_validation": grouped.apply(g, key="validation"),
        }
version_0
version_1
version_2
version_3
version_4

Compute the empirical unenforced ratio#

[6]:
unenforced_ratio = np.array(
    [
        v["f_validation"].iloc[-1] / v["g_validation"].iloc[-1]
        for i, (k, v) in enumerate(UnEnforcedResults.items())
    ]
)
unenforced_ratio.mean(), unenforced_ratio.std()
[6]:
(0.3669882612570243, 0.13488926187436895)
[7]:
rel = unenforced_ratio / (1 + unenforced_ratio)
print(rel.mean(), rel.std())
print(unenforced_ratio.mean() / (1 + unenforced_ratio.mean()))
0.26198354528285633 0.06584885182554004
0.2684648227480443
[8]:
OrgResults = {}
for en, (version, results) in enumerate(RATIOS.items()):
    target_ratio, v, acc = results
    if v.hparams["velocity_ratio_params"]["enforce"] == 0:
        target_ratio = unenforced_ratio.mean()  # 0.3669882612570243
    if not target_ratio in OrgResults.keys():
        OrgResults[target_ratio] = []
    best = acc[acc.loc["unique_train.all_fates"].idxmax()]["unique_test.all_fates"]
    OrgResults[target_ratio].append(best)
[9]:
sorted_results = {
    key: OrgResults[key]
    for key in sorted(np.array(list(OrgResults.keys())).astype(float))
}
[10]:
for k, v in sorted_results.items():
    print("{:>9.4f} {:.3f}  {:.4f} | {}".format(k, np.mean(v), np.std(v), len(v)))
   0.0010 0.523  0.0164 | 5
   0.0100 0.499  0.0253 | 5
   0.3670 0.513  0.0271 | 5
   0.5000 0.520  0.0320 | 5
   1.0000 0.530  0.0160 | 5
   1.5000 0.506  0.0154 | 5
   2.2500 0.529  0.0386 | 5
   2.5000 0.548  0.0260 | 5
   3.0000 0.522  0.0381 | 5
   5.0000 0.530  0.0157 | 5
  10.0000 0.523  0.0347 | 5
  20.0000 0.391  0.0713 | 5
  30.0000 0.292  0.0467 | 4

Plot F_obs and F_hat#

[11]:
matplotlib.rcParams["xtick.major.width"] = 0.5


class FateBiasHeatmap(ABCParse.ABCParse):

    def __init__(self, savedir=os.getcwd(), *args, **kwargs):
        self.__parse__(locals())

        self.img_dir = pathlib.Path(savedir).joinpath("F_hat_plots")
        if not self.img_dir.exists():
            self.img_dir.mkdir()

        self._plot_obs()

    @property
    def _LARRY_cmap(self):
        return larry.pl.InVitroColorMap()._dict

    def _col_filter(self, df, search_cols):
        return df[[col for col in search_cols if col in df.columns]]

    @property
    def F_obs(self):
        if not hasattr(self, "_F_obs"):
            self._F_obs = larry.tasks.fate_prediction.F_obs
        return self._F_obs

    def _plot_obs(self):
        self.cg = sns.clustermap(
            self.F_obs,
            figsize=(3, 5),
            cmap="Blues",
            xticklabels=self.F_obs.columns,
            yticklabels=[],
            cbar_pos=None,  # cbar_pos,
            vmin=0,
            vmax=1,
            center=0.5,
            col_colors=[self._LARRY_cmap[fate] for fate in self.F_obs.columns],
            rasterized=True,
        )
        self.cg.ax_heatmap.set_title("True fate bias", fontsize=10, y=1.3)

        plt.savefig(self.img_dir.joinpath("F_obs.png"))

    @property
    def clustered_cells(self):
        return self.cg.dendrogram_row.reordered_ind

    @property
    def clustered_fates(self):
        return self.cg.dendrogram_col.reordered_ind

    def augment_fate_columns(self, F_hat):
        """"""
        F_hat_ = F_hat.copy()
        for col in self.F_obs.columns:
            if not col in F_hat_:
                F_hat_[col] = 0

        return F_hat_

    def _compose_F_hat_plot(self, F_hat):

        F_hat_plot = self.augment_fate_columns(F_hat)
        F_hat_plot = F_hat_plot[self.F_obs.columns[self.clustered_fates]]
        return F_hat_plot.iloc[self.clustered_cells]

    def plot_F_hat(self, F_hat_plot, title):
        self._cg_F_hat = sns.clustermap(
            F_hat_plot,
            figsize=(3, 5),
            cmap="Blues",
            xticklabels=F_hat_plot.columns,
            yticklabels=[],
            cbar_pos=None,
            row_cluster=False,
            col_cluster=False,
            vmin=0,
            vmax=1,
            center=0.5,
            col_colors=[self._LARRY_cmap[fate] for fate in F_hat_plot.columns],
            rasterized=True,
        )
        if not title is None:
            self._cg_F_hat.ax_heatmap.set_title(title, fontsize=10, y=1.05)

    def __call__(self, F_hat, title=None, *args, **kwargs):
        self.__update__(locals())

        F_hat_plot = self._compose_F_hat_plot(F_hat)
        self.plot_F_hat(F_hat_plot, title=title)
[12]:
def _test_from_best_train(acc):
    best_train_epoch = acc.loc["unique_train.all_fates"].idxmax()
    best_train = acc[best_train_epoch]
    return best_train_epoch, best_train["unique_test.all_fates"]


def _get_best_F_hat_path(version, best_epoch):
    """use glob to filter and grab the right saved F_hat"""
    try:
        regex = f"LightningSDE-FixedPotential-RegularizedVelocityRatio/version_{version}/fate_prediction_metrics/{best_epoch}*/F_hat.processed.csv"
        return glob.glob(regex)[0]
    except:
        regex = f"LightningSDE-FixedPotential-RegularizedVelocityRatio/version_{version}/fate_prediction_metrics/on_train_end*/F_hat.processed.csv"
        print(regex)
        return glob.glob(regex)[0]


def _convert_best_epoch_name(row):
    if row["best_epoch"] == 2500:
        return "last"
    if row["best_epoch"] == 2499:
        return "on_train_end"
    return f"epoch_{row['best_epoch']}"


def get_best_results(RATIOS):
    BestResults = []
    for en, (version, results) in enumerate(RATIOS.items()):
        if not version in ["version_46", "version_71"]:

            target_ratio, v, acc = results
            #             if target_ratio == 30:
            #                 print(acc)
            #                 print()
            #         if not str(target_ratio) in OrgResults.keys():
            #             BestResults[] = []

            best_epoch, best_score = _test_from_best_train(acc)
            v_key = version.split("version_")[-1]
            BestResults.append(
                {
                    "target_ratio": str(target_ratio),
                    "best_epoch": best_epoch,
                    "best_score": best_score,
                    "version": v_key,
                }
            )

    return pd.DataFrame(BestResults)
[13]:
project_path = project._PROJECT_PATH

fate_bias_heatmap = FateBiasHeatmap()
paths = list(project_path.glob("*/fate_prediction_metrics/*/F_hat.processed.csv"))
plt.savefig("F_obs.svg", dpi=400)
../_images/_analyses_FigureS4_19_0.png
[14]:
best_results = get_best_results(RATIOS)
best_results
[14]:
target_ratio best_epoch best_score version
0 2.0 1891 0.508731 0
1 2.0 2500 0.485448 1
2 5.0 2285 0.522701 10
3 0.01 1674 0.472643 11
4 1.0 1294 0.552969 12
... ... ... ... ...
58 2.25 1603 0.516880 62
59 2.25 2285 0.477299 63
60 5.0 1330 0.514552 7
61 5.0 1119 0.555297 8
62 1.0 1287 0.534342 9

63 rows × 4 columns

[15]:
best_results.loc[best_results["target_ratio"] == "3.0"]
[15]:
target_ratio best_epoch best_score version
41 3.0 1339 0.579744 47
43 3.0 1311 0.474971 49
46 3.0 2499 0.540163 51
48 3.0 1341 0.528522 53
50 3.0 2500 0.485448 55
[16]:
def get_grouped_F_hat_dfs(group_df):
    GroupF_hats = []
    for i, row in group_df.iterrows():
        #         print(row)
        F_hat_path = _get_best_F_hat_path(
            version=row["version"], best_epoch=_convert_best_epoch_name(row)
        )
        F_hat = pd.read_csv(F_hat_path, index_col=0)
        GroupF_hats.append(F_hat)
    return GroupF_hats
[17]:
# "LightningSDE-FixedPotential-RegularizedVelocityRatio/version_59/fate_prediction_metrics/on_train_end.epoch_2500/"
best_results
[17]:
target_ratio best_epoch best_score version
0 2.0 1891 0.508731 0
1 2.0 2500 0.485448 1
2 5.0 2285 0.522701 10
3 0.01 1674 0.472643 11
4 1.0 1294 0.552969 12
... ... ... ... ...
58 2.25 1603 0.516880 62
59 2.25 2285 0.477299 63
60 5.0 1330 0.514552 7
61 5.0 1119 0.555297 8
62 1.0 1287 0.534342 9

63 rows × 4 columns

[18]:
best_F_hats = (
    best_results.groupby("target_ratio").apply(get_grouped_F_hat_dfs).to_dict()
)
LightningSDE-FixedPotential-RegularizedVelocityRatio/version_50/fate_prediction_metrics/on_train_end*/F_hat.processed.csv
[20]:
for k, v in best_F_hats.items():
    fate_bias_heatmap(F_hat=v[0], title=f" Rv= {k}")
    plt.savefig(f"./F_hat.Rv_{k}.svg", dpi=400)
../_images/_analyses_FigureS4_25_0.png
../_images/_analyses_FigureS4_25_1.png
../_images/_analyses_FigureS4_25_2.png
../_images/_analyses_FigureS4_25_3.png
../_images/_analyses_FigureS4_25_4.png
../_images/_analyses_FigureS4_25_5.png
../_images/_analyses_FigureS4_25_6.png
../_images/_analyses_FigureS4_25_7.png
../_images/_analyses_FigureS4_25_8.png
../_images/_analyses_FigureS4_25_9.png
../_images/_analyses_FigureS4_25_10.png
../_images/_analyses_FigureS4_25_11.png
../_images/_analyses_FigureS4_25_12.png