import numpy as np
import pandas as pd
import scanpy as sc
import scipy
import anndata as ad
import json
import torch
from collections import defaultdict
[docs]
def preprocess_adata(adata_raw, modality, hvg_num=3000, n_comps=50, target_sum=None):
"""
Preprocess an AnnData object based on the specified modality. The pipeline includes highly variable feature selection, normalization, log-transformation, scaling, and PCA.
This function supports preprocessing of epigenomic, RNA, protein, metabolite, and histology embedding (HE) modalities.
Parameters
----------
adata_raw : AnnData
The raw AnnData object to be processed.
modality : str
The modality type. Must be one of:
- `'RNA'`, `'RNA_panel2'`: RNA count matrix, supports different panels within the same RNA modality.
- `'H3K27me3'`, `'H3K27ac'`, `'ATAC'`, `'H3K4me3'`: Epigenomic signals.
We recommend first converting raw epigenomics data to gene scores before using this function.
Gene score generation scripts are available at spatial-Mux-seq Repository.
- `'Protein'`: Protein abundance matrix; CLR normalization will be applied. For COMET protein data, we recommend using arcsinh normalization.
- `'Metabolite'`: Metabolite expression matrix.
- `'HE'`: Histology image embeddings; PCA will be applied directly without normalization.
hvg_num : int, optional
Number of highly variable features to select. If None, HVG selection is skipped. Default is 3000.
n_comps : int, optional
Number of PCA components to compute. Default is 50.
target_sum : float or None, optional
Target sum for total-count normalization (used in `normalize_total`). If None, the Scanpy default is used. Default is None.
Returns
-------
adata : AnnData
A preprocessed AnnData object with normalized, log-transformed, scaled, and PCA-reduced `.X`. For protein modality, CLR normalization is used. HVG selection is only applied if `hvg_num` is provided and the number of input features exceeds this threshold.
"""
adata = adata_raw.copy()
adata.var_names_make_unique()
# print(f'Processing data... Modality: {modality}')
if modality == 'HE':
sc.tl.pca(adata, n_comps=n_comps)
else:
if hvg_num and len(adata.var_names)>hvg_num:
if modality in {"RNA", "RNA_panel2","H3K27me3", "H3K27ac", "ATAC", "H3K4me3", "Metabolite"}:
use_batch = 'batch' in adata.obs
# print(f'Selecting HVG for {modality} {"with batch key" if use_batch else "without batch key"}')
if modality in ["RNA", "RNA_panel2"]:
sc.pp.highly_variable_genes(adata, n_top_genes=hvg_num, flavor="seurat_v3", batch_key='batch' if use_batch else None)
else:
sc.pp.highly_variable_genes(adata, n_top_genes=hvg_num, batch_key='batch' if use_batch else None)
adata = adata[:, adata.var['highly_variable']]
if modality == "Protein":
adata = clr_normalize_each_cell(adata)
sc.pp.scale(adata)
# adata.X = np.arcsinh(adata.X / 5) ### For COMET data
# print('using arcsinh')
n_proteins = adata.shape[1]
if n_proteins >= n_comps:
sc.tl.pca(adata, n_comps=n_comps)
elif n_proteins >= 20:
sc.tl.pca(adata, n_comps=20)
else:
sc.tl.pca(adata, n_comps=15)
else:
if target_sum:
sc.pp.normalize_total(adata, target_sum=target_sum)
else:
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)
sc.pp.scale(adata)
sc.tl.pca(adata, n_comps=n_comps)
return adata
[docs]
def load_data(data_dict, n_comps=50, hvg_num=3000, target_sum=None, use_harmony=True, metacell = False):
"""
Process input spatial multi-modal data, returning processed feature matrices and spatial coordinates.
Shared modalities that appear in multiple sections are concatenated and jointly processed.
Unique modalities (only present in one section) are processed independently. Each section's feature matrix is stored
as a PyTorch tensor for downstream modeling. Spatial coordinates are checked for consistency across modalities; if
inconsistent within a section, an error is raised.
Parameters
----------
data_dict : dict
A dictionary mapping each modality name (e.g., `'RNA'`, `'Protein'`) to a list of AnnData objects, one per tissue section. Each AnnData should contain `.X`, `.obs`, `.var`, and `.obsm['spatial']`. If a modality is missing from a section, use `None` as a placeholder in the list.
n_comps : int, optional
Number of PCA components to compute. Default is 50.
hvg_num : int, optional
Number of highly variable features to select. If the feature dimension is smaller than `hvg_num`, HVG selection is skipped. Default is 3000.
target_sum : float or None, optional
Target sum for total-count normalization (used in `scanpy.pp.normalize_total`). If None, Scanpy default is used. Default is None.
use_harmony : bool, optional
Whether to perform Harmony integration across sections for shared modalities. If False, only joint PCA is applied. Default is True.
metacell : bool, optional
Whether to merge each 2×2 spatial grid of cells into a "metacell" for reducing memory usage and improving speed. Applies to all modalities. Default is False.
Returns
-------
feature_dict : dict
A dictionary mapping each section name (e.g., `'s1'`, `'s2'`) to a sub-dictionary of processed feature tensors for each modality. Each feature is a `torch.FloatTensor` of shape (n_cells, n_comps).
spatial_loc_dict : dict
A dictionary mapping each section name to a 2D NumPy array of spatial coordinates, extracted from `.obsm['spatial']`. Shape is (n_cells, 2).
data_dict : dict
The updated input dictionary. Each AnnData object is modified to include reduced features (e.g., PCA or Harmony output) in `.obsm`.
"""
if metacell:
# This will combine every 2x2 adjacent cells into a meta-cell across all modalities.
print('Combine adjacent 4 cells into metacell to save memory and speed up computation')
data_dict = construct_metacell_data_dict(data_dict)
feature_dict = {}
spatial_loc_dict = {}
num_sections = max(len(sections) for sections in data_dict.values())
# Detect shared modality
shared_modalities = {modality: [adata for adata in sections if adata is not None]
for modality, sections in data_dict.items()
if sum(x is not None for x in sections) > 1}
# print('Shared modalities:', shared_modalities)
shared_modality_sections = {modality: [idx for idx, adata in enumerate(data_dict[modality]) if adata is not None]
for modality in shared_modalities}
# Process shared modality
for modality, adata_list in shared_modalities.items():
print(f'-------- Processing shared modality {modality} across sections --------')
if modality == 'HE':
adata_sub_list = []
for i, adata in enumerate(adata_list):
adata_sub = adata.copy()
adata_sub.obs_names = adata_sub.obs_names + f"_{shared_modality_sections[modality][i]}"
adata_sub_list.append(adata_sub)
else:
common_var_names = adata_list[0].var_names
for adata in adata_list[1:]:
common_var_names = common_var_names.intersection(adata.var_names)
adata_sub_list = []
for i, adata in enumerate(adata_list):
adata_sub = adata[:, common_var_names].copy()
adata_sub.obs_names = adata_sub.obs_names + f"_{shared_modality_sections[modality][i]}"
adata_sub_list.append(adata_sub)
adata_combined = ad.concat(adata_sub_list)
adata_combined.obs['batch'] = [f'batch_{shared_modality_sections[modality][i]}'
for i, adata in enumerate(adata_list) for _ in range(adata.shape[0])]
adata_combined = preprocess_adata(adata_combined, modality, hvg_num=hvg_num, n_comps=n_comps)
if use_harmony:
print(f"Running Harmony for {modality}")
sc.external.pp.harmony_integrate(adata_combined, key='batch')
pca_data_combined = adata_combined.obsm['X_pca_harmony']
else:
pca_data_combined = adata_combined.obsm['X_pca']
# split back to each section
split_indices = np.cumsum([adata.shape[0] for adata in adata_list])[:-1]
combined_data_splits = np.split(pca_data_combined, split_indices)
for i, section in enumerate(shared_modality_sections[modality]):
key_name = f'{modality}_harmony' if use_harmony else f'{modality}_pca'
data_dict[modality][section].obsm[key_name] = combined_data_splits[i]
if section not in feature_dict:
feature_dict[section] = {}
shared_data = combined_data_splits[i].copy()
feature_dict[section][modality] = torch.from_numpy(shared_data).float()
del shared_data
# Process unique modality
for modality, sections in data_dict.items():
if modality in shared_modalities:
continue
for section, adata in enumerate(sections):
if adata is not None:
print(f'-------- Processing unique modality {modality} for section {section+1} --------')
if section not in feature_dict:
feature_dict[section] = {}
adata_processed = preprocess_adata(adata, modality, hvg_num=hvg_num, n_comps=n_comps, target_sum=target_sum)
pca_data = adata_processed.obsm['X_pca'].copy()
data_dict[modality][section].obsm[f'{modality}_pca'] = pca_data
feature_dict[section][modality] = torch.from_numpy(pca_data).float()
del pca_data
feature_dict = {f's{int(k) + 1}': v for k, v in feature_dict.items()}
# Process spatial location
for section_idx in range(num_sections):
print(f'Extracting spatial location for section {section_idx+1}')
spatial_list = []
for modality, sections in data_dict.items():
if section_idx < len(sections) and sections[section_idx] is not None and 'spatial' in sections[section_idx].obsm:
spatial_list.append(sections[section_idx].obsm['spatial'])
if len(spatial_list) == 1:
spatial_loc_dict[f's{section_idx+1}'] = spatial_list[0]
elif len(spatial_list) > 1:
if all(np.array_equal(spatial_list[0], spatial) for spatial in spatial_list[1:]):
spatial_loc_dict[f's{section_idx+1}'] = spatial_list[0]
else:
raise ValueError(f"Section {section_idx+1} contains inconsistent spatial information across different modalities!")
return feature_dict, spatial_loc_dict, data_dict
### CLR normalization for protein borrowed from SpatialGLUE
[docs]
def clr_normalize_each_cell(adata, inplace=True):
"""
Normalize each cell's protein counts using Centered Log-Ratio (CLR) normalization,
following the approach used in Seurat and SpatialGLUE.
Parameters
----------
adata : AnnData
Input AnnData object where `.X` stores raw count data (e.g., protein abundance).
inplace : bool, optional
Whether to modify the input `adata` in place. If True, the normalization will overwrite `adata.X`. If False, a normalized copy of `adata` is returned. Default is True.
Returns
-------
adata : AnnData
The AnnData object with CLR-normalized `.X`. If `inplace=True`, returns the modified input object; if `inplace=False`, returns a new normalized copy.
"""
def seurat_clr(x):
s = np.sum(np.log1p(x[x > 0]))
exp = np.exp(s / len(x))
return np.log1p(x / exp)
if not inplace:
adata = adata.copy()
adata.X = np.apply_along_axis(
seurat_clr, 1, (adata.X.A if scipy.sparse.issparse(adata.X) else np.array(adata.X))
)
return adata
# def metacell_construction(adata):
# """
# Construct metacells by aggregating every 2×2 spatially adjacent grid of cells into one,
# to reduce memory usage and speed up computation.
# Parameters
# ----------
# adata : AnnData
# An AnnData object. The spatial coordinates should be stored in `adata.obsm['spatial']`. The expression matrix `adata.X` can be either dense or sparse.
# Returns
# -------
# adata_metacell : AnnData
# A new AnnData object where each observation (cell) corresponds to a metacell,
# formed by averaging a group of up to 4 adjacent spatial cells (in a 2×2 pattern).
# The returned object includes:
# - `.X`: Averaged expression matrix across grouped cells.
# - `.obsm['spatial']`: Spatial coordinates (mean of each group).
# - `.uns['meta_to_original']`: A list mapping each metacell to the indices of original cells it includes.
# - `.uns['original_cell_num']`: Total number of original cells before metacell construction.
# """
# spatial_coords = adata.obsm['spatial']
# expression = adata.X.toarray() if scipy.sparse.issparse(adata.X) else adata.X
# # sort (y, x)
# sorted_indices = np.lexsort((spatial_coords[:, 0], spatial_coords[:, 1]))
# sorted_coords = spatial_coords[sorted_indices]
# sorted_expression = expression[sorted_indices]
# meta_cells = []
# meta_coords = []
# meta_to_original = []
# visited = set()
# # visit each cell
# for i in range(len(sorted_coords)):
# if i in visited:
# continue
# # current cell
# group_indices = [i]
# visited.add(i)
# # neighbors
# right = None
# down = None
# diag = None
# for j in range(i + 1, len(sorted_coords)):
# if j in visited:
# continue
# dx = sorted_coords[j][0] - sorted_coords[i][0]
# dy = sorted_coords[j][1] - sorted_coords[i][1]
# if right is None and dx > 0 and dy == 0:
# right = j
# visited.add(j)
# elif down is None and dx == 0 and dy > 0:
# down = j
# visited.add(j)
# elif diag is None and dx > 0 and dy > 0:
# diag = j
# visited.add(j)
# if right is not None and down is not None and diag is not None:
# break
# if right is not None and down is not None and diag is not None:
# group_indices.extend([right, down, diag])
# original_indices = sorted_indices[group_indices]
# meta_to_original.append(original_indices.tolist())
# merged_expression = sorted_expression[group_indices].mean(axis=0)
# merged_coords = sorted_coords[group_indices].mean(axis=0)
# meta_cells.append(merged_expression)
# meta_coords.append(merged_coords)
# # build new adata
# meta_X = np.vstack(meta_cells)
# meta_coords = np.array(meta_coords)
# adata_metacell = sc.AnnData(X=meta_X)
# adata_metacell.var_names = adata.var_names.copy()
# adata_metacell.obsm['spatial'] = meta_coords
# # adata_metacell.uns['meta_to_original'] = [json.dumps(indices) for indices in meta_to_original]
# adata_metacell.uns['meta_to_original'] = meta_to_original
# adata_metacell.uns['original_cell_num'] = adata.n_obs
# adata_metacell.X = np.array(adata_metacell.X)
# return adata_metacell
###