Source code for scdiffeq.plotting._simulation_expression_gif

# -- import packages: ---------------------------------------------------------
import anndata
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.cm as cm
import numpy as np
import pandas as pd
import tempfile
import os

# -- set typing: --------------------------------------------------------------
from typing import Callable, Dict, List, Optional, Union


# -- Helper functions: --------------------------------------------------------
def _create_expr_default_background(adata_sim, ax, use_key, background_s, background_inner_s):
    """Default background with black outline and white fill."""
    xu = adata_sim.obsm[use_key]
    if isinstance(xu, pd.DataFrame):
        xu = xu.values
    ax.scatter(xu[:, 0], xu[:, 1], c="k", ec="None", rasterized=True, s=background_s)
    ax.scatter(xu[:, 0], xu[:, 1], c="w", ec="None", rasterized=True, s=background_inner_s)


def _create_expr_grouped_background(
    adata_sim, ax, use_key, background_groupby, background_cmap, background_s, background_inner_s
):
    """Background colored by group membership."""
    groups = adata_sim.obs[background_groupby]
    xu = adata_sim.obsm[use_key]
    if isinstance(xu, pd.DataFrame):
        xu = xu.values

    unique_groups = groups.unique()
    if background_cmap is None:
        default_colors = plt.cm.tab10.colors
        group_colors = {g: default_colors[i % len(default_colors)] for i, g in enumerate(unique_groups)}
    else:
        group_colors = background_cmap

    for group in unique_groups:
        mask = groups == group
        c = group_colors.get(group, "k")
        ax.scatter(xu[mask, 0], xu[mask, 1], c=c, ec="None", rasterized=True, s=background_s)
        ax.scatter(xu[mask, 0], xu[mask, 1], c="w", ec="None", rasterized=True, s=background_inner_s)


def _draw_umap_labels(ax, umap_labels):
    """Draw text labels on the UMAP axes."""
    if umap_labels is None:
        return
    for label in umap_labels:
        text = label.get("text", "")
        x = label.get("x", 0)
        y = label.get("y", 0)
        kwargs = {k: v for k, v in label.items() if k not in ("text", "x", "y")}
        # Set defaults
        kwargs.setdefault("fontsize", 10)
        kwargs.setdefault("ha", "center")
        kwargs.setdefault("va", "center")
        ax.text(x, y, text, **kwargs)


