# -- import packages: ---------------------------------------------------------
import anndata
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.cm as cm
import numpy as np
import pandas as pd
import tempfile
import os
# -- set typing: --------------------------------------------------------------
from typing import Callable, Dict, List, Optional, Union
# -- Helper functions: --------------------------------------------------------
def _create_default_background(adata_sim, ax, use_key, background_s, background_inner_s):
"""Default background with black outline and white fill."""
xu = adata_sim.obsm[use_key]
if isinstance(xu, pd.DataFrame):
xu = xu.values
ax.scatter(xu[:, 0], xu[:, 1], c="k", ec="None", rasterized=True, s=background_s)
ax.scatter(xu[:, 0], xu[:, 1], c="w", ec="None", rasterized=True, s=background_inner_s)
def _create_grouped_background(
adata_sim, ax, use_key, background_groupby, background_cmap, background_s, background_inner_s
):
"""Background colored by group membership."""
groups = adata_sim.obs[background_groupby]
xu = adata_sim.obsm[use_key]
if isinstance(xu, pd.DataFrame):
xu = xu.values
unique_groups = groups.unique()
if background_cmap is None:
default_colors = plt.cm.tab10.colors
group_colors = {g: default_colors[i % len(default_colors)] for i, g in enumerate(unique_groups)}
else:
group_colors = background_cmap
for group in unique_groups:
mask = groups == group
c = group_colors.get(group, "k")
ax.scatter(xu[mask, 0], xu[mask, 1], c=c, ec="None", rasterized=True, s=background_s)
ax.scatter(xu[mask, 0], xu[mask, 1], c="w", ec="None", rasterized=True, s=background_inner_s)
def _draw_umap_labels(ax, umap_labels):
"""Draw text labels on the UMAP axes."""
if umap_labels is None:
return
for label in umap_labels:
text = label.get("text", "")
x = label.get("x", 0)
y = label.get("y", 0)
kwargs = {k: v for k, v in label.items() if k not in ("text", "x", "y")}
# Set defaults
kwargs.setdefault("fontsize", 10)
kwargs.setdefault("ha", "center")
kwargs.setdefault("va", "center")
ax.text(x, y, text, **kwargs)
def _create_trajectory_progenitor_frame(
adata_sim,
ax,
background_fn,
progenitor_x,
progenitor_y,
progenitor_color,
progenitor_s,
progenitor_label,
show_time_label,
time_label_loc,
time_label_fmt,
time_label_fontsize,
t_min,
t_max,
title,
x_all,
y_all,
cmap,
color,
umap_labels=None,
):
"""Create content for progenitor intro frame on a given axes."""
background_fn(adata_sim, ax)
ax.scatter(
progenitor_x,
progenitor_y,
c=progenitor_color,
s=progenitor_s,
edgecolors="white",
linewidths=1.5,
zorder=300,
)
ax.annotate(
progenitor_label,
xy=(progenitor_x, progenitor_y),
xytext=(progenitor_x + 1.5, progenitor_y + 1.5),
fontsize=12,
fontweight="bold",
color=progenitor_color,
arrowprops=dict(arrowstyle="->", color=progenitor_color, lw=2),
zorder=301,
)
if show_time_label:
ax.text(
time_label_loc[0],
time_label_loc[1],
time_label_fmt.format(t_min),
transform=ax.transAxes,
fontsize=time_label_fontsize,
verticalalignment="top",
fontweight="bold",
bbox=dict(boxstyle="round", facecolor="white", alpha=0.8),
)
if title:
ax.set_title(title, fontsize=12)
for spine in ax.spines.values():
spine.set_visible(False)
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlabel("")
ax.set_ylabel("")
ax.set_xlim(x_all.min() - 0.5, x_all.max() + 0.5)
ax.set_ylim(y_all.min() - 0.5, y_all.max() + 0.5)
norm = mcolors.Normalize(vmin=t_min, vmax=t_max)
sm = cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
cbar = plt.colorbar(sm, ax=ax, shrink=0.6, orientation="horizontal", location="bottom")
cbar.set_label(color, fontsize=10)
cbar_ax = cbar.ax
cbar_xmin, cbar_xmax = cbar_ax.get_xlim()
cbar_ax.axvline(x=cbar_xmin, color="dodgerblue", linewidth=2, zorder=10)
# Draw UMAP labels
_draw_umap_labels(ax, umap_labels)
def _create_trajectory_frame(
adata_sim,
ax,
t_current,
frame_alpha,
background_fn,
x_all,
y_all,
time_values,
color_values,
cmap,
s,
alpha,
trail_alpha,
leading_edge_scale,
vmin,
vmax,
show_time_label,
time_label_loc,
time_label_fmt,
time_label_fontsize,
title,
t_min,
t_max,
color,
umap_labels=None,
**kwargs,
):
"""Create content for a single animation frame on a given axes."""
background_fn(adata_sim, ax)
mask = time_values <= t_current
x = x_all[mask]
y = y_all[mask]
c = color_values[mask]
t_pts = time_values[mask]
trail_mask = t_pts < t_current
if np.any(trail_mask):
ax.scatter(
x[trail_mask],
y[trail_mask],
c=c[trail_mask],
cmap=cmap,
s=s,
alpha=alpha * trail_alpha * frame_alpha,
vmin=vmin,
vmax=vmax,
zorder=200,
edgecolors="none",
**kwargs,
)
leading_mask = t_pts == t_current
if np.any(leading_mask):
ax.scatter(
x[leading_mask],
y[leading_mask],
c=c[leading_mask],
cmap=cmap,
s=s * leading_edge_scale,
alpha=frame_alpha,
vmin=vmin,
vmax=vmax,
zorder=202,
edgecolors="none",
**kwargs,
)
else:
ax.scatter(
x,
y,
c=c,
cmap=cmap,
s=s,
alpha=alpha * trail_alpha * frame_alpha,
vmin=vmin,
vmax=vmax,
zorder=200,
edgecolors="none",
**kwargs,
)
if show_time_label:
ax.text(
time_label_loc[0],
time_label_loc[1],
time_label_fmt.format(t_current),
transform=ax.transAxes,
fontsize=time_label_fontsize,
verticalalignment="top",
fontweight="bold",
alpha=frame_alpha,
bbox=dict(boxstyle="round", facecolor="white", alpha=0.8 * frame_alpha),
)
if title:
ax.set_title(title, fontsize=12)
for spine in ax.spines.values():
spine.set_visible(False)
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlabel("")
ax.set_ylabel("")
ax.set_xlim(x_all.min() - 0.5, x_all.max() + 0.5)
ax.set_ylim(y_all.min() - 0.5, y_all.max() + 0.5)
norm = mcolors.Normalize(vmin=t_min, vmax=t_max)
sm = cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
cbar = plt.colorbar(sm, ax=ax, shrink=0.6, orientation="horizontal", location="bottom")
cbar.set_label(color, fontsize=10)
cbar_ax = cbar.ax
cbar_xmin, cbar_xmax = cbar_ax.get_xlim()
progress_x = (
cbar_xmin + (t_current - t_min) / (t_max - t_min) * (cbar_xmax - cbar_xmin)
if t_max > t_min
else cbar_xmax
)
cbar_ax.axvline(x=progress_x, color="dodgerblue", linewidth=2, zorder=10)
# Draw UMAP labels
_draw_umap_labels(ax, umap_labels)
# -- API-facing function: -----------------------------------------------------
[docs]
def simulation_trajectory_gif(
adata_sim: anndata.AnnData,
savename: str = "simulation_trajectory.gif",
color: str = "t",
use_key: str = "X_umap",
time_key: str = "t",
gene_key: str = "X_gene_inv",
gene_ids_key: str = "gene_ids",
figsize: tuple = (6, 6),
cmap: Union[str, mcolors.Colormap] = "plasma_r",
s: float = 10.0,
alpha: float = 0.8,
background_fn: Optional[Callable] = None,
background_groupby: Optional[str] = None,
background_cmap: Optional[Dict[str, str]] = None,
background_s: float = 100.0,
background_inner_s: float = 65.0,
umap_labels: Optional[List[Dict]] = None,
show_time_label: bool = True,
time_label_fmt: str = "t = {:.1f}d",
time_label_loc: tuple = (0.05, 0.95),
time_label_fontsize: int = 12,
title: Optional[str] = None,
fps: int = 10,
duration: Optional[float] = None,
hold_frames: int = 10,
fade_frames: int = 8,
leading_edge_scale: float = 2.0,
trail_alpha: float = 0.5,
show_progenitor: bool = True,
progenitor_frames: int = 8,
progenitor_label: str = "Progenitor",
progenitor_s: float = 80.0,
progenitor_color: str = "dodgerblue",
dpi: int = 100,
return_fig: bool = False,
**kwargs,
) -> Union[str, tuple]:
"""
Create a GIF of simulation trajectories growing over UMAP space.
Parameters
----------
adata_sim : anndata.AnnData
Simulated data from ``sdq.tl.simulate()``, with UMAP coordinates
in obsm.
savename : str, default="simulation_trajectory.gif"
Output filename for the GIF.
color : str, default="t"
What to color points by. Can be column in obs or gene name.
use_key : str, default="X_umap"
Key in ``adata_sim.obsm`` containing UMAP coordinates.
time_key : str, default="t"
Column in ``adata_sim.obs`` containing time values.
gene_key : str, default="X_gene_inv"
Key in ``adata_sim.obsm`` containing gene expression matrix.
gene_ids_key : str, default="gene_ids"
Key in ``adata_sim.uns`` containing gene names.
figsize : tuple, default=(6, 6)
Figure size (width, height) in inches.
cmap : str or Colormap, default="plasma_r"
Colormap for continuous values.
s : float, default=10.0
Point size for simulation points.
alpha : float, default=0.8
Point transparency.
background_fn : Callable, optional
Custom function to plot background. Should accept (adata_sim, ax).
If None, uses default background (or fate-colored if background_groupby set).
background_groupby : str, optional
Column in obs to group background cells by (e.g., "final_state").
When set, background cells are colored by group using background_cmap.
background_cmap : Dict[str, str], optional
Mapping from group names to colors for background. Required if
background_groupby is set. Example: {"Mon.": "orange", "Neu.": "#4a7298"}
background_s : float, default=100.0
Point size for background outer points.
background_inner_s : float, default=65.0
Point size for background inner points.
umap_labels : List[Dict], optional
List of label dictionaries to draw on the UMAP. Each dict should
have keys "text", "x", "y", and optionally any matplotlib text kwargs
like "color", "fontsize", "weight", "ha", "va". Example:
[{"text": "Monocyte", "x": 10.5, "y": 10, "color": "#F08700", "weight": "bold"}]
show_time_label : bool, default=True
Whether to show time label on each frame.
time_label_fmt : str, default="t = {:.1f}d"
Format string for time label.
time_label_loc : tuple, default=(0.05, 0.95)
Location of time label in axes coordinates.
time_label_fontsize : int, default=12
Font size for time label.
title : str, optional
Plot title.
fps : int, default=10
Frames per second for the GIF.
duration : float, optional
Total duration in seconds. If provided, overrides fps.
hold_frames : int, default=10
Number of frames to hold at the end before fading.
fade_frames : int, default=8
Number of frames for the fade-out transition.
leading_edge_scale : float, default=2.0
Size multiplier for leading edge points (current time step).
trail_alpha : float, default=0.5
Alpha multiplier for trail points (older time steps), relative to base alpha.
show_progenitor : bool, default=True
Whether to show progenitor intro frames at the start.
progenitor_frames : int, default=8
Number of frames to hold on the progenitor before starting animation.
progenitor_label : str, default="Progenitor"
Label text for the progenitor annotation.
progenitor_s : float, default=80.0
Point size for the progenitor marker.
progenitor_color : str, default="dodgerblue"
Color for the progenitor marker and annotation.
dpi : int, default=100
Resolution for each frame.
return_fig : bool, default=False
If True, also returns the final frame's (fig, ax) tuple.
**kwargs
Additional keyword arguments passed to scatter.
Returns
-------
str or tuple
Path to the saved GIF file. If return_fig=True, returns
(savename, fig, ax) tuple with the final frame.
Examples
--------
>>> import scdiffeq as sdq
>>> # Basic usage
>>> sdq.pl.simulation_trajectory_gif(adata_sim, savename="my_sim.gif")
>>> # With custom background
>>> def my_background(adata_sim, ax):
... xu = adata_sim.obsm["X_umap"]
... ax.scatter(xu[:, 0], xu[:, 1], c="lightgray", s=50)
>>> sdq.pl.simulation_trajectory_gif(adata_sim, background_fn=my_background)
"""
try:
from PIL import Image
except ImportError:
raise ImportError(
"PIL (Pillow) is required for GIF creation. " "Install it with: pip install Pillow"
)
# -- Get UMAP coordinates -------------------------------------------------
umap_coords = adata_sim.obsm[use_key]
if isinstance(umap_coords, pd.DataFrame):
x_all = umap_coords.iloc[:, 0].values
y_all = umap_coords.iloc[:, 1].values
else:
x_all = umap_coords[:, 0]
y_all = umap_coords[:, 1]
# -- Get time values ------------------------------------------------------
time_values = adata_sim.obs[time_key].values
unique_times = np.sort(np.unique(time_values))
t_min, t_max = unique_times.min(), unique_times.max()
# -- Compute progenitor mean position (t=0 cells) -------------------------
t0_mask = time_values == t_min
progenitor_x = np.mean(x_all[t0_mask])
progenitor_y = np.mean(y_all[t0_mask])
# -- Get color values -----------------------------------------------------
color_values = None
if color in adata_sim.obs.columns:
color_values = adata_sim.obs[color].values
else:
gene_ids_raw = adata_sim.uns.get(gene_ids_key, {})
if isinstance(gene_ids_raw, dict):
gene_names = list(gene_ids_raw.values())
if color in gene_names:
gene_idx = gene_names.index(color)
expr_matrix = adata_sim.obsm[gene_key]
if isinstance(expr_matrix, pd.DataFrame):
color_values = expr_matrix.iloc[:, gene_idx].values
else:
color_values = expr_matrix[:, gene_idx]
else:
if isinstance(gene_ids_raw, (pd.Index, pd.Series)):
gene_ids = gene_ids_raw.tolist()
elif isinstance(gene_ids_raw, np.ndarray):
gene_ids = gene_ids_raw.tolist()
elif isinstance(gene_ids_raw, list):
gene_ids = gene_ids_raw
else:
gene_ids = list(gene_ids_raw) if gene_ids_raw else []
if color in gene_ids:
gene_idx = gene_ids.index(color)
expr_matrix = adata_sim.obsm[gene_key]
if isinstance(expr_matrix, pd.DataFrame):
color_values = expr_matrix.iloc[:, gene_idx].values
else:
color_values = expr_matrix[:, gene_idx]
if color_values is None:
color_values = time_values
vmin, vmax = np.nanmin(color_values), np.nanmax(color_values)
# -- Setup background function --------------------------------------------
if background_fn is None:
if background_groupby is not None:
background_fn = lambda adata, ax: _create_grouped_background(
adata, ax, use_key, background_groupby, background_cmap, background_s, background_inner_s
)
else:
background_fn = lambda adata, ax: _create_default_background(
adata, ax, use_key, background_s, background_inner_s
)
# -- Create frames --------------------------------------------------------
frames = []
with tempfile.TemporaryDirectory() as tmpdir:
frame_idx = 0
# Progenitor intro frames
if show_progenitor and progenitor_frames > 0:
fig, ax = plt.subplots(figsize=figsize)
_create_trajectory_progenitor_frame(
adata_sim,
ax,
background_fn,
progenitor_x,
progenitor_y,
progenitor_color,
progenitor_s,
progenitor_label,
show_time_label,
time_label_loc,
time_label_fmt,
time_label_fontsize,
t_min,
t_max,
title,
x_all,
y_all,
cmap,
color,
umap_labels=umap_labels,
)
frame_path = os.path.join(tmpdir, f"frame_{frame_idx:04d}.png")
plt.savefig(frame_path, dpi=dpi, bbox_inches="tight")
plt.close(fig)
progenitor_img = Image.open(frame_path)
for _ in range(progenitor_frames):
frames.append(progenitor_img.copy())
frame_idx += 1
# Main animation frames
for i, t in enumerate(unique_times):
fig, ax = plt.subplots(figsize=figsize)
_create_trajectory_frame(
adata_sim,
ax,
t,
1.0,
background_fn,
x_all,
y_all,
time_values,
color_values,
cmap,
s,
alpha,
trail_alpha,
leading_edge_scale,
vmin,
vmax,
show_time_label,
time_label_loc,
time_label_fmt,
time_label_fontsize,
title,
t_min,
t_max,
color,
umap_labels=umap_labels,
**kwargs,
)
frame_path = os.path.join(tmpdir, f"frame_{frame_idx:04d}.png")
plt.savefig(frame_path, dpi=dpi, bbox_inches="tight")
plt.close(fig)
frames.append(Image.open(frame_path))
frame_idx += 1
# Hold frames at the end
last_frame = frames[-1]
for _ in range(hold_frames):
frames.append(last_frame.copy())
# Fade out frames
for fade_i in range(fade_frames):
fade_alpha = 1.0 - (fade_i + 1) / fade_frames
fig, ax = plt.subplots(figsize=figsize)
_create_trajectory_frame(
adata_sim,
ax,
unique_times[-1],
fade_alpha,
background_fn,
x_all,
y_all,
time_values,
color_values,
cmap,
s,
alpha,
trail_alpha,
leading_edge_scale,
vmin,
vmax,
show_time_label,
time_label_loc,
time_label_fmt,
time_label_fontsize,
title,
t_min,
t_max,
color,
umap_labels=umap_labels,
**kwargs,
)
frame_path = os.path.join(tmpdir, f"fade_{fade_i:04d}.png")
plt.savefig(frame_path, dpi=dpi, bbox_inches="tight")
plt.close(fig)
frames.append(Image.open(frame_path))
# -- Create GIF -------------------------------------------------------
if duration is not None:
frame_duration = int(duration * 1000 / len(unique_times))
else:
frame_duration = int(1000 / fps)
frames[0].save(
savename,
save_all=True,
append_images=frames[1:],
duration=frame_duration,
loop=0,
)
if return_fig:
final_fig, final_ax = plt.subplots(figsize=figsize)
_create_trajectory_frame(
adata_sim,
final_ax,
unique_times[-1],
1.0,
background_fn,
x_all,
y_all,
time_values,
color_values,
cmap,
s,
alpha,
trail_alpha,
leading_edge_scale,
vmin,
vmax,
show_time_label,
time_label_loc,
time_label_fmt,
time_label_fontsize,
title,
t_min,
t_max,
color,
umap_labels=umap_labels,
**kwargs,
)
return savename, final_fig, final_ax
return savename