Load trained scDiffEq model#

This example highlights loading a model that was trained on the full LARRY dataset for fate perturbation (Task 3 in the scDiffEq manuscript).

Import dependencies#

[1]:
import larry
import scdiffeq as sdq
import scdiffeq_analyses as sdq_an

F_obs = larry.tasks.fate_prediction.F_obs

Load data#

[2]:
adata = sdq.datasets.larry()
adata
[2]:
AnnData object with n_obs × n_vars = 130887 × 2447
    obs: 'Library', 'Cell barcode', 'Time point', 'Starting population', 'Cell type annotation', 'Well', 'SPRING-x', 'SPRING-y', 'clone_idx', 'fate_observed', 't0_fated'
    var: 'gene_ids', 'hv_genes', 'use_genes'
    uns: 'fate_counts', 'time_occupance'
    obsm: 'X_clone', 'X_pca', 'X_scaled', 'cell_fate_df'

Load the project (sdq.io.Project)#

Important: first, clone and locally install the scdiffeq-analyses repo:

git clone https://github.com/scDiffEq/scdiffeq-analyses.git;
cd scdiffeq-analyses; pip install -e .
[3]:
project_path = "./scdiffeq-analyses/manuscript/models/LARRY.full_dataset/LightningSDE-FixedPotential-RegularizedVelocityRatio/"
project = sdq.io.Project(path=project_path)

Get the best checkpoints#

[4]:
best_ckpts_df = sdq_an.parsers.best_checkpoints(project=project)
best_ckpts_df
[4]:
train test ckpt_path epoch
version_0 0.571656 0.551804 /Users/michaelvinyard/GitHub/scdiffeq-analyses... 2500
version_1 0.541401 0.465658 /Users/michaelvinyard/GitHub/scdiffeq-analyses... 1706
version_2 0.547771 0.499418 /Users/michaelvinyard/GitHub/scdiffeq-analyses... 1238
version_3 0.496815 0.504075 /Users/michaelvinyard/GitHub/scdiffeq-analyses... 1245
version_4 0.562102 0.522701 /Users/michaelvinyard/GitHub/scdiffeq-analyses... 1662
[5]:
ckpt_path = best_ckpts_df['ckpt_path'].loc['version_0']

Load the model#

[6]:
model = sdq.io.load_model(adata = adata, ckpt_path = ckpt_path)
print(model)
model.to("mps:0") # or "cuda:0", for example
Seed set to 0
scDiffEq

Alternatively: load only the DiffEq#

[7]:
DiffEq = sdq.io.load_diffeq(ckpt_path=ckpt_path)
print(DiffEq)
LightningSDE-FixedPotential-RegularizedVelocityRatio