Figure 2F#

Simulated interpolation#

[ ]:
%load_ext nb_black

import matplotlib.pyplot as plt
import scdiffeq as sdq
import scdiffeq_analyses as sdq_an
import adata_query
import larry
import pandas as pd
import numpy as np
import tqdm.notebook
import ABCParse
import autodevice
import torch
import pathlib
import glob

from typing import Tuple
[1]:
h5ad_path = (
    "/home/mvinyard/data/adata.reprocessed_19OCT2023.more_feature_inclusive.h5ad"
)
adata = sdq.io.read_h5ad(h5ad_path)

project_path = "/home/mvinyard/experiments/LARRY.full_dataset/LightningSDE-FixedPotential-RegularizedVelocityRatio"
project = sdq.io.Project(path=project_path)

InterpolationData = larry.tasks.interpolation.InterpolationData
sinkhorn = SinkhornDivergence()
SinkhornDivergence = sdq.core.lightning_models.base.SinkhornDivergence

class InterpolationTask(ABCParse.ABCParse):
    def __init__(
        self,
        adata,
        time_key="Time point",
        use_key="X_pca",
        t0=2,
        n_samples=10_000,
        lineage_key="clone_idx",
        device=autodevice.AutoDevice(),
        backend = "auto",
        silent = False,
        PCA = None,
        *args,
        **kwargs,
    ):
        self.__parse__(locals())

        self.data = InterpolationData(**self._DATA_KWARGS)

        self.SinkhornDivergence = SinkhornDivergence(**self._SINKHORN_KWARGS)

    @property
    def _DATA_KWARGS(self):
        return ABCParse.function_kwargs(
            func=InterpolationData, kwargs=self._PARAMS
        )
    @property
    def _SINKHORN_KWARGS(self):
        return ABCParse.function_kwargs(
            func=SinkhornDivergence, kwargs=self._PARAMS
        )

    def forward_without_grad(self, DiffEq):
        """Forward integrate over the model without gradients."""
        with torch.no_grad():
            X_hat = DiffEq.forward(self.data.X0, self.data.t)
            return self._parse_forward_out(X_hat)

    def forward_with_grad(self, DiffEq):
        """Forward integrate over the model retaining gradients."""
        torch.set_grad_enabled(True)
        X_hat = DiffEq.forward(self.data.X0, self.data.t)
        return self._parse_forward_out(X_hat)

    @property
    def potential(self):
        return "Potential" in str(self.DiffEq)

    def _parse_forward_out(self, X_hat):
        """to account for KLDiv"""
        if isinstance(X_hat, Tuple):
            return X_hat[0]
        return X_hat

    def _dimension_reduce_pca(self, X_hat):
        return torch.stack(
            [torch.Tensor(self.PCA.transform(x)) for x in X_hat.detach().cpu().numpy()]
        ).to(self.device)


    def __call__(self, trainer, DiffEq, *args, **kwargs):

        self.__update__(locals())

        if self.potential:
            X_hat = self.forward_with_grad(DiffEq)
        else:
            X_hat = self.forward_without_grad(DiffEq)

        if not self.PCA is None:
            X_hat = self._dimension_reduce_pca(X_hat)

        d4_loss = self.SinkhornDivergence(X_hat[1], self.data.X_test_d4).item()
        d6_loss = self.SinkhornDivergence(X_hat[2], self.data.X_train_d6).item()

        if not self.silent:
            print(
                "- Epoch: {:<5}| Day 4 loss: {:.2f} | Day 6 loss: {:.2f}".format(
                    DiffEq.current_epoch, d4_loss, d6_loss,
                ),
            )

        return d4_loss, d6_loss

AnnData object with n_obs × n_vars = 130887 × 2492
    obs: 'Library', 'Cell barcode', 'Time point', 'Starting population', 'Cell type annotation', 'Well', 'SPRING-x', 'SPRING-y', 'clone_idx', 'fate_observed', 't0_fated', 'train'
    var: 'gene_ids', 'hv_gene', 'must_include', 'exclude', 'use_genes'
    uns: 'fate_counts', 'h5ad_path', 'time_occupance'
    obsm: 'X_clone', 'X_pca', 'X_umap', 'cell_fate_df'
    layers: 'X_scaled'
