COSIE.COSIE_framework.COSIE_model.train_model

COSIE_model.train_model(file_path, config, optimizer, device, feature_dict, spatial_loc_dict, data_dict, Linkage_indicator, num_hvg=3000, n_x=1, n_y=1)[source]

Train the COSIE model on spatial multimodal data.

Supports both full-graph and subgraph-level training. The training mode is selected based on the n_x and n_y values.

Parameters

file_pathstr

Directory path where final embeddings will be saved as .npy files.

configdict

Configuration dictionary defining model and training hyperparameters.

optimizertorch.optim.Optimizer

Optimizer for updating model parameters.

devicestr

Device identifier (e.g., ‘cuda:0’, ‘cpu’).

feature_dictdict

A dictionary mapping each section name (e.g., ‘s1’, ‘s2’) to a sub-dictionary of processed feature tensors for each modality. Each feature is a torch.FloatTensor.

spatial_loc_dictdict

A dictionary mapping each section name to a 2D NumPy array of spatial coordinates.

data_dictdict

A dictionary mapping each modality name (e.g., ‘RNA’, ‘Protein’) to a list of AnnData objects, one per tissue section. Each AnnData should contain .X, .obs, .var, and .obsm[‘spatial’]. If a modality is missing from a section, use None as a placeholder in the list.

Linkage_indicatordict

A dictionary specifying which tissue section pairs and modality pairs should be linked. Format:

{(“s1”, “s2”): [(“RNA”, “RNA”), (“RNA”, “Protein”)],(“s2”, “s3”): [(“ATAC”, “RNA”)]}

means: constructing linkage between section s1 and s2 using both RNA-RNA strong linkage and RNA-Protein weak linkage; constructing linkage between section s2 and s3 using ATAC-RNA linkage.

num_hvgint, optional

Number of highly variable features to retain for feature matching during linkage construction. Default is 3000.

n_xint, optional

Number of spatial partitions along the x-axis per section. Default is 1. If set to 1, no splitting is applied.

n_yint, optional

Number of spatial partitions along the y-axis per section. Default is 1. If set to 1, no splitting is applied. When both n_x = 1 and n_y = 1 (default), the model runs in full-graph training mode. When either n_x or n_y is greater than 1, the model will switch to subgraph training mode and partitions each section into a grid of n_x × n_y subregions.

Returns

final_embeddingsdict

Dictionary mapping section names to their final learned embedding matrices as NumPy arrays. These embeddings are also saved to {file_path}/s1_embedding.npy, etc. Format:

{‘s1’: np.ndarray of shape (n1_cells, latent_dim), ‘s2’: np.ndarray of shape (n2_cells, latent_dim), …}