Source code for scdiffeq.plotting._velocity_stream

# -- import packages: ---------------------------------------------------------
import ABCParse
import anndata
import cellplots as cp
import logging
import matplotlib.cm
import matplotlib.pyplot as plt
import numpy as np
import os
import pathlib

# -- import local dependencies: -----------------------------------------------
from ..tools import VelocityEmbedding, GridVelocity

# -- set type hints: ----------------------------------------------------------
from typing import Any, Dict, List, Optional, Union, Tuple

# -- configure logger: --------------------------------------------------------
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

# -- Operational class: -------------------------------------------------------
class VelocityStreamPlot(ABCParse.ABCParse):
    """A class to generate and plot velocity stream plots on a given plt.Axes object.

    This class is designed to visualize the flow of cells (or particles) in a
    velocity field, commonly used in single-cell data analysis to represent cell
    trajectories over a reduced dimensionality space. It extends the ABCParse class
    to leverage its parsing capabilities for initializing and configuring the plot
    parameters.

    Attributes:
        density (float): The density of the grid for velocity vectors. Defaults to 1.
        smooth (float): Smoothing factor applied to the velocity field. Defaults to 0.5.
        n_neighbors (Optional[int]): Number of neighbors to consider for local averaging. Defaults to None.
        min_mass (float): Minimum mass (weight) threshold for considering a point in the velocity field. Defaults to 1.
        autoscale (bool): Flag to automatically scale the vectors. Defaults to True.
        stream_adjust (bool): Adjust the streamplot parameters for optimal visualization. Defaults to True.
        cutoff_percentile (float): Percentile for cutoff to ignore outlier velocities. Defaults to 0.05.
        velocity_key (str): Key in `adata` to access velocity vectors. Defaults to "velocity".
        self_transitions (bool): Whether to consider self-transitions in the velocity calculations. Defaults to True.
        use_negative_cosines (bool): Flag to use negative cosines to adjust directionality. Defaults to True.
        T_scale (float): Scaling factor for the transition matrix. Defaults to 10.
        args: Variable length argument list.
        kwargs: Arbitrary keyword arguments.

    Note:
        This class requires an AnnData object `adata` to be passed at call time, not at initialization.
    """

    def __init__(
        self,
        density: float = 1,
        smooth: float = 0.5,
        n_neighbors: Optional[int] = None,
        min_mass: float = 1,
        autoscale: bool = True,
        stream_adjust: bool = True,
        cutoff_percentile: float = 0.05,
        velocity_key: str = "velocity",
        self_transitions: bool = True,
        use_negative_cosines: bool = True,
        T_scale: float = 10,
        *args,
        **kwargs,
    ) -> None:
        """Initializes the VelocityStreamPlot object with parameters to configure the velocity stream plot.

        This method sets up the necessary parameters for generating a velocity stream plot, including the setup for velocity embedding and grid velocity calculation. It parses the arguments and initializes internal states needed for plotting.

        Args:
            density (float): Density of the grid for velocity vectors. Higher values create a denser grid. Defaults to 1.
            smooth (float): Smoothing factor applied to the velocity vectors, influencing the smoothness of the stream plot. Defaults to 0.5.
            n_neighbors (Optional[int]): Number of nearest neighbors to use for local averaging of velocities. If None, a default heuristic is used. Defaults to None.
            min_mass (float): Minimum mass (weight) threshold for considering a point in the velocity field. Helps in filtering out noise. Defaults to 1.
            autoscale (bool): If True, scales the magnitude of velocity vectors automatically based on the density and size of the plot. Defaults to True.
            stream_adjust (bool): If True, adjusts stream plot parameters for optimal visualization. Defaults to True.
            cutoff_percentile (float): Percentile for cutoff to filter out outlier velocities, specified as a fraction between 0 and 1. Defaults to 0.05.
            velocity_key (str): Key in the AnnData object `adata` to access velocity vectors. Defaults to "velocity".
            self_transitions (bool): If True, considers self-transitions in velocity calculations, affecting the direction and magnitude of vectors. Defaults to True.
            use_negative_cosines (bool): If True, uses negative cosines to adjust the directionality of vectors, potentially improving visualization clarity. Defaults to True.
            T_scale (float): Scaling factor for the transition matrix T, affecting the overall magnitude of velocity vectors. Defaults to 10.
            args: Additional positional arguments not specifically defined.
            kwargs: Additional keyword arguments not specifically defined.

        Note:
            The `__init__` method does not require the AnnData object `adata`. Instead, `adata` should be passed to the `__call__` method when generating the plots.
        """

        self.__parse__(locals())

        self._velocity_emb = VelocityEmbedding(
            velocity_key=velocity_key,
            self_transitions=self_transitions,
            use_negative_cosines=use_negative_cosines,
            T_scale=T_scale,
        )
        self._grid_velocity = GridVelocity(
            density=density,
            smooth=smooth,
            n_neighbors=n_neighbors,
            min_mass=min_mass,
            autoscale=autoscale,
            stream_adjust=stream_adjust,
            cutoff_percentile=cutoff_percentile,
        )

    @property
    def X_emb(self):
        if not hasattr(self, "_X_emb"):
            self._X_emb, self._V_emb = self._velocity_emb(self._adata)
        return self._X_emb

    @property
    def V_emb(self):
        if not hasattr(self, "_V_emb"):
            self._X_emb, self._V_emb = self._velocity_emb(self._adata)
        return self._V_emb

    @property
    def X_grid(self):
        if not hasattr(self, "_X_grid"):
            self._X_grid, self._V_grid = self._grid_velocity(self.X_emb, self.V_emb)
        return self._X_grid

    @property
    def V_grid(self):
        if not hasattr(self, "_V_grid"):
            self._X_grid, self._V_grid = self._grid_velocity(self.X_emb, self.V_emb)
        return self._V_grid

    @property
    def x(self):
        return self.X_grid[0]

    @property
    def y(self):
        return self.X_grid[1]

    @property
    def u(self):
        return self.V_grid[0]

    @property
    def v(self):
        return self.V_grid[1]

    @property
    def xmin(self):
        return np.min(self.X_emb[:, 0])

    @property
    def xmax(self):
        return np.max(self.X_emb[:, 0])

    @property
    def ymin(self):
        return np.min(self.X_emb[:, 1])

    @property
    def ymax(self):
        return np.max(self.X_emb[:, 1])

    @property
    def xmargin(self):
        return (self.xmax - self.xmin) * self._add_margin

    @property
    def ymargin(self):
        return (self.ymax - self.ymin) * self._add_margin

    def _set_margin(self, ax):
        """"""
        ax.set_xlim(self.xmin - self.xmargin, self.xmax + self.xmargin)
        ax.set_ylim(self.ymin - self.ymargin, self.ymax + self.ymargin)

    @property
    def _STREAMPLOT_KWARGS(self) -> Dict[str, Any]:
        kwargs = {
            "color": self._stream_color,
            "density": self._stream_density,
            "linewidth": self._linewidth,
            "zorder": self._stream_zorder,
            "arrowsize": self._arrowsize,
            "arrowstyle": self._arrowstyle,
            "maxlength": self._maxlength,
            "integration_direction": self._integration_direction,
        }
        kwargs.update(self._stream_kwargs)
        return kwargs

    def streamplot(self, ax) -> None:
        ax.streamplot(self.x, self.y, self.u, self.v, **self._STREAMPLOT_KWARGS)
        self._set_margin(ax)

    @property
    def _SCATTER_KWARGS(self) -> Dict[str, Any]:
        """ """
        kwargs = {
            "c": self._c,
            "zorder": self._scatter_zorder,
            "ec": "None",
            "alpha": 0.2,
            "s": 50,
            "cmap": self._cmap,
            "rasterized": self._rasterized,
        }
        kwargs.update(self._scatter_kwargs)
        return kwargs

    def _SCATTER_CMAP(self, groups) -> Dict:
        """ """
        if not hasattr(self, "_cmap"):
            self._cmap = matplotlib.cm.tab20.colors
        if not isinstance(self._cmap, Dict):
            self._cmap = {group: self._cmap[en] for en, group in enumerate(groups)}
        return self._cmap

    def scatter(self, ax) -> None:
        """Generates a scatter plot on the given matplotlib axis, overlaying the stream plot.

        This method visualizes individual points (cells) on the velocity stream plot,
        with optional coloring and grouping.

        Args:
            ax (matplotlib.axes.Axes): The matplotlib axis on which to plot the scatter plot.
        """
        obs_df = self._adata.obs.copy().reset_index()
        cols = obs_df.columns.tolist()

        kwargs = self._SCATTER_KWARGS

        COLOR_FROM_OBS = self._c in cols

        if COLOR_FROM_OBS:
            COLOR_BY_GROUP = str(obs_df[self._c].dtype) == "categorical"
            if not COLOR_BY_GROUP:  # implies float not grouped object.
                c_idx = np.argsort(obs_df[self._c])
                kwargs.update({"c": obs_df[self._c][c_idx]})
                self._img = ax.scatter(
                    self.X_emb[c_idx, 0], self.X_emb[c_idx, 1], **kwargs
                )
                if not self._disable_cbar:
                    cbar = plt.colorbar(mappable=self._img, **self._cbar_kwargs)
                    cbar.solids.set(alpha=1)
            #                     cbar.set_alpha(1)

            else:
                kwargs.pop("c")
                groups = obs_df.groupby(self._c).groups  # dict
                cmap = self._SCATTER_CMAP(groups)
                for group, group_ix in groups.items():
                    if hasattr(self, "_group_zorder") and group in self._group_zorder:
                        kwargs.update({"zorder": self._group_zorder[group]})
                    ax.scatter(
                        self.X_emb[group_ix, 0],
                        self.X_emb[group_ix, 1],
                        color=cmap[group],
                        **kwargs,
                    )
        else:
            ax.scatter(self.X_emb[:, 0], self.X_emb[:, 1], **kwargs)

    @property
    def scdiffeq_figure_dir(self):
        return pathlib.Path("scdiffeq_figures")

    def _mk_fig_dir(self):
        if not self.scdiffeq_figure_dir.exists():
            os.mkdir(self.scdiffeq_figure_dir)
            logger.info(f"mkdir: {self.scdiffeq_figure_dir}")

    @property
    def sdq_info(self):
        return self._adata.uns["sdq_info"]

    @property
    def data_model_info_tag(self) -> str:
        return f"{self.sdq_info['project']}.version_{self.sdq_info['version']}.ckpt_{self.sdq_info['ckpt']}"

    @property
    def fname_basis(self) -> pathlib.Path:
        if "sdq_info" in self._adata.uns:
            try:
                return self.scdiffeq_figure_dir.joinpath(
                    f"velocity_stream.{self.data_model_info_tag}"
                )
            except:
                return self.scdiffeq_figure_dir.joinpath(
                    f"velocity_stream.{self.sdq_info}"
                )
            finally:
                pass
        return self.scdiffeq_figure_dir.joinpath("velocity_stream")

    @property
    def SVG_path(self) -> pathlib.Path:
        """ """
        return pathlib.Path(".".join([str(self.fname_basis), "svg"]))

    @property
    def PNG_path(self) -> pathlib.Path:
        return pathlib.Path(".".join([str(self.fname_basis), "png"]))

    def save_img(self) -> None:
        """Saves the generated plot to both SVG and PNG formats in a specified directory."""
        self._mk_fig_dir()
        plt.savefig(self.SVG_path, dpi=self._svg_dpi)
        plt.savefig(self.PNG_path, dpi=self._png_dpi)
        logger.info(f"Saved to: \n  {self.SVG_path}\n  {self.PNG_path}")

    def __call__(
        self,
        adata: anndata.AnnData,
        ax: Optional[Union[plt.Axes, List[plt.Axes]]] = None,
        stream_color: str = "k",
        c: str = "dodgerblue",
        group_zorder: Optional[Dict] = None,
        cmap: Optional[Union[Dict, List, Tuple, str]] = "plasma_r",
        linewidth: float = 0.5,
        stream_density: float = 2.5,
        add_margin: float = 0.1,
        arrowsize: float = 1,
        density: float = 1,
        arrowstyle: str = "-|>",
        maxlength: float = 4,
        integration_direction: str = "both",
        scatter_zorder: int = 0,
        stream_zorder: int = 10,
        rasterized: bool = True,
        mpl_kwargs: Optional[Dict] = {},
        scatter_kwargs: Optional[Dict] = {},
        stream_kwargs: Optional[Dict] = {},
        cbar_kwargs: Optional[Dict] = {},
        disable_scatter: bool = False,
        disable_cbar: bool = False,
        save: bool = False,
        png_dpi: Optional[float] = 500,
        svg_dpi: Optional[float] = 250,
        *args,
        **kwargs,
    ):
        """Generates velocity stream plots for the provided AnnData object.

        Args:
            adata (anndata.AnnData): The AnnData object containing the data for plotting.
            ax (Optional[Union[plt.Axes, List[plt.Axes]]]): A matplotlib axis or list of axes where plots will be drawn.
            **kwargs: Additional keyword arguments to customize the plot appearance.

        Returns:
            List[plt.Axes]: A list of matplotlib axes with the generated plots.
        """
        self.__update__(locals())

        if ax is None:
            _mpl_kwargs = {
                "nplots": 1,
                "ncols": 1,
                "height": 1.0,
                "width": 1.0,
                "delete": "all",
                "del_xy_ticks": [True],
            }
            _mpl_kwargs.update(mpl_kwargs)
            fig, axes = cp.plot(**_mpl_kwargs)

        else:
            axes = ABCParse.as_list(ax)

        for ax in axes:
            self.streamplot(ax)
            if not self._disable_scatter:
                self.scatter(ax)

        if self._save:
            self.save_img()

        return axes