[5]:
class Results:
    def __init__(self, Results):
        self.d2 = pd.DataFrame([result["d2"] for i, result in Results.items()]).T
        self.d4 = pd.DataFrame([result["d4"] for i, result in Results.items()]).T
        self.d6 = pd.DataFrame([result["d6"] for i, result in Results.items()]).T
        self.d2_d4 = pd.DataFrame([result["d2.d4"] for i, result in Results.items()]).T
        self.d4_d6 = pd.DataFrame([result["d4.d6"] for i, result in Results.items()]).T
        self.d2_d6 = pd.DataFrame([result["d2.d6"] for i, result in Results.items()]).T
        self.d2_d2 = pd.DataFrame([result["d2.d2"] for i, result in Results.items()]).T
        self.d4_d4 = pd.DataFrame([result["d4.d4"] for i, result in Results.items()]).T
        self.d6_d6 = pd.DataFrame([result["d6.d6"] for i, result in Results.items()]).T


class PiecewiseDistance(ABCParse.ABCParse):
    def __init__(
        self,
        adata,
        n_samples: int = 10_000,
        t=torch.linspace(2, 6, 41),
        device: torch.device = autodevice.AutoDevice(),
        *args,
        **kwargs,
    ):
        self.__parse__(locals())
        self._df = self._adata.obs.copy()
        self._clonal_df = self._df.loc[self._df["clone_idx"].notna()]

    def sampling(self, group_df, N: int = 10_000):
        if N > group_df.shape[0]:
            replace = True
        else:
            replace = False
        return group_df.sample(N, replace=True).index

    def sample_indices(self, clonal_df):
        print("Call clonal sampling")
        return {
            group: self.sampling(group_df)
            for group, group_df in clonal_df.groupby("Time point")
        }

    @property
    def X2(self):
        if not hasattr(self, "_X2"):
            self._X2 = adata_query.fetch(
                self._adata[self.indices[2]],
                key="X_pca",
                torch=True,
                device=self._device,
            )
        return self._X2

    @property
    def X4(self):
        if not hasattr(self, "_X4"):
            self._X4 = adata_query.fetch(
                self._adata[self.indices[4]],
                key="X_pca",
                torch=True,
                device=self._device,
            )
        return self._X4

    @property
    def X6(self):
        if not hasattr(self, "_X6"):
            self._X6 = adata_query.fetch(
                self._adata[self.indices[6]],
                key="X_pca",
                torch=True,
                device=self._device,
            )
        return self._X6

    @property
    def t(self):
        return self._t.to(self._device)

    def simulate(self, DiffEq):
        return DiffEq.forward(self.X2, t=self.t)

    @property
    def divs(self):
        if not hasattr(self, "_divs"):
            divs = np.linspace(0, self._n_samples, 11).astype(int)
            self._divs = (divs[:-1], divs[1:])
        return self._divs

    @property
    def Z_hat(self):
        if not hasattr(self, "_Z_hat"):
            self._Z_hat = self.simulate(self._DiffEq)
        return self._Z_hat

    def forward(self, XA, XB, *args, **kwargs):
        self.__update__(locals())

        self._distances = []
        div_i, div_j = self.divs
        with torch.no_grad():
            for i, j in zip(div_i, div_j):
                self._distances.append(sinkhorn(XA[i:j], XB[i:j]))
        return torch.stack(self._distances).detach().cpu().mean().item()

    def compute_d2_distance(self):
        return [self.forward(self.Z_hat[i], self.X2) for i in range(len(self.Z_hat))]

    def compute_d4_distance(self):
        return [self.forward(self.Z_hat[i], self.X4) for i in range(len(self.Z_hat))]

    def compute_d6_distance(self):
        return [self.forward(self.Z_hat[i], self.X6) for i in range(len(self.Z_hat))]

    def compute_self_distances(self):

        self.X2_X4 = self.forward(self.X2, self.X4)
        self.X2_X6 = self.forward(self.X2, self.X6)
        self.X4_X6 = self.forward(self.X4, self.X6)

        self.X2_X2 = self.forward(self.X2, self.X2)
        self.X4_X4 = self.forward(self.X4, self.X4)
        self.X6_X6 = self.forward(self.X6, self.X6)

    def compute_d2_d6(self):
        return self.forward(self.X2, self.X6)

    def compute_d4_d4(self):
        return self.forward(self.X2, self.X4)

    def compute_d4_d6(self):
        return self.forward(self.X2, self.X6)

    def __call__(self, DiffEq, N=5):
        self.__update__(locals())

        _Results = {}
        for i in tqdm.notebook.tqdm(range(self._N)):
            self.indices = self.sample_indices(self._clonal_df)
            self.compute_self_distances()
            _Results[i] = {
                "d2": self.compute_d2_distance(),
                "d4": self.compute_d4_distance(),
                "d6": self.compute_d6_distance(),
                "d2.d2": self.X2_X2,
                "d4.d4": self.X4_X4,
                "d6.d6": self.X6_X6,
                "d2.d4": self.X2_X4,
                "d4.d6": self.X4_X6,
                "d2.d6": self.X2_X6,
            }

            del self._X2
            del self._X4
            del self._X6
            del self._Z_hat

        return Results(_Results)
