# -- import packages: ---------------------------------------------------------
import anndata
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
import pandas as pd
# -- set typing: --------------------------------------------------------------
from typing import Dict, List, Optional, Union
# -- API-facing function: -----------------------------------------------------
[docs]
def simulation_umap(
adata_sim: anndata.AnnData,
color: str = "t",
use_key: str = "X_umap",
gene_key: str = "X_gene_inv",
gene_ids_key: str = "gene_ids",
ax: Optional[plt.Axes] = None,
figsize: tuple = (4, 4),
cmap: Union[str, mcolors.Colormap] = "viridis",
categorical_cmap: Optional[Dict[str, str]] = None,
s: float = 1.0,
alpha: float = 0.8,
title: Optional[str] = None,
show_colorbar: bool = True,
colorbar_label: Optional[str] = None,
show_legend: bool = True,
legend_loc: str = "best",
save: bool = False,
savename: Optional[str] = None,
save_format: str = "svg",
dpi: int = 300,
**kwargs,
) -> plt.Axes:
"""
Plot UMAP embedding of simulated data, colored by obs attribute or gene expression.
Parameters
----------
adata_sim : anndata.AnnData
Simulated data from ``sdq.tl.simulate()``, with UMAP coordinates
in obsm and optionally gene expression in obsm after calling
``sdq.tl.invert_scaled_gex()``.
color : str, default="t"
What to color points by. Can be:
- Column name in ``adata_sim.obs`` (e.g., "t", "fate", "sim_i")
- Gene name (will look up in gene_ids_key and extract from gene_key)
use_key : str, default="X_umap"
Key in ``adata_sim.obsm`` containing UMAP coordinates.
gene_key : str, default="X_gene_inv"
Key in ``adata_sim.obsm`` containing gene expression matrix.
gene_ids_key : str, default="gene_ids"
Key in ``adata_sim.uns`` containing gene names.
ax : plt.Axes, optional
Matplotlib axes to plot on. If None, creates new figure.
figsize : tuple, default=(4, 4)
Figure size (width, height) in inches if creating new figure.
cmap : str or Colormap, default="viridis"
Colormap for continuous values.
categorical_cmap : Dict[str, str], optional
Mapping from category names to colors for categorical data.
s : float, default=1.0
Point size.
alpha : float, default=0.8
Point transparency.
title : str, optional
Plot title. If None, uses the color parameter.
show_colorbar : bool, default=True
Whether to show colorbar for continuous values.
colorbar_label : str, optional
Label for colorbar. If None, uses the color parameter.
show_legend : bool, default=True
Whether to show legend for categorical values.
legend_loc : str, default="best"
Legend location.
save : bool, default=False
Whether to save the figure.
savename : str, optional
Filename for saving. If None, auto-generates from color parameter.
save_format : str, default="svg"
Format for saving figure.
dpi : int, default=300
Resolution for saving figure.
**kwargs
Additional keyword arguments passed to ``ax.scatter()``
(e.g., ``zorder``, ``edgecolors``, ``linewidths``).
Returns
-------
plt.Axes
The matplotlib axes object.
Examples
--------
>>> import scdiffeq as sdq
>>> # Color by time
>>> sdq.pl.simulation_umap(adata_sim, color="t")
>>> # Color by fate
>>> sdq.pl.simulation_umap(adata_sim, color="fate", categorical_cmap={"Mon.": "orange", "Neu.": "blue"})
>>> # Color by gene expression
>>> sdq.pl.simulation_umap(adata_sim, color="Myc")
"""
# -- Get UMAP coordinates -------------------------------------------------
umap_coords = adata_sim.obsm[use_key]
if isinstance(umap_coords, pd.DataFrame):
x = umap_coords.iloc[:, 0].values
y = umap_coords.iloc[:, 1].values
else:
x = umap_coords[:, 0]
y = umap_coords[:, 1]
# -- Determine color values -----------------------------------------------
is_categorical = False
color_values = None
# Check if color is in obs
if color in adata_sim.obs.columns:
color_values = adata_sim.obs[color].values
# Check if categorical
if adata_sim.obs[color].dtype.name == 'category' or isinstance(color_values[0], str):
is_categorical = True
else:
# Try to find as gene name
gene_ids_raw = adata_sim.uns.get(gene_ids_key, {})
# Handle dict format: {index: gene_name}
if isinstance(gene_ids_raw, dict):
gene_names = list(gene_ids_raw.values())
if color in gene_names:
gene_idx = gene_names.index(color)
expr_matrix = adata_sim.obsm[gene_key]
if isinstance(expr_matrix, pd.DataFrame):
color_values = expr_matrix.iloc[:, gene_idx].values
else:
color_values = expr_matrix[:, gene_idx]
else:
# Handle list-like formats
if isinstance(gene_ids_raw, (pd.Index, pd.Series)):
gene_ids = gene_ids_raw.tolist()
elif isinstance(gene_ids_raw, np.ndarray):
gene_ids = gene_ids_raw.tolist()
elif isinstance(gene_ids_raw, list):
gene_ids = gene_ids_raw
else:
gene_ids = list(gene_ids_raw) if gene_ids_raw else []
if color in gene_ids:
gene_idx = gene_ids.index(color)
expr_matrix = adata_sim.obsm[gene_key]
if isinstance(expr_matrix, pd.DataFrame):
color_values = expr_matrix.iloc[:, gene_idx].values
else:
color_values = expr_matrix[:, gene_idx]
if color_values is None:
raise ValueError(
f"'{color}' not found in adata_sim.obs columns or gene names. "
f"Available obs columns: {list(adata_sim.obs.columns)[:5]}..."
)
# -- Setup figure ---------------------------------------------------------
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
# -- Plot -----------------------------------------------------------------
if is_categorical:
unique_cats = np.unique(color_values)
# Default categorical colors
if categorical_cmap is None:
default_colors = plt.cm.tab10.colors
categorical_cmap = {cat: default_colors[i % len(default_colors)]
for i, cat in enumerate(unique_cats)}
for cat in unique_cats:
mask = color_values == cat
cat_color = categorical_cmap.get(cat, "gray")
ax.scatter(x[mask], y[mask], c=[cat_color], s=s, alpha=alpha, label=cat, **kwargs)
if show_legend:
ax.legend(loc=legend_loc, frameon=True, facecolor="white",
edgecolor="lightgray", fontsize=8, markerscale=3)
else:
# Continuous
scatter = ax.scatter(x, y, c=color_values, cmap=cmap, s=s, alpha=alpha, **kwargs)
if show_colorbar:
cbar = plt.colorbar(scatter, ax=ax, shrink=0.8)
if colorbar_label is None:
colorbar_label = color
cbar.set_label(colorbar_label, fontsize=10)
# -- Formatting -----------------------------------------------------------
if title is None:
title = color
ax.set_title(title, fontsize=11)
# Remove all spines, ticks, and labels for clean UMAP look
for spine in ax.spines.values():
spine.set_visible(False)
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlabel("")
ax.set_ylabel("")
# -- Save -----------------------------------------------------------------
if save:
if savename is None:
savename = f"scDiffEq.simulation_umap.{color}.{save_format}"
plt.savefig(savename, format=save_format, dpi=dpi, bbox_inches="tight")
return ax