# -- import packages: ---------------------------------------------------------
import adata_query
import anndata
import logging
import numpy as np
import pandas as pd
import voyager
# -- set type hints: ----------------------------------------------------------
from typing import List, Optional, Tuple, Union
# -- configure logger: --------------------------------------------------------
logger = logging.getLogger(__name__)
# -- metric mapping: ----------------------------------------------------------
METRIC_TO_SPACE = {
"euclidean": voyager.Space.Euclidean,
"cosine": voyager.Space.Cosine,
"inner_product": voyager.Space.InnerProduct,
}
# -- operational class: -------------------------------------------------------
[docs]
class kNN:
"""k-Nearest Neighbors container using voyager backend.
This class provides a kNN graph interface for single-cell data stored in
AnnData objects. It uses voyager (Spotify's HNSW implementation) for
efficient approximate nearest neighbor search.
Args:
adata: AnnData object containing the data.
use_key: Key to fetch data from adata (e.g., "X_pca"). Default: "X_pca".
n_neighbors: Number of neighbors to return in queries. Default: 20.
metric: Distance metric to use. One of "euclidean", "cosine",
"inner_product". Default: "euclidean".
space: Alternative to metric - directly specify voyager.Space.
If provided, overrides metric parameter.
Attributes:
adata: The AnnData object.
use_key: Key used to fetch data.
n_neighbors: Number of neighbors for queries.
space: The voyager.Space used for distance computation.
"""
_KNN_IDX_BUILT: bool = False
[docs]
def __init__(
self,
adata: anndata.AnnData,
use_key: str = "X_pca",
n_neighbors: int = 20,
metric: str = "euclidean",
space: Optional[voyager.Space] = None,
):
self.adata = adata
self.use_key = use_key
self.n_neighbors = n_neighbors
# Handle metric/space parameter
if space is not None:
self.space = space
else:
self.space = METRIC_TO_SPACE.get(metric, voyager.Space.Euclidean)
self._metric = metric
self._build()
@property
def X(self) -> np.ndarray:
"""Fetch the data array from adata."""
if not hasattr(self, "_X"):
self._X = adata_query.fetch(self.adata, self.use_key, torch=False)
return self._X
@property
def n_dim(self) -> int:
"""Number of dimensions in the data."""
return self.X.shape[1]
@property
def n_obs(self) -> int:
"""Number of observations (cells) in the index."""
return len(self._build_indices)
@property
def index(self) -> voyager.Index:
"""The voyager Index object."""
if not hasattr(self, "_index"):
self._index = voyager.Index(space=self.space, num_dimensions=self.n_dim)
return self._index
[docs]
def _build(self) -> None:
"""Build the kNN index by adding all items from the data."""
self._build_indices = [self.index.add_item(x_cell) for x_cell in self.X]
self._KNN_IDX_BUILT = True
logger.info(f"Built kNN index with {len(self._build_indices)} items")
[docs]
def query(
self,
X_query: np.ndarray,
n_neighbors: Optional[int] = None,
include_distances: bool = False,
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
"""Query the kNN index for nearest neighbors.
Args:
X_query: Query points of shape (n_queries, n_dim).
n_neighbors: Number of neighbors to return. If None, uses self.n_neighbors.
include_distances: If True, also return distances.
Returns:
If include_distances is False:
neighbors: Array of shape (n_queries, n_neighbors) with neighbor indices.
If include_distances is True:
Tuple of (neighbors, distances) arrays.
"""
if n_neighbors is None:
n_neighbors = self.n_neighbors
# Query k+1 neighbors to exclude self if query point is in the index
k = int(n_neighbors + 1)
results = [
self.index.query(X_query[i], k)
for i in range(X_query.shape[0])
]
# voyager.Index.query returns (neighbor_ids, distances)
neighbors = np.array([r[0] for r in results])
distances = np.array([r[1] for r in results])
# Exclude the first neighbor (self) - take indices 1:k
neighbors = neighbors[:, 1:].astype(int)
distances = distances[:, 1:]
if include_distances:
return neighbors, distances
return neighbors
[docs]
def _count_values(self, col: pd.Series) -> dict:
"""Count value occurrences in a Series."""
return col.value_counts().to_dict()
[docs]
def _max_count(self, col: pd.Series) -> str:
"""Get the most frequent value in a Series."""
return col.value_counts().idxmax()
[docs]
def count(
self,
query_result: np.ndarray,
obs_key: str,
max_only: bool = False,
n_neighbors: Optional[int] = None,
) -> Union[List[dict], List[str]]:
"""Count neighbor annotations from query results.
Args:
query_result: Array of neighbor indices from query().
obs_key: Key in adata.obs to count.
max_only: If True, return only the most frequent annotation per query.
n_neighbors: Number of neighbors (for reshaping). If None, uses self.n_neighbors.
Returns:
If max_only is False:
List of dicts mapping annotation values to counts.
If max_only is True:
List of most frequent annotation values.
"""
if n_neighbors is None:
n_neighbors = self.n_neighbors
nn_adata = self.adata[query_result.flatten()]
query_df = pd.DataFrame(
nn_adata.obs[obs_key].to_numpy().reshape(-1, n_neighbors).T
)
del nn_adata
if not max_only:
return [
self._count_values(query_df[i]) for i in query_df.columns
] # list of dicts
return [
self._max_count(query_df[i]) for i in query_df.columns
] # list of values
[docs]
def aggregate(
self,
X_query: np.ndarray,
obs_key: str,
max_only: bool = False,
n_neighbors: Optional[int] = None,
) -> pd.DataFrame:
"""Query neighbors and aggregate annotation counts.
Combines query() and count() into a single operation.
Args:
X_query: Query points of shape (n_queries, n_dim).
obs_key: Key in adata.obs to aggregate.
max_only: If True, return only the most frequent annotation per query.
n_neighbors: Number of neighbors. If None, uses self.n_neighbors.
Returns:
DataFrame with aggregated counts or most frequent annotations.
"""
_df = (
pd.DataFrame(
self.count(
query_result=self.query(X_query=X_query, n_neighbors=n_neighbors),
obs_key=obs_key,
max_only=max_only,
n_neighbors=n_neighbors,
)
)
.fillna(0)
.sort_index(axis=1)
)
if not max_only:
return _df
return _df.rename({0: obs_key}, axis=1)
[docs]
def multi_aggregate(
self,
X_query: np.ndarray,
obs_key: str,
max_only: bool = False,
n_neighbors: Optional[int] = None,
) -> Union[List[pd.DataFrame], pd.DataFrame]:
"""Aggregate annotations for multiple query sets.
Args:
X_query: Multiple query sets of shape (n_sets, n_queries, n_dim).
obs_key: Key in adata.obs to aggregate.
max_only: If True, return only the most frequent annotation per query.
n_neighbors: Number of neighbors. If None, uses self.n_neighbors.
Returns:
If max_only is False:
List of DataFrames, one per query set.
If max_only is True:
Single DataFrame with columns for each query set.
"""
_list_of_dfs = [
self.aggregate(
X_query=X_query[i],
obs_key=obs_key,
max_only=max_only,
n_neighbors=n_neighbors,
)
for i in range(len(X_query))
]
if max_only:
concat_df = pd.concat(_list_of_dfs, axis=1)
concat_df.columns = range(len(X_query))
return concat_df
return _list_of_dfs
[docs]
def __repr__(self) -> str:
"""String representation of the kNN instance."""
attrs = {
"built": self._KNN_IDX_BUILT,
"n_obs": self.n_obs,
"n_dim": self.n_dim,
"n_neighbors": self.n_neighbors,
"use_key": self.use_key,
}
repr_str = "k-nearest neighbor graph\n"
for key, val in attrs.items():
repr_str += f"\n {key}: {val}"
return repr_str