[8]:
best_ckpts = sdq_an.parsers.summarize_best_checkpoints(project)

for vname, ckpt_path in best_ckpts["ckpt_path"].items():
    save_path = pathlib.Path(f"sdq.simulate_interpolation.{vname}.pkl")
    if not save_path.exists():
        model = sdq.io.load_model(adata, ckpt_path=ckpt_path)
        pw_distance = PiecewiseDistance(adata)
        version_results = pw_distance(model.DiffEq)
        sdq.io.write_pickle(version_results, save_path)
 - [INFO] | Input data configured.
 - [INFO] | Bulding Annoy kNN Graph on adata.obsm['train']
Seed set to 0
 - [INFO] | Using the specified parameters, LightningSDE-FixedPotential-RegularizedVelocityRatio has been called.
Call clonal sampling
Call clonal sampling
Call clonal sampling
Call clonal sampling
Call clonal sampling
[11]:
d2 = pd.DataFrame(
    [
        sdq.io.read_pickle(path).d2.mean(1)
        for path in glob.glob("sdq.simulate_interpolation.*.pkl")
    ]
).T
d4 = pd.DataFrame(
    [
        sdq.io.read_pickle(path).d4.mean(1)
        for path in glob.glob("sdq.simulate_interpolation.*.pkl")
    ]
).T
d6 = pd.DataFrame(
    [
        sdq.io.read_pickle(path).d6.mean(1)
        for path in glob.glob("sdq.simulate_interpolation.*.pkl")
    ]
).T
d2_d6 = pd.DataFrame(
    [
        sdq.io.read_pickle(path).d2_d6.mean(1)
        for path in glob.glob("sdq.simulate_interpolation.*.pkl")
    ]
).mean()[0]
d2_d4 = pd.DataFrame(
    [
        sdq.io.read_pickle(path).d2_d4.mean(1)
        for path in glob.glob("sdq.simulate_interpolation.*.pkl")
    ]
).mean()[0]
[12]:
import cellplots as cp
[13]:
time_cmap = sdq_an.pl.TimeColorMap()()
[14]:
len(time_cmap)
[14]:
41
[22]:
colors = [time_cmap[4], time_cmap[20], time_cmap[40]]
[35]:
fig, axes = cp.plot(
    1,
    1,
    height=0.5,
    width=0.5,
    x_label=["time steps"],
    y_label=["Wasserstein Distance"],
    title=["Learned time concordance"],
)
ax = axes[0]
for en, d in enumerate([d2, d4, d6]):
    mean = d.mean(1)
    std = d.std(1)
    lo = mean - std
    hi = mean + std

    ax.fill_between(
        mean.index, lo, hi, color=colors[en], alpha=0.2, ec="None", zorder=2
    )
    ax.plot(mean, color=colors[en])
ax.hlines(d2_d6, 0, 40, lw=1, color="k", ls="--", zorder=3)
ax.hlines(d2_d4, 0, 40, lw=1, color="k", ls="--", zorder=3)
ax.grid(True, alpha=0.2, c="lightgrey")
ax.set_ylim(0, 200)
ax.set_xlim(0, 40)
ax.tick_params(axis="both", width=0.5, length=2.5)
_ = [spine.set_linewidth(0.5) for spine in list(ax.spines.values())]
plt.savefig("interpolation.stepwise_distance.svg", dpi=500)
../_images/_analyses_Figure2F_10_0.png