def _create_expression_progenitor_frame(
    adata_sim,
    ax_umap,
    ax_expr,
    background_fn,
    progenitor_x,
    progenitor_y,
    progenitor_color,
    progenitor_s,
    progenitor_label,
    show_time_label,
    time_label_loc,
    time_label_fmt,
    time_label_fontsize,
    t_min,
    t_max,
    umap_title,
    x_all,
    y_all,
    umap_cmap,
    color,
    expr_ylim,
    x_label,
    y_label,
    gene,
    plot_groups,
    expr_cmap,
    linewidth,
    umap_labels=None,
):
    """Create content for dual-panel progenitor intro frame."""
    # === UMAP Panel ===
    background_fn(adata_sim, ax_umap)

    ax_umap.scatter(
        progenitor_x,
        progenitor_y,
        c=progenitor_color,
        s=progenitor_s,
        edgecolors="white",
        linewidths=1.5,
        zorder=300,
    )

    ax_umap.annotate(
        progenitor_label,
        xy=(progenitor_x, progenitor_y),
        xytext=(progenitor_x + 1.5, progenitor_y + 1.5),
        fontsize=12,
        fontweight="bold",
        color=progenitor_color,
        arrowprops=dict(arrowstyle="->", color=progenitor_color, lw=2),
        zorder=301,
    )

    if show_time_label:
        ax_umap.text(
            time_label_loc[0],
            time_label_loc[1],
            time_label_fmt.format(t_min),
            transform=ax_umap.transAxes,
            fontsize=time_label_fontsize,
            verticalalignment="top",
            fontweight="bold",
            bbox=dict(boxstyle="round", facecolor="white", alpha=0.8),
        )

    if umap_title:
        ax_umap.set_title(umap_title, fontsize=12)

    for spine in ax_umap.spines.values():
        spine.set_visible(False)
    ax_umap.set_xticks([])
    ax_umap.set_yticks([])
    ax_umap.set_xlabel("")
    ax_umap.set_ylabel("")
    ax_umap.set_xlim(x_all.min() - 0.5, x_all.max() + 0.5)
    ax_umap.set_ylim(y_all.min() - 0.5, y_all.max() + 0.5)

    norm = mcolors.Normalize(vmin=t_min, vmax=t_max)
    sm = cm.ScalarMappable(cmap=umap_cmap, norm=norm)
    sm.set_array([])
    cbar = plt.colorbar(sm, ax=ax_umap, shrink=0.6, orientation="horizontal", location="bottom")
    cbar.set_label(color, fontsize=10)

    cbar_ax = cbar.ax
    cbar_xmin, cbar_xmax = cbar_ax.get_xlim()
    cbar_ax.axvline(x=cbar_xmin, color="dodgerblue", linewidth=2, zorder=10)

    # Draw UMAP labels
    _draw_umap_labels(ax_umap, umap_labels)

    # === Expression Panel ===
    ax_expr.set_xlim(t_min, t_max)
    ax_expr.set_ylim(expr_ylim)
    ax_expr.set_xlabel(x_label, fontsize=10)
    ax_expr.set_ylabel(y_label, fontsize=10)
    ax_expr.set_title(f"$\\it{{{gene}}}$", fontsize=11)
    ax_expr.spines["top"].set_visible(False)
    ax_expr.spines["right"].set_visible(False)
    ax_expr.grid(True, alpha=0.3, zorder=0)

    ax_expr.axvline(x=t_min, color="dodgerblue", linewidth=2, linestyle="--", alpha=0.7, zorder=5)

    for group in plot_groups:
        ax_expr.plot([], [], color=expr_cmap.get(group, "gray"), linewidth=linewidth, label=group)
    ax_expr.legend(loc="best", frameon=True, facecolor="white", edgecolor="lightgray", fontsize=8)


