Source code for scdiffeq.plotting._temporal_expression

# -- import packages: ---------------------------------------------------------
import anndata
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# -- set typing: --------------------------------------------------------------
from typing import Dict, List, Optional, Union


# -- API-facing function: -----------------------------------------------------
[docs] def temporal_expression( adata_sim: anndata.AnnData, gene: str, groupby: str = "final_state", groups: Optional[List[str]] = None, use_key: str = "X_gene_inv", time_key: str = "t", gene_ids_key: str = "gene_ids", show_std: bool = True, std_alpha: float = 0.2, ax: Optional[plt.Axes] = None, figsize: tuple = (3, 2.5), cmap: Optional[Dict[str, str]] = None, linewidth: float = 2.0, x_label: str = "t(d)", y_label: str = "Log-norm. expression", title: Optional[str] = None, show_legend: bool = True, legend_loc: Union[str, tuple] = "best", grid: bool = True, grid_alpha: float = 0.3, save: bool = False, savename: Optional[str] = None, save_format: str = "svg", dpi: int = 300, ) -> plt.Axes: """ Plot gene expression over simulated time, grouped by fate. Computes mean and standard deviation at each time step and plots as line (mean) with shaded fill-between region (±1 std). Parameters ---------- adata_sim : anndata.AnnData Simulated data from ``sdq.tl.simulate()``, with gene expression stored in obsm after calling ``sdq.tl.invert_scaled_gex()``. gene : str Gene name to plot. groupby : str, default="final_state" Column in ``adata_sim.obs`` for grouping trajectories (e.g., cell fate). groups : List[str], optional Specific groups to plot. If None, plots all groups. Use this to exclude certain groups (e.g., ``groups=["Mon.", "Neu."]`` to only plot those two fates). use_key : str, default="X_gene_inv" Key in ``adata_sim.obsm`` containing the gene expression matrix. time_key : str, default="t" Column in ``adata_sim.obs`` containing time values. gene_ids_key : str, default="gene_ids" Key in ``adata_sim.uns`` containing gene names array. show_std : bool, default=True Whether to show standard deviation as shaded fill-between region. std_alpha : float, default=0.2 Transparency of the standard deviation shading. ax : plt.Axes, optional Matplotlib axes to plot on. If None, creates new figure. figsize : tuple, default=(3, 2.5) Figure size (width, height) in inches if creating new figure. cmap : Dict[str, str], optional Mapping from group names to colors. If None, uses default colormap. linewidth : float, default=2.0 Width of the mean line. x_label : str, default="t(d)" Label for x-axis. y_label : str, default="Log-norm. expression" Label for y-axis. title : str, optional Plot title. If None, uses gene name in italic. show_legend : bool, default=True Whether to show legend. legend_loc : str or tuple, default="best" Legend location. grid : bool, default=True Whether to show grid. grid_alpha : float, default=0.3 Transparency of grid lines. save : bool, default=False Whether to save the figure. savename : str, optional Filename for saving. If None, auto-generates from gene name. save_format : str, default="svg" Format for saving figure. dpi : int, default=300 Resolution for saving figure. Returns ------- plt.Axes The matplotlib axes object. Examples -------- >>> import scdiffeq as sdq >>> adata_sim = sdq.tl.simulate(adata, diffeq=model, idx=idx) >>> sdq.tl.invert_scaled_gex(adata_sim, ...) >>> sdq.tl.annotate_cell_fate(adata_sim, ...) >>> sdq.pl.temporal_expression( ... adata_sim, ... gene="Spi1", ... groupby="final_state", ... cmap={"Mon.": "orange", "Neu.": "#4a7298"} ... ) """ # -- Get gene index ------------------------------------------------------- gene_ids_raw = adata_sim.uns[gene_ids_key] # Handle dict format: {index: gene_name} if isinstance(gene_ids_raw, dict): gene_names = list(gene_ids_raw.values()) if gene not in gene_names: preview = gene_names[:5] raise ValueError( f"Gene '{gene}' not found in adata_sim.uns['{gene_ids_key}']. " f"Available genes: {preview}..." ) gene_idx = gene_names.index(gene) else: # Handle list-like formats if isinstance(gene_ids_raw, (pd.Index, pd.Series)): gene_ids = gene_ids_raw.tolist() elif isinstance(gene_ids_raw, np.ndarray): gene_ids = gene_ids_raw.tolist() elif isinstance(gene_ids_raw, list): gene_ids = gene_ids_raw else: gene_ids = list(gene_ids_raw) if gene not in gene_ids: preview = gene_ids[:5] if len(gene_ids) >= 5 else gene_ids raise ValueError( f"Gene '{gene}' not found in adata_sim.uns['{gene_ids_key}']. " f"Available genes: {preview}..." ) gene_idx = gene_ids.index(gene) # -- Extract expression and metadata -------------------------------------- expr_matrix = adata_sim.obsm[use_key] if isinstance(expr_matrix, pd.DataFrame): # DataFrame: use iloc for positional indexing expression = expr_matrix.iloc[:, gene_idx].values else: # numpy array: use direct indexing expression = expr_matrix[:, gene_idx] time = adata_sim.obs[time_key].values group_labels = adata_sim.obs[groupby].values # -- Build dataframe for groupby operations ------------------------------- df = pd.DataFrame({ "expression": expression, "time": time, "group": group_labels, }) # -- Compute mean and std per (time, group) ------------------------------- stats = df.groupby(["time", "group"])["expression"].agg(["mean", "std"]) stats = stats.reset_index() # -- Setup figure --------------------------------------------------------- if ax is None: fig, ax = plt.subplots(figsize=figsize) # -- Determine which groups to plot --------------------------------------- unique_groups = df["group"].unique() if groups is not None: # Filter to only requested groups plot_groups = [g for g in groups if g in unique_groups] else: plot_groups = list(unique_groups) # -- Default colormap ----------------------------------------------------- if cmap is None: default_colors = plt.cm.tab10.colors cmap = {g: default_colors[i % len(default_colors)] for i, g in enumerate(plot_groups)} # -- Plot each group ------------------------------------------------------ for group in plot_groups: group_data = stats[stats["group"] == group].sort_values("time") t = group_data["time"].values mean = group_data["mean"].values std = group_data["std"].values color = cmap.get(group, "gray") # Plot mean line ax.plot(t, mean, color=color, linewidth=linewidth, label=group, zorder=2) # Plot std fill-between if show_std: ax.fill_between( t, mean - std, mean + std, color=color, alpha=std_alpha, linewidth=0, zorder=1, ) # -- Formatting ----------------------------------------------------------- ax.set_xlabel(x_label, fontsize=10) ax.set_ylabel(y_label, fontsize=10) if title is None: title = f"$\\it{{{gene}}}$" ax.set_title(title, fontsize=11) if grid: ax.grid(True, alpha=grid_alpha, zorder=0) if show_legend: ax.legend( loc=legend_loc, frameon=True, facecolor="white", edgecolor="lightgray", fontsize=8, ) # Remove top and right spines ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) # -- Save ----------------------------------------------------------------- if save: if savename is None: savename = f"scDiffEq.temporal_expression.{gene}.{save_format}" plt.savefig(savename, format=save_format, dpi=dpi, bbox_inches="tight") return ax