Source code for scdiffeq.io._model._version

# -- import packages: ---------------------------------------------------------
import ABCParse
import pathlib
import pandas as pd


# -- import local dependencies: -----------------------------------------------
from ._hparams import HParams
from ._checkpoint import Checkpoint
from ._grouped_metrics import GroupedMetrics


# -- set typing: --------------------------------------------------------------
from typing import Dict, List, Union


# -- operational class: -------------------------------------------------------
[docs] class Version(ABCParse.ABCParse): """scDiffEq Version object container Attributes ---------- _path : Union[pathlib.Path, str] Path to the version within an scDiffEq project. _groupby : str Grouping method for metrics, default is "epoch". """
[docs] def __init__( self, path: Union[pathlib.Path, str] = None, groupby: str = "epoch", *args, **kwargs ): """Instantiate Version by providing a path Parameters ---------- path : Union[pathlib.Path, str], optional Path to the version within an scDiffEq project, by default None. groupby : str, optional Grouping method for metrics, by default "epoch" Returns ------- None """ self.__parse__(locals())
@property def _PATH(self) -> pathlib.Path: """Check and format the provided path Returns ------- pathlib.Path Formatted path Raises ------ TypeError If the path is not of type pathlib.Path or str """ if isinstance(self._path, pathlib.Path): return self._path if isinstance(self._path, str): self._path = pathlib.Path(self._path) return self._path @property def _NAME(self) -> str: """Version name from provided path Returns ------- str Name of the version """ return self._PATH.name @property def _CONTENTS(self) -> List[pathlib.Path]: """Return one-level glob of the provided path Returns ------- List[pathlib.Path] List of contents in the provided path """ return list(self._PATH.glob("*")) @property def hparams(self): """Check if the .yaml exists and instantiate the HParams class each time Returns ------- HParams Instance of HParams class if hparams.yaml exists """ hparams_path = self._PATH.joinpath("hparams.yaml") if hparams_path.exists(): return HParams(hparams_path) @property def metrics_df(self) -> pd.DataFrame: """Check if metrics.csv path exists and read it Returns ------- pd.DataFrame DataFrame containing the metrics if metrics.csv exists """ metrics_path = self._PATH.joinpath("metrics.csv") if metrics_path.exists(): return pd.read_csv(metrics_path) @property def per_epoch_metrics(self): """Group metrics by the specified groupby attribute Returns ------- GroupedMetrics Instance of GroupedMetrics class """ self._GROUPED_METRICS = GroupedMetrics(groupby=self._groupby) return self._GROUPED_METRICS(self.metrics_df) @property def _CKPT_PATHS(self) -> List[pathlib.Path]: """Formatted checkpoint paths Returns ------- List[pathlib.Path] List of checkpoint paths """ _ckpt_paths = list(self._PATH.joinpath("checkpoints").glob("*")) return [pathlib.Path(path) for path in _ckpt_paths] @property def _SORTED_CKPT_KEYS(self) -> List: """Sorting for organization's sake Returns ------- List Sorted list of checkpoint keys """ epochs = list(self.ckpts.keys()) _epochs = sorted([epoch for epoch in epochs if epoch != "last"]) if "last" in epochs: _epochs.append("last") return _epochs @property def ckpts(self) -> Dict: """Format and update available checkpoints for the version Returns ------- Dict Dictionary of checkpoints """ if not hasattr(self, "_CHECKPOINTS"): self._CHECKPOINTS = {} for ckpt_path in self._CKPT_PATHS: ckpt = Checkpoint(ckpt_path) self._CHECKPOINTS[ckpt.epoch] = ckpt return self._CHECKPOINTS
[docs] def __repr__(self) -> str: """Return the name of the object Returns ------- str Name of the object """ return self._NAME