def _create_expression_frame(
    adata_sim,
    ax_umap,
    ax_expr,
    t_current,
    frame_alpha,
    background_fn,
    x_all,
    y_all,
    time_values,
    color_values,
    umap_cmap,
    s,
    alpha,
    trail_alpha,
    leading_edge_scale,
    vmin,
    vmax,
    show_time_label,
    time_label_loc,
    time_label_fmt,
    time_label_fontsize,
    umap_title,
    t_min,
    t_max,
    color,
    stats_full,
    plot_groups,
    expr_cmap,
    linewidth,
    show_std,
    std_alpha,
    expr_ylim,
    x_label,
    y_label,
    gene,
    umap_labels=None,
    **kwargs,
):
    """Create content for a dual-panel animation frame."""
    # === UMAP Panel ===
    background_fn(adata_sim, ax_umap)

    mask = time_values <= t_current
    x = x_all[mask]
    y = y_all[mask]
    c = color_values[mask]
    t_pts = time_values[mask]

    trail_mask = t_pts < t_current
    if np.any(trail_mask):
        ax_umap.scatter(
            x[trail_mask],
            y[trail_mask],
            c=c[trail_mask],
            cmap=umap_cmap,
            s=s,
            alpha=alpha * trail_alpha * frame_alpha,
            vmin=vmin,
            vmax=vmax,
            zorder=200,
            edgecolors="none",
            **kwargs,
        )

    leading_mask = t_pts == t_current
    if np.any(leading_mask):
        ax_umap.scatter(
            x[leading_mask],
            y[leading_mask],
            c=c[leading_mask],
            cmap=umap_cmap,
            s=s * leading_edge_scale,
            alpha=frame_alpha,
            vmin=vmin,
            vmax=vmax,
            zorder=202,
            edgecolors="none",
            **kwargs,
        )

    if show_time_label:
        ax_umap.text(
            time_label_loc[0],
            time_label_loc[1],
            time_label_fmt.format(t_current),
            transform=ax_umap.transAxes,
            fontsize=time_label_fontsize,
            verticalalignment="top",
            fontweight="bold",
            alpha=frame_alpha,
            bbox=dict(boxstyle="round", facecolor="white", alpha=0.8 * frame_alpha),
        )

    if umap_title:
        ax_umap.set_title(umap_title, fontsize=12)

    for spine in ax_umap.spines.values():
        spine.set_visible(False)
    ax_umap.set_xticks([])
    ax_umap.set_yticks([])
    ax_umap.set_xlabel("")
    ax_umap.set_ylabel("")
    ax_umap.set_xlim(x_all.min() - 0.5, x_all.max() + 0.5)
    ax_umap.set_ylim(y_all.min() - 0.5, y_all.max() + 0.5)

    norm = mcolors.Normalize(vmin=t_min, vmax=t_max)
    sm = cm.ScalarMappable(cmap=umap_cmap, norm=norm)
    sm.set_array([])
    cbar = plt.colorbar(sm, ax=ax_umap, shrink=0.6, orientation="horizontal", location="bottom")
    cbar.set_label(color, fontsize=10)

    cbar_ax = cbar.ax
    cbar_xmin, cbar_xmax = cbar_ax.get_xlim()
    progress_x = (
        cbar_xmin + (t_current - t_min) / (t_max - t_min) * (cbar_xmax - cbar_xmin)
        if t_max > t_min
        else cbar_xmax
    )
    cbar_ax.axvline(x=progress_x, color="dodgerblue", linewidth=2, zorder=10)

    # Draw UMAP labels
    _draw_umap_labels(ax_umap, umap_labels)

    # === Expression Panel ===
    stats_current = stats_full[stats_full["time"] <= t_current]

    for group in plot_groups:
        group_data = stats_current[stats_current["group"] == group].sort_values("time")
        if len(group_data) == 0:
            continue

        t = group_data["time"].values
        mean = group_data["mean"].values
        std = group_data["std"].values
        color_line = expr_cmap.get(group, "gray")

        ax_expr.plot(t, mean, color=color_line, linewidth=linewidth, label=group, alpha=frame_alpha, zorder=2)

        if show_std:
            ax_expr.fill_between(
                t,
                mean - std,
                mean + std,
                color=color_line,
                alpha=std_alpha * frame_alpha,
                linewidth=0,
                zorder=1,
            )

    ax_expr.set_xlim(t_min, t_max)
    ax_expr.set_ylim(expr_ylim)
    ax_expr.set_xlabel(x_label, fontsize=10)
    ax_expr.set_ylabel(y_label, fontsize=10)
    ax_expr.set_title(f"$\\it{{{gene}}}$", fontsize=11)
    ax_expr.spines["top"].set_visible(False)
    ax_expr.spines["right"].set_visible(False)
    ax_expr.grid(True, alpha=0.3, zorder=0)

    ax_expr.axvline(
        x=t_current, color="dodgerblue", linewidth=2, linestyle="--", alpha=0.7 * frame_alpha, zorder=5
    )

    handles, labels = ax_expr.get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    ax_expr.legend(
        by_label.values(), by_label.keys(), loc="best", frameon=True, facecolor="white", edgecolor="lightgray", fontsize=8
    )


