Source code for scdiffeq.tools._simulation

# -- import packages: ---------------------------------------------------------
import ABCParse
import adata_query
import anndata
import autodevice
import lightning
import numpy as np
import pandas as pd
import torch

# -- set type hints: ----------------------------------------------------------
from typing import Optional

# -- operational cls: ---------------------------------------------------------
[docs] class Simulation(ABCParse.ABCParse): """Sampled trajectories from an scDiffEq model"""
[docs] def __init__( self, use_key: str = "X_pca", time_key: str = "Time point", N: int = 1, device: Optional[torch.device] = autodevice.AutoDevice(), *args, **kwargs, ): """ Args: use_key (str): description. **Default**: "X_pca" N (int): number of trajectories to sample from the model. **Default**: 2000 device (bool): description. **Default**: "cuda:0" Returns: None """ self.__parse__(locals()) self._T_GIVEN = False
@property def _adata_init(self): ... @property def idx(self): """ """ if not hasattr(self, "_idx"): self._idx = self._adata.obs.index return self._idx @property def Z0(self) -> torch.Tensor: """ """ if not hasattr(self, "_Z0"): self._Z0 = adata_query.fetch( self._adata[self.idx], key=self._use_key, torch=True, device=self._device, ) if self._N > 1: self._Z0 = self._Z0[None, :, :] self._Z0 = self._Z0.expand(self._N, -1, -1).flatten(0, 1) return self._Z0 @property def _TIME(self) -> pd.Series: if not self._T_GIVEN: return self._adata.obs[self._time_key] return self.t @property def _T_MIN(self) -> float: """ """ return self._TIME.min() @property def _T_MAX(self) -> float: """ """ return self._TIME.max() @property def _N_STEPS(self) -> float: return int(((self._T_MAX - self._T_MIN) / self._dt) + 1) @property def t(self) -> torch.Tensor: """ """ if not hasattr(self, "_t"): self._t = torch.linspace( self._T_MIN, self._T_MAX, self._N_STEPS, ).to(self._device) return self._t @property def _N_CELLS(self): return self._N * len(self.idx)
[docs] def forward(self, Z0, t) -> torch.Tensor: return self._diffeq.forward(Z0, t).detach().cpu().flatten(0, 1).numpy()
[docs] def _to_adata_sim(self, Z_hat: np.ndarray) -> anndata.AnnData: """ Args: Z_hat (np.ndarray) Returns: adata_sim (anndata.AnnData) """ adata_sim = anndata.AnnData(Z_hat) adata_sim.obs["t"] = np.repeat(self.t.detach().cpu().numpy(), self._N_CELLS) adata_sim.obs["z0_idx"] = np.tile(np.tile(self.idx, self._N_STEPS), self._N) adata_sim.obs["sim_i"] = np.tile( np.arange(self._N).repeat(len(self.idx)), self._N_STEPS ) adata_sim.obs["sim"] = adata_sim.obs["z0_idx"].astype(str) + adata_sim.obs[ "sim_i" ].astype(str) adata_sim.uns["sim_idx"] = self.idx adata_sim.uns["simulated"] = True return adata_sim
[docs] def __call__( self, diffeq, adata: anndata.AnnData, idx: pd.Index, dt: float = 0.1, t: Optional[torch.Tensor] = None, *args, **kwargs, ) -> anndata.AnnData: """Simulate trajectories by sampling from an scDiffEq model. Args: diffeq (): lightning model. adata (anndata.AnnData): Input AnnDat object. idx (pd.Index): cell indices (corresponding to `adata` from which the model should initiate sampled trajectories. Returns: adata_sim (anndata.AnnData) """ self.__update__(locals()) self._diffeq.to(self._device) if not t is None: self._T_GIVEN = True Z_hat = self.forward(self.Z0, self.t) return self._to_adata_sim(Z_hat)
[docs] def simulate( adata: anndata.AnnData, diffeq: lightning.LightningModule, idx: Optional[pd.Index] = None, use_key: str = "X_pca", time_key: str = "Time point", N: Optional[int] = 1, t: Optional[torch.Tensor] = None, dt: Optional[float] = 0.1, device: Optional[torch.device] = autodevice.AutoDevice(), *args, **kwargs, ) -> anndata.AnnData: """ Simulate trajectories by sampling from an scDiffEq model. Parameters ---------- adata : anndata.AnnData Input AnnData object. idx : pd.Index Cell indices (corresponding to `adata`) from which the model should initiate sampled trajectories. diffeq : lightning.LightningModule The differential equation model. use_key : str, optional adata accession key for the input data. Default is "X_pca". N : int, optional Number of trajectories to sample from the model. Default is 2000. device : torch.device, optional Device to run the simulation on. Default is True. Returns ------- anndata.AnnData AnnData object encapsulating scDiffEq model simulation. """ if diffeq.__repr__() == "scDiffEq": diffeq = diffeq.DiffEq simulation = Simulation( use_key=use_key, time_key=time_key, N=N, device=device, ) return simulation(diffeq=diffeq, adata=adata, idx=idx, t=t, dt=dt)