# -- import packages: ----------------------------------------------------------
import ABCParse
import anndata
import logging
import numpy as np
import os
import pandas as pd
import pathlib
import sklearn.decomposition
import sklearn.preprocessing
# -- import local dependencies: ------------------------------------------------
from .. import io
from ._figshare_downloader import figshare_downloader
# -- set type hints: -----------------------------------------------------------
from typing import Dict, Optional, Union
# -- configure logger: ----------------------------------------------------------
logger = logging.getLogger(__name__)
def _annotate_larry_cytotrace(adata: anndata.AnnData, data_dir: Union[str, pathlib.Path]) -> Dict[str, pd.DataFrame]:
""" """
obs_write_path = str(pathlib.Path(data_dir).joinpath("larry.ct_obs_df.csv"))
var_write_path = str(pathlib.Path(data_dir).joinpath("larry.ct_var_df.csv"))
figshare_downloader(
figshare_id="54312011",
write_path=obs_write_path,
)
figshare_downloader(
figshare_id="54312008",
write_path=var_write_path,
)
obs_df = pd.read_csv(obs_write_path, index_col = 0)
var_df = pd.read_csv(var_write_path, index_col = 0)
# Convert indices to strings
obs_df.index = obs_df.index.astype(str)
var_df.index = var_df.index.astype(str)
# Convert all columns to strings
for col in obs_df.columns:
obs_df[col] = obs_df[col].astype(str)
for col in var_df.columns:
var_df[col] = var_df[col].astype(str)
adata.obs = pd.concat([adata.obs, obs_df], axis = 1)
adata.var = pd.concat([adata.var, var_df], axis = 1)
return adata
# -- Controller class: ---------------------------------------------------------
class LARRYInVitroDataset(ABCParse.ABCParse):
FIGSHARE_IDS = {
None: 55415231, # New biology-rich dataset (default)
"fate_prediction": 52612805, # Original fate prediction dataset
}
def __init__(
self,
data_dir=os.getcwd(),
variant: Optional[str] = None,
filter_genes: bool = True,
reduce_dimensions: bool = True,
cytotrace: bool = True,
force_download: bool = False,
*args,
**kwargs,
):
self.__parse__(locals())
@property
def _scdiffeq_parent_data_dir(self):
path = pathlib.Path(self._data_dir).joinpath("scdiffeq_data")
if not path.exists():
path.mkdir()
return path
@property
def data_dir(self):
path = self._scdiffeq_parent_data_dir.joinpath("larry")
if not path.exists():
path.mkdir()
return path
@property
def FNAME(self):
suffix = f"_{self._variant}" if self._variant else ""
return f"larry{suffix}.h5ad"
@property
def h5ad_path(self) -> pathlib.Path:
return self.data_dir.joinpath(self.FNAME)
@property
def _DO_PREPROCESSING(self):
return any([self._filter_genes, self._reduce_dimensions])
def download(self):
figshare_id = self.FIGSHARE_IDS[self._variant]
figshare_downloader(
figshare_id=figshare_id,
write_path=self.h5ad_path,
)
def _gene_filtering(self, adata: anndata.AnnData) -> anndata.AnnData:
return adata[:, adata.var["use_genes"]].copy()
def _dimension_reduction(self, adata: anndata.AnnData):
"""Do sample dimension reduction"""
# -- instantiate models: ----------------------------------------------
scaler = sklearn.preprocessing.StandardScaler()
pca = sklearn.decomposition.PCA(n_components=50)
# -- fit transform data: ----------------------------------------------
X_raw = adata.X
if not isinstance(X_raw, np.ndarray):
X_raw = X_raw.toarray()
adata.obsm["X_scaled"] = scaler.fit_transform(X_raw)
adata.obsm["X_pca"] = pca.fit_transform(adata.obsm["X_scaled"])
# -- save models: -----------------------------------------------------
io.write_pickle(
obj=scaler,
path=self.data_dir.joinpath("scaler.pkl"),
)
io.write_pickle(
obj=pca,
path=self.data_dir.joinpath("pca.pkl"),
)
def _preprocess(self, adata: anndata.AnnData) -> anndata.AnnData:
if self._DO_PREPROCESSING:
logger.info("Preprocessing...")
if self._cytotrace:
_annotate_larry_cytotrace(adata=adata, data_dir=self.data_dir)
if self._filter_genes:
adata = self._gene_filtering(adata)
if self._reduce_dimensions:
self._dimension_reduction(adata)
adata.write_h5ad(self.h5ad_path)
self._adata = adata
def _safe_read(self):
logger.info(f"Loading data from {self.h5ad_path}")
try:
adata = anndata.read_h5ad(self.h5ad_path)
if "ct_pseudotime" in adata.obs.columns:
adata.obs['ct_pseudotime'] = adata.obs['ct_pseudotime'].astype(float)
adata.obs.index.name = "index"
self._adata = adata
except Exception as e:
logger.error(f"Error loading data from {self.h5ad_path}: {e}")
raise e
return adata
@property
def adata(self) -> anndata.AnnData:
if not hasattr(self, "_adata"):
if not self.h5ad_path.exists() or self._force_download:
self.download()
adata = anndata.read_h5ad(self.h5ad_path)
self._preprocess(adata=adata)
return self._safe_read()
[docs]
def larry(
data_dir: str = os.getcwd(),
variant: Optional[str] = None,
filter_genes: bool = True,
reduce_dimensions: bool = True,
cytotrace: bool = True,
force_download: bool = False,
) -> anndata.AnnData:
"""LARRY in vitro dataset
Args:
data_dir: str, default=os.getcwd()
Path to the directory where the data will be saved.
variant: Optional[str], default=None
Dataset variant to download. None (default) downloads the full biology-rich
dataset. "fate_prediction" downloads the original fate prediction dataset.
filter_genes: bool, default=True
Whether to filter genes.
reduce_dimensions: bool, default=True
Whether to reduce dimensions.
cytotrace: bool, default=True
Whether to annotate LARRY with pre-computed CytoTRACE annotations.
force_download: bool, default=False
Whether to force download the data.
Returns:
anndata.AnnData: Preprocessed AnnData object.
"""
data_handler = LARRYInVitroDataset(
data_dir=data_dir,
variant=variant,
filter_genes=filter_genes,
reduce_dimensions=reduce_dimensions,
cytotrace=cytotrace,
force_download=force_download,
)
return data_handler.adata