# -- API-facing function: -----------------------------------------------------
[docs] def simulation_expression_gif( adata_sim: anndata.AnnData, gene: str, savename: str = "simulation_expression.gif", groupby: str = "final_state", groups: Optional[List[str]] = None, color: str = "t", use_key: str = "X_umap", time_key: str = "t", gene_key: str = "X_gene_inv", gene_ids_key: str = "gene_ids", figsize: tuple = (12, 6), expr_width_scale: float = 0.8, expr_height_scale: float = 0.8, # UMAP panel options umap_cmap: Union[str, mcolors.Colormap] = "plasma_r", s: float = 10.0, alpha: float = 0.8, background_fn: Optional[Callable] = None, background_groupby: Optional[str] = None, background_cmap: Optional[Dict[str, str]] = None, background_s: float = 100.0, background_inner_s: float = 65.0, umap_labels: Optional[List[Dict]] = None, # Expression panel options expr_cmap: Optional[Dict[str, str]] = None, linewidth: float = 2.0, show_std: bool = True, std_alpha: float = 0.2, x_label: str = "t(d)", y_label: str = "Log-norm. expression", # Shared options show_time_label: bool = True, time_label_fmt: str = "t = {:.1f}d", time_label_loc: tuple = (0.05, 0.95), time_label_fontsize: int = 12, umap_title: Optional[str] = None, fps: int = 10, duration: Optional[float] = None, hold_frames: int = 10, fade_frames: int = 8, leading_edge_scale: float = 2.0, trail_alpha: float = 0.5, show_progenitor: bool = True, progenitor_frames: int = 8, progenitor_label: str = "Progenitor", progenitor_s: float = 80.0, progenitor_color: str = "dodgerblue", dpi: int = 100, return_fig: bool = False, **kwargs, ) -> Union[str, tuple]: """ Create a dual-panel GIF with synchronized UMAP trajectory and gene expression. Left panel shows the simulation trajectory growing over UMAP space. Right panel shows temporal gene expression (mean ± std) growing over time. Parameters ---------- adata_sim : anndata.AnnData Simulated data from ``sdq.tl.simulate()``, with UMAP coordinates in obsm and gene expression after ``sdq.tl.invert_scaled_gex()``. gene : str Gene name to plot in the expression panel. savename : str, default="simulation_expression.gif" Output filename for the GIF. groupby : str, default="final_state" Column in ``adata_sim.obs`` for grouping trajectories in expression panel. groups : List[str], optional Specific groups to plot. If None, plots all groups. color : str, default="t" What to color UMAP points by. Can be column in obs or gene name. use_key : str, default="X_umap" Key in ``adata_sim.obsm`` containing UMAP coordinates. time_key : str, default="t" Column in ``adata_sim.obs`` containing time values. gene_key : str, default="X_gene_inv" Key in ``adata_sim.obsm`` containing gene expression matrix. gene_ids_key : str, default="gene_ids" Key in ``adata_sim.uns`` containing gene names. figsize : tuple, default=(12, 6) Figure size (width, height) in inches for the dual-panel figure. expr_width_scale : float, default=0.8 Width of expression panel relative to UMAP panel. 0.8 means expression panel is 80% as wide as UMAP panel. expr_height_scale : float, default=0.8 Height scaling for expression panel. Uses GridSpec height_ratios to make expression panel shorter. 0.8 means expression panel is 80% as tall, with remaining space as padding. umap_cmap : str or Colormap, default="plasma_r" Colormap for UMAP continuous values. s : float, default=10.0 Point size for simulation points on UMAP. alpha : float, default=0.8 Point transparency on UMAP. background_fn : Callable, optional Custom function to plot UMAP background. Should accept (adata_sim, ax). background_groupby : str, optional Column in obs to group background cells by (e.g., "final_state"). When set, background cells are colored by group using background_cmap. background_cmap : Dict[str, str], optional Mapping from group names to colors for background. Example: {"Mon.": "orange", "Neu.": "#4a7298"} background_s : float, default=100.0 Point size for background outer points. background_inner_s : float, default=65.0 Point size for background inner points. umap_labels : List[Dict], optional List of label dictionaries to draw on the UMAP panel. Each dict should have keys "text", "x", "y", and optionally any matplotlib text kwargs like "color", "fontsize", "weight", "ha", "va". Example: [{"text": "Monocyte", "x": 10.5, "y": 10, "color": "#F08700", "weight": "bold"}] expr_cmap : Dict[str, str], optional Mapping from group names to colors for expression panel. linewidth : float, default=2.0 Width of mean lines in expression panel. show_std : bool, default=True Whether to show standard deviation shading in expression panel. std_alpha : float, default=0.2 Transparency of std shading. x_label : str, default="t(d)" X-axis label for expression panel. y_label : str, default="Log-norm. expression" Y-axis label for expression panel. show_time_label : bool, default=True Whether to show time label on UMAP panel. time_label_fmt : str, default="t = {:.1f}d" Format string for time label. time_label_loc : tuple, default=(0.05, 0.95) Location of time label in axes coordinates. time_label_fontsize : int, default=12 Font size for time label. umap_title : str, optional Title for UMAP panel. fps : int, default=10 Frames per second for the GIF. duration : float, optional Total duration in seconds. If provided, overrides fps. hold_frames : int, default=10 Number of frames to hold at the end before fading. fade_frames : int, default=8 Number of frames for the fade-out transition. leading_edge_scale : float, default=2.0 Size multiplier for leading edge points on UMAP. trail_alpha : float, default=0.5 Alpha multiplier for trail points on UMAP. show_progenitor : bool, default=True Whether to show progenitor intro frames at the start. progenitor_frames : int, default=8 Number of frames to hold on the progenitor before starting animation. progenitor_label : str, default="Progenitor" Label text for the progenitor annotation. progenitor_s : float, default=80.0 Point size for the progenitor marker. progenitor_color : str, default="dodgerblue" Color for the progenitor marker and annotation. dpi : int, default=100 Resolution for each frame. return_fig : bool, default=False If True, also returns the final frame's (fig, ax_umap, ax_expr) tuple. **kwargs Additional keyword arguments passed to UMAP scatter. Returns ------- str or tuple Path to the saved GIF file. If return_fig=True, returns (savename, fig, ax_umap, ax_expr) tuple with the final frame. Examples -------- >>> import scdiffeq as sdq >>> sdq.pl.simulation_expression_gif( ... adata_sim, ... gene="Spi1", ... groupby="final_state", ... expr_cmap={"Mon.": "orange", "Neu.": "#4a7298"} ... ) """ try: from PIL import Image except ImportError: raise ImportError("PIL (Pillow) is required for GIF creation. Install it with: pip install Pillow") # -- Get UMAP coordinates ------------------------------------------------- umap_coords = adata_sim.obsm[use_key] if isinstance(umap_coords, pd.DataFrame): x_all = umap_coords.iloc[:, 0].values y_all = umap_coords.iloc[:, 1].values else: x_all = umap_coords[:, 0] y_all = umap_coords[:, 1] # -- Get time values ------------------------------------------------------ time_values = adata_sim.obs[time_key].values unique_times = np.sort(np.unique(time_values)) t_min, t_max = unique_times.min(), unique_times.max() # -- Compute progenitor mean position (t=0 cells) ------------------------- t0_mask = time_values == t_min progenitor_x = np.mean(x_all[t0_mask]) progenitor_y = np.mean(y_all[t0_mask]) # -- Get color values for UMAP -------------------------------------------- color_values = None if color in adata_sim.obs.columns: color_values = adata_sim.obs[color].values else: gene_ids_raw = adata_sim.uns.get(gene_ids_key, {}) if isinstance(gene_ids_raw, dict): gene_names = list(gene_ids_raw.values()) if color in gene_names: gene_idx = gene_names.index(color) expr_matrix = adata_sim.obsm[gene_key] if isinstance(expr_matrix, pd.DataFrame): color_values = expr_matrix.iloc[:, gene_idx].values else: color_values = expr_matrix[:, gene_idx] else: if isinstance(gene_ids_raw, (pd.Index, pd.Series)): gene_ids = gene_ids_raw.tolist() elif isinstance(gene_ids_raw, np.ndarray): gene_ids = gene_ids_raw.tolist() elif isinstance(gene_ids_raw, list): gene_ids = gene_ids_raw else: gene_ids = list(gene_ids_raw) if gene_ids_raw else [] if color in gene_ids: gene_idx = gene_ids.index(color) expr_matrix = adata_sim.obsm[gene_key] if isinstance(expr_matrix, pd.DataFrame): color_values = expr_matrix.iloc[:, gene_idx].values else: color_values = expr_matrix[:, gene_idx] if color_values is None: color_values = time_values vmin, vmax = np.nanmin(color_values), np.nanmax(color_values) # -- Get gene expression data for expression panel ------------------------ gene_ids_raw = adata_sim.uns[gene_ids_key] if isinstance(gene_ids_raw, dict): gene_names_list = list(gene_ids_raw.values()) if gene not in gene_names_list: preview = gene_names_list[:5] raise ValueError( f"Gene '{gene}' not found in adata_sim.uns['{gene_ids_key}']. " f"Available genes: {preview}..." ) gene_idx = gene_names_list.index(gene) else: if isinstance(gene_ids_raw, (pd.Index, pd.Series)): gene_ids_list = gene_ids_raw.tolist() elif isinstance(gene_ids_raw, np.ndarray): gene_ids_list = gene_ids_raw.tolist() elif isinstance(gene_ids_raw, list): gene_ids_list = gene_ids_raw else: gene_ids_list = list(gene_ids_raw) if gene not in gene_ids_list: preview = gene_ids_list[:5] if len(gene_ids_list) >= 5 else gene_ids_list raise ValueError( f"Gene '{gene}' not found in adata_sim.uns['{gene_ids_key}']. " f"Available genes: {preview}..." ) gene_idx = gene_ids_list.index(gene) expr_matrix = adata_sim.obsm[gene_key] if isinstance(expr_matrix, pd.DataFrame): expression = expr_matrix.iloc[:, gene_idx].values else: expression = expr_matrix[:, gene_idx] group_labels = adata_sim.obs[groupby].values df = pd.DataFrame({"expression": expression, "time": time_values, "group": group_labels}) stats_full = df.groupby(["time", "group"])["expression"].agg(["mean", "std"]).reset_index() unique_groups = df["group"].unique() if groups is not None: plot_groups = [g for g in groups if g in unique_groups] else: plot_groups = list(unique_groups) if expr_cmap is None: default_colors = plt.cm.tab10.colors expr_cmap = {g: default_colors[i % len(default_colors)] for i, g in enumerate(plot_groups)} expr_ymin = (stats_full["mean"] - stats_full["std"]).min() expr_ymax = (stats_full["mean"] + stats_full["std"]).max() expr_y_margin = (expr_ymax - expr_ymin) * 0.1 expr_ylim = (expr_ymin - expr_y_margin, expr_ymax + expr_y_margin) # -- Setup figure layout -------------------------------------------------- width_ratios = (1.0, expr_width_scale) # For height: use 2 rows, expression panel spans only top portion # height_ratios = (expr_height_scale, 1 - expr_height_scale) for the expr column # But UMAP spans both rows def _create_dual_figure(): """Create figure with UMAP (full height) and expression (scaled height) panels.""" fig = plt.figure(figsize=figsize) # 2 columns: UMAP gets width_ratios[0], expression gets width_ratios[1] # 2 rows for expression column: top row is expr_height_scale, bottom is padding gs = fig.add_gridspec( 2, 2, width_ratios=width_ratios, height_ratios=(expr_height_scale, 1 - expr_height_scale), hspace=0.05, ) # UMAP spans both rows in column 0 ax_umap = fig.add_subplot(gs[:, 0]) # Expression panel only in top row of column 1 ax_expr = fig.add_subplot(gs[0, 1]) return fig, ax_umap, ax_expr # -- Setup background function -------------------------------------------- if background_fn is None: if background_groupby is not None: background_fn = lambda adata, ax: _create_expr_grouped_background( adata, ax, use_key, background_groupby, background_cmap, background_s, background_inner_s ) else: background_fn = lambda adata, ax: _create_expr_default_background( adata, ax, use_key, background_s, background_inner_s ) # -- Create frames -------------------------------------------------------- frames = [] with tempfile.TemporaryDirectory() as tmpdir: frame_idx = 0 # Progenitor intro frames if show_progenitor and progenitor_frames > 0: fig, ax_umap, ax_expr = _create_dual_figure() _create_expression_progenitor_frame( adata_sim, ax_umap, ax_expr, background_fn, progenitor_x, progenitor_y, progenitor_color, progenitor_s, progenitor_label, show_time_label, time_label_loc, time_label_fmt, time_label_fontsize, t_min, t_max, umap_title, x_all, y_all, umap_cmap, color, expr_ylim, x_label, y_label, gene, plot_groups, expr_cmap, linewidth, umap_labels=umap_labels, ) plt.tight_layout() frame_path = os.path.join(tmpdir, f"frame_{frame_idx:04d}.png") plt.savefig(frame_path, dpi=dpi, bbox_inches="tight") plt.close(fig) progenitor_img = Image.open(frame_path) for _ in range(progenitor_frames): frames.append(progenitor_img.copy()) frame_idx += 1 # Main animation frames for i, t in enumerate(unique_times): fig, ax_umap, ax_expr = _create_dual_figure() _create_expression_frame( adata_sim, ax_umap, ax_expr, t, 1.0, background_fn, x_all, y_all, time_values, color_values, umap_cmap, s, alpha, trail_alpha, leading_edge_scale, vmin, vmax, show_time_label, time_label_loc, time_label_fmt, time_label_fontsize, umap_title, t_min, t_max, color, stats_full, plot_groups, expr_cmap, linewidth, show_std, std_alpha, expr_ylim, x_label, y_label, gene, umap_labels=umap_labels, **kwargs, ) plt.tight_layout() frame_path = os.path.join(tmpdir, f"frame_{frame_idx:04d}.png") plt.savefig(frame_path, dpi=dpi, bbox_inches="tight") plt.close(fig) frames.append(Image.open(frame_path)) frame_idx += 1 # Hold frames at the end last_frame = frames[-1] for _ in range(hold_frames): frames.append(last_frame.copy()) # Fade out frames for fade_i in range(fade_frames): fade_alpha = 1.0 - (fade_i + 1) / fade_frames fig, ax_umap, ax_expr = _create_dual_figure() _create_expression_frame( adata_sim, ax_umap, ax_expr, unique_times[-1], fade_alpha, background_fn, x_all, y_all, time_values, color_values, umap_cmap, s, alpha, trail_alpha, leading_edge_scale, vmin, vmax, show_time_label, time_label_loc, time_label_fmt, time_label_fontsize, umap_title, t_min, t_max, color, stats_full, plot_groups, expr_cmap, linewidth, show_std, std_alpha, expr_ylim, x_label, y_label, gene, umap_labels=umap_labels, **kwargs, ) plt.tight_layout() frame_path = os.path.join(tmpdir, f"fade_{fade_i:04d}.png") plt.savefig(frame_path, dpi=dpi, bbox_inches="tight") plt.close(fig) frames.append(Image.open(frame_path)) # -- Create GIF ------------------------------------------------------- if duration is not None: frame_duration = int(duration * 1000 / len(unique_times)) else: frame_duration = int(1000 / fps) frames[0].save( savename, save_all=True, append_images=frames[1:], duration=frame_duration, loop=0, ) if return_fig: final_fig, final_ax_umap, final_ax_expr = _create_dual_figure() _create_expression_frame( adata_sim, final_ax_umap, final_ax_expr, unique_times[-1], 1.0, background_fn, x_all, y_all, time_values, color_values, umap_cmap, s, alpha, trail_alpha, leading_edge_scale, vmin, vmax, show_time_label, time_label_loc, time_label_fmt, time_label_fontsize, umap_title, t_min, t_max, color, stats_full, plot_groups, expr_cmap, linewidth, show_std, std_alpha, expr_ylim, x_label, y_label, gene, umap_labels=umap_labels, **kwargs, ) plt.tight_layout() return savename, final_fig, final_ax_umap, final_ax_expr return savename