COSIE.COSIE_framework.COSIE_model.train_model
- COSIE_model.train_model(file_path, config, optimizer, device, feature_dict, spatial_loc_dict, data_dict, Linkage_indicator, num_hvg=3000, n_x=1, n_y=1)[source]
Train the COSIE model on spatial multimodal data.
Supports both full-graph and subgraph-level training. The training mode is selected based on the n_x and n_y values.
Parameters
- file_pathstr
Directory path where final embeddings will be saved as .npy files.
- configdict
Configuration dictionary defining model and training hyperparameters.
- optimizertorch.optim.Optimizer
Optimizer for updating model parameters.
- devicestr
Device identifier (e.g., ‘cuda:0’, ‘cpu’).
- feature_dictdict
A dictionary mapping each section name (e.g., ‘s1’, ‘s2’) to a sub-dictionary of processed feature tensors for each modality. Each feature is a torch.FloatTensor.
- spatial_loc_dictdict
A dictionary mapping each section name to a 2D NumPy array of spatial coordinates.
- data_dictdict
A dictionary mapping each modality name (e.g., ‘RNA’, ‘Protein’) to a list of AnnData objects, one per tissue section. Each AnnData should contain .X, .obs, .var, and .obsm[‘spatial’]. If a modality is missing from a section, use None as a placeholder in the list.
- Linkage_indicatordict
A dictionary specifying which tissue section pairs and modality pairs should be linked. Format:
{(“s1”, “s2”): [(“RNA”, “RNA”), (“RNA”, “Protein”)],(“s2”, “s3”): [(“ATAC”, “RNA”)]}
means: constructing linkage between section s1 and s2 using both RNA-RNA strong linkage and RNA-Protein weak linkage; constructing linkage between section s2 and s3 using ATAC-RNA linkage.
- num_hvgint, optional
Number of highly variable features to retain for feature matching during linkage construction. Default is 3000.
- n_xint, optional
Number of spatial partitions along the x-axis per section. Default is 1. If set to 1, no splitting is applied.
- n_yint, optional
Number of spatial partitions along the y-axis per section. Default is 1. If set to 1, no splitting is applied. When both n_x = 1 and n_y = 1 (default), the model runs in full-graph training mode. When either n_x or n_y is greater than 1, the model will switch to subgraph training mode and partitions each section into a grid of n_x × n_y subregions.
Returns
- final_embeddingsdict
Dictionary mapping section names to their final learned embedding matrices as NumPy arrays. These embeddings are also saved to {file_path}/s1_embedding.npy, etc. Format:
{‘s1’: np.ndarray of shape (n1_cells, latent_dim), ‘s2’: np.ndarray of shape (n2_cells, latent_dim), …}