Load trained scDiffEq model#
This example highlights loading a model that was trained on the full LARRY dataset for fate perturbation (Task 3 in the scDiffEq manuscript).
Import dependencies#
[1]:
import larry
import scdiffeq as sdq
import scdiffeq_analyses as sdq_an
F_obs = larry.tasks.fate_prediction.F_obs
Load data#
[2]:
adata = sdq.datasets.larry()
adata
[2]:
AnnData object with n_obs × n_vars = 130887 × 2447
obs: 'Library', 'Cell barcode', 'Time point', 'Starting population', 'Cell type annotation', 'Well', 'SPRING-x', 'SPRING-y', 'clone_idx', 'fate_observed', 't0_fated'
var: 'gene_ids', 'hv_genes', 'use_genes'
uns: 'fate_counts', 'time_occupance'
obsm: 'X_clone', 'X_pca', 'X_scaled', 'cell_fate_df'
Load the project (sdq.io.Project)#
Important: first, clone and locally install the scdiffeq-analyses repo:
git clone https://github.com/scDiffEq/scdiffeq-analyses.git;
cd scdiffeq-analyses; pip install -e .
[3]:
project_path = "./scdiffeq-analyses/manuscript/models/LARRY.full_dataset/LightningSDE-FixedPotential-RegularizedVelocityRatio/"
project = sdq.io.Project(path=project_path)
Get the best checkpoints#
[4]:
best_ckpts_df = sdq_an.parsers.best_checkpoints(project=project)
best_ckpts_df
[4]:
| train | test | ckpt_path | epoch | |
|---|---|---|---|---|
| version_0 | 0.571656 | 0.551804 | /Users/michaelvinyard/GitHub/scdiffeq-analyses... | 2500 |
| version_1 | 0.541401 | 0.465658 | /Users/michaelvinyard/GitHub/scdiffeq-analyses... | 1706 |
| version_2 | 0.547771 | 0.499418 | /Users/michaelvinyard/GitHub/scdiffeq-analyses... | 1238 |
| version_3 | 0.496815 | 0.504075 | /Users/michaelvinyard/GitHub/scdiffeq-analyses... | 1245 |
| version_4 | 0.562102 | 0.522701 | /Users/michaelvinyard/GitHub/scdiffeq-analyses... | 1662 |
[5]:
ckpt_path = best_ckpts_df['ckpt_path'].loc['version_0']
Load the model#
[6]:
model = sdq.io.load_model(adata = adata, ckpt_path = ckpt_path)
print(model)
model.to("mps:0") # or "cuda:0", for example
Seed set to 0
scDiffEq
Alternatively: load only the DiffEq#
[7]:
DiffEq = sdq.io.load_diffeq(ckpt_path=ckpt_path)
print(DiffEq)
LightningSDE-FixedPotential-RegularizedVelocityRatio