# -- API-facing function: -----------------------------------------------------
[docs] def velocity_stream( adata: anndata.AnnData, ax: Optional[Union[plt.Axes, List[plt.Axes]]] = None, c: str = "dodgerblue", cmap: Optional[Union[Dict, List, Tuple]] = "plasma_r", group_zorder: Optional[Dict] = None, linewidth: float = 0.5, stream_density: float = 2.5, add_margin: float = 0.1, arrowsize: float = 1, arrowstyle: str = "-|>", maxlength: float = 4, integration_direction: str = "both", scatter_zorder: int = 101, stream_zorder: int = 201, density: float = 1, smooth: float = 0.5, n_neighbors: Optional[int] = None, min_mass: float = 1, autoscale=True, stream_adjust=True, cutoff_percentile: float = 0.05, velocity_key: str = "velocity", self_transitions: bool = True, use_negative_cosines: bool = True, T_scale: float = 10, disable_scatter: bool = False, disable_cbar: bool = False, stream_kwargs: Optional[Dict[str, Any]] = {}, scatter_kwargs: Optional[Dict[str, Any]] = {}, cbar_kwargs: Optional[Dict] = {}, mpl_kwargs: Optional[Dict[str, Any]] = {}, return_axes: bool = False, save: Optional[bool] = False, rasterized: bool = True, png_dpi: Optional[float] = 500, svg_dpi: Optional[float] = 250, *args, **kwargs, ) -> Optional[Union[List[plt.Axes], None]]: """ Generates velocity stream plots for single-cell data using the VelocityStreamPlot class. This function is a convenient wrapper around the VelocityStreamPlot class, allowing users to quickly generate and customize velocity stream plots without manually instantiating the class. Parameters ---------- adata : anndata.AnnData The AnnData object containing the data to plot. ax : Optional[Union[plt.Axes, List[plt.Axes]]], optional Matplotlib axes object or list of axes objects on which to draw the plots. If None, a new figure and axes are created. **Default**: ``None``. c : str, optional Color for the scatter plot points. Can be a column name from `adata.obs` if coloring by a categorical variable. **Default**: ``"dodgerblue"``. cmap : Optional[Union[Dict, List, Tuple, str]], optional Colormap for the scatter plot points if `c` is a categorical variable. **Default**: ``"plasma_r"``. group_zorder : Optional[Dict], optional Z-order for groups in the scatter plot, allowing certain groups to be plotted on top of others. **Default**: ``None``. linewidth : float, optional Line width for the streamlines. **Default**: ``0.5``. stream_density : float, optional Density of the streamlines. Higher values create more densely packed streamlines. **Default**: ``2.5``. add_margin : float, optional Additional margin added around the plotted data, specified as a fraction of the data range. **Default**: ``0.1``. arrowsize : float, optional Size of the arrows in the stream plot. **Default**: ``1``. arrowstyle : str, optional Style of the arrows in the stream plot. **Default**: ``"-|>"``. maxlength : float, optional Maximum length of the arrows in the stream plot. **Default**: ``4``. integration_direction : str, optional Direction of integration for the streamlines, can be "forward", "backward", or "both". **Default**: ``"both"``. scatter_zorder : int, optional Z-order for scatter plot points, determining their layering. **Default**: ``101``. stream_zorder : int, optional Z-order for the streamlines, determining their layering. **Default**: ``201``. density : float, optional **Default**: ``1``. smooth : float, optional **Default**: ``0.5``. n_neighbors : Optional[int], optional **Default**: ``None``. min_mass : float, optional **Default**: ``1``. autoscale : bool, optional **Default**: ``True``. stream_adjust : bool, optional **Default**: ``True``. cutoff_percentile : float, optional velocity_key : str, optional self_transitions : bool, optional use_negative_cosines : bool, optional T_scale : float, optional disable_scatter : bool, optional If True, disables the scatter plot overlay on the stream plot. **Default**: ``False``. disable_cbar : bool, optional If True, disables the color bar for the scatter plot. Useful when `c` is numeric. **Default**: ``False``. stream_kwargs : Optional[Dict[str, Any]], optional scatter_kwargs : Optional[Dict[str, Any]], optional cbar_kwargs : Optional[Dict], optional mpl_kwargs : Optional[Dict[str, Any]], optional Additional keyword arguments for customizing the stream plot, scatter plot, color bar, and matplotlib figure, respectively. return_axes : bool, optional If True, returns the matplotlib axes with the generated plots. **Default**: ``False``. save : bool, optional If True, saves the generated plot to SVG and PNG formats. **Default**: ``False``. png_dpi : Optional[float], optional DPI settings for saving PNG images. **Default**: ``500``. svg_dpi : Optional[float], optional DPI settings for saving SVG images. **Default**: ``250``. Returns ------- Optional[Union[List[plt.Axes], None]] A list of matplotlib axes with the generated plots, if ``return_axes == True``. Otherwise, returns ``None``. """ init_kwargs = ABCParse.function_kwargs(VelocityStreamPlot.__init__, locals()) call_kwargs = ABCParse.function_kwargs(VelocityStreamPlot.__call__, locals()) velo_stream_plot = VelocityStreamPlot(**init_kwargs) axes = velo_stream_plot(**call_kwargs) if return_axes: return axes