Source code for scdiffeq.io._model._checkpoint

# -- import packages: ---------------------------------------------------------
import ABCParse
import logging
import pandas as pd
import pathlib
import torch


# -- set typing: --------------------------------------------------------------
from typing import Union, Dict


# -- set up logging: ----------------------------------------------------------
logger = logging.getLogger(__name__)


# -- operational class: -------------------------------------------------------
[docs] class Checkpoint(ABCParse.ABCParse):
[docs] def __init__(self, path: Union[pathlib.Path, str], *args, **kwargs) -> None: """Instantiates checkpoint object. Args: path (Union[pathlib.Path, str]): Path to saved checkpoint. """ self.__parse__(locals())
@property def path(self) -> pathlib.Path: """ Returns: pathlib.Path The path to the checkpoint. """ return pathlib.Path(self._path) @property def _fname(self) -> str: """ Returns: str Filename without extension. """ return self.path.name.split(".")[0] @property def version(self): """ Returns: str Version of the checkpoint. """ if not hasattr(self, "_version"): v, n = self.path.parent.parent.name.split("_") self._version = " ".join([v.capitalize(), n]) return self._version @property def _PATH_F_HAT_RAW(self): """ Returns: pathlib.Path Path to the raw F_hat file. """ if not hasattr(self, "_FATE_PREDICTION_METRICS_PATH"): base_path = self.path.parent.parent.joinpath("fate_prediction_metrics") converted_name = ( self.path.name.replace("=", "_").replace("-", ".").split(".ckpt")[0] ) self._FATE_PREDICTION_METRICS_PATH = base_path.joinpath( f"{converted_name}/F_hat.unfiltered.csv" ) return self._FATE_PREDICTION_METRICS_PATH @property def F_hat(self): """ Returns: pd.DataFrame or None DataFrame containing F_hat data if the path exists, otherwise None. """ if not hasattr(self, "_F_hat"): if self._PATH_F_HAT_RAW.exists(): self._F_hat = pd.read_csv(self._PATH_F_HAT_RAW, index_col=0) self._F_hat.index = self._F_hat.index.astype(str) else: logger.warning(f"F_hat path does not exist.") self._F_hat = None return self._F_hat @property def epoch(self) -> Union[int, str]: """ Returns: Union[int, str] Epoch number if not 'last', otherwise 'last'. """ if self._fname != "last": return int(self._fname.split("=")[1].split("-")[0]) return self._fname @property def ckpt(self) -> Dict[str, "LightningCheckpoint"]: """ Returns: Dict[str, "LightningCheckpoint"] State dictionary created by PyTorch Lightning. """ if not hasattr(self, "_ckpt"): self._ckpt = torch.load( self.path, weights_only=False, map_location="cpu" ) return self._ckpt
[docs] def __repr__(self) -> str: """ Returns: str Object description of checkpoint at epoch. """ return f"ckpt epoch: {self.epoch} [{self.version}]"