Graph Neural Networks in Integrative Omics¶
- Sergiu Netotea, NBIS, 2024
Graph Embeddings¶
- What is a graph embedding?
Graph embedding is a technique used to represent the vertices (or nodes) of a graph in a continuous vector space, often in a lower dimension, while preserving the graph's structural and relational information.
- Why are we doing this?
Dimensionality Reduction: Graphs often have high-dimensional and sparse data. Embedding reduces the dimensions while preserving important information, making it easier to work with. The goal of embedding is to ensure that nodes that are close or have similar roles in the original graph remain close in the embedding space. This includes properties such as neighborhood similarity, shortest paths, or connectivity patterns.
Continuous Vector Representation: Each node in the graph is represented as a vector of real numbers. These vectors can then be used as input to machine learning models for tasks such as node classification, link prediction, and clustering. This allows graphs, which are inherently discrete and complex, to be processed and analyzed using machine learning algorithms that work on continuous data.
Common applications:
- Node Classification: Predicting the category or label of a node.
- Link Prediction: Predicting the likelihood of a connection between two nodes.
- Graph Clustering: Grouping nodes based on similarity in their embeddings.
- Recommendation Systems: Embedding nodes (e.g., users or items) can be used in collaborative filtering models.
Methods for Graph Embedding:
- Matrix Factorization: Traditional methods involve decomposing a graph’s adjacency matrix or Laplacian matrix into lower-dimensional representations. For example, Laplacian Eigenmaps or Graph Factorization.
- Random Walk-based Methods:
- DeepWalk: It generates random walks from each node and treats them like sentences in natural language processing (NLP). Word2Vec-like algorithms are then used to generate embeddings.
- Node2Vec: A variation of DeepWalk, it adjusts the random walks to balance between exploring local neighborhoods and global structures.
- Graph Neural Networks (GNNs): These are neural networks specifically designed for graph structures. GNNs learn node representations by aggregating information from a node’s neighbors iteratively. Popular GNN variants include:
- Graph Convolutional Networks (GCNs): Generalize the concept of convolution from image data to graph data.
- Graph Attention Networks (GATs): Use attention mechanisms to weigh the importance of neighboring nodes in the embedding process.
- Graph Autoencoders: These use neural networks to encode graph information into embeddings and then decode the embeddings to reconstruct the graph. The idea is that a good embedding should allow for accurate graph reconstruction.
Node2Vec
Notes on the algorithm:
- Random Walk Generation: For each node in the graph, DeepWalk performs a series of random walks. The process is repeated multiple times to create a corpus of node sequences, similar to sentences.
- Return parameter (p): Controls the likelihood of revisiting the previous node. A high value of p discourages revisiting nodes, encouraging exploration.
- In-out parameter (q): Controls whether the random walk should explore the graph deeply (depth-first search, DFS) or broadly (breadth-first search, BFS):
- q > 1: Favors BFS, which explores neighbors that are close in proximity but not necessarily deeply connected.
- q < 1: Favors DFS, exploring nodes further away, possibly revealing more global structural information.
- Skip-Gram Model: Once these node sequences are generated, the skip-gram model (from Word2Vec) is trained to learn embeddings. The skip-gram model is trained to predict a node’s neighbors within a certain window (context) in the sequence, ensuring that nodes with similar neighbors have similar vector representations.
The key innovation of Node2Vec lies in its ability to tune the random walk process to explore the graph more effectively, capturing both homophily (local communities) and structural equivalence (nodes playing similar roles across the graph).
During the lab: clustering the Karate Club dataset.
Graph neural networks. Why?¶
Main benefits of GNNs over other graph embeddings:
Direct Aggregation of Node Features:
- GNNs explicitly aggregate information from neighboring nodes, learning how to combine features from the local neighborhood. This aggregation process allows nodes to learn not only their structural relationships but also meaningful patterns from their neighbors' feature vectors, leading to more informative embeddings.
- Word2Vec-based methods: These methods rely on random walks and treat nodes as if they were isolated entities without directly accounting for the features of the neighboring nodes during the embedding process. They are more focused on structural properties (node proximity), but they do not integrate node attributes effectively.
End-to-End Learning:
- GNNs: can be trained in an end-to-end fashion for specific tasks, such as node classification, link prediction, or graph classification. This means the model can be optimized directly for a target objective, resulting in embeddings tailored to the specific task.
- Word2Vec-based methods: These methods learn node embeddings in an unsupervised manner and cannot be directly optimized for downstream tasks. Once the embeddings are learned, they are typically used as input features for another machine learning model, creating a two-step process rather than end-to-end learning.
Handling Node and Edge Features:
- GNNs: GNNs are designed to handle graphs that have rich node and edge attributes. These attributes (such as node labels, node degrees, or edge weights) are incorporated into the learning process, allowing GNNs to capture more complex relationships within the graph.
- Word2Vec-based methods: Random walks and context windows used in Word2Vec-based methods focus mainly on the graph topology, without the ability to naturally incorporate additional node or edge attributes into the embedding process.
Other benefits:
- Capturing Higher-Order and Global Relationships: GNNs are capable of learning both local and global structural information. Deeper GNNs can capture higher-order neighborhood information and complex global relationships within the graph. While Node2Vec can adjust between local and global exploration using its parameters, it still primarily captures node proximity rather than higher-order, task-specific relationships.
- Graph-level Learning: GNNs are flexible enough to perform tasks at the node level (e.g., node classification) and graph level (e.g., graph classification, where the entire graph is classified). GNNs can aggregate node-level features into graph-level representations, which is important for tasks like molecular property prediction in chemistry or disease prediction in biological networks. Word2Vec-based methods are mainly focused on node-level tasks, such as node classification or link prediction, and do not naturally extend to graph-level tasks.
- Modeling Edge-Level Relationships: GNNs can model edge-level relationships and handle different types of edges, edge weights, or directed edges, allowing for more expressive representations of complex relationships in graphs.
- Explainability and Interpretability: By explicitly modeling feature aggregation and graph structure through multiple layers, GNNs can provide a more interpretable model. For example, attention mechanisms in Graph Attention Networks (GATs) can highlight which neighbors are most influential in determining a node’s representation.
GNNs, how are they studied?¶
- they are a type of geometric learning. Geometric learning refers to a broader class of machine learning techniques that handle data with geometric or topological structures, such as graphs, manifolds, or other non-Euclidean spaces.
- what is different from traditional NNs: Unlike traditional neural networks that work on grid-like data (e.g., images or sequences) with a regular, Euclidean structure, GNNs handle data that is structured as graphs. Graphs represent entities (nodes) and their relationships (edges) in irregular or complex topologies, which require geometric reasoning beyond simple Euclidean space.
- They learn on a different data structure: manifolds. A graph can be thought of as a discrete manifold where nodes exist in a space with connectivity defined by edges. The relationships among nodes are not defined by a simple Euclidean distance but by the graph's structure, which may involve complex geometry.
- Permutation Invariance/Equivariance: A key challenge in learning from graph data is that the nodes in a graph have no fixed ordering, unlike pixels in an image. GNNs are designed to respect this permutation invariance or equivariance, meaning the model's output is independent of the specific labeling of nodes. This is a geometric property because the model must learn to recognize patterns in the graph based on its inherent structure, not based on any arbitrary ordering.
- Convolutional Operations in Graphs: GNNs generalize the convolution operation used in Convolutional Neural Networks (CNNs) from grid-like data (e.g., images) to graphs. Graph convolutions rely on the local structure of the graph, taking into account the neighborhood of each node. This involves reasoning about the geometric relationships between nodes in the graph's structure.
- Geometric Priors and Invariance: GNNs leverage geometric priors such as locality (information passing occurs among neighbors) and global structures like connectivity, which are crucial for learning meaningful representations from the graph. These concepts are inherently geometric because they involve understanding the shape, structure, and relationships within the graph.
Graph Convolutional Networks (GCNs)¶
- Description: GCNs generalize convolutional operations to graph-structured data, focusing on neighborhood aggregation.
- Example Uses: Social network analysis, where the goal is to predict user behavior based on the connections and interactions between users. Predicting labels of nodes in a partially labeled graph (e.g., classifying users in a social network). Predicting whether an edge exists between two nodes (e.g., friend recommendation in social networks). Classifying entire graphs (e.g., classifying molecules based on their structure).
- Single-cell Transcriptomics: GCNs can model gene-gene interactions across different cells by representing the cells as nodes and the gene expression correlations as edges. This helps identify cell types and gene regulatory networks.
Wang, T., Bai, J. & Nabavi, S. Single-cell classification using graph convolutional networks. BMC Bioinformatics 22, 364 (2021). https://doi.org/10.1186/s12859-021-04278-2
Yuan, Y., Bar-Joseph, Z. GCNG: graph convolutional networks for inferring gene interaction from spatial transcriptomics data. Genome Biol 21, 300 (2020). https://doi.org/10.1186/s13059-020-02214-w
- Bulk RNA-Seq & omics integration: GCNs can be used to integrate bulk transcriptomic data with other omics data (like DNA methylation, protein, etc) by constructing a graph where nodes represent genes, and edges represent regulatory relationships, helping in uncovering gene expression regulation patterns.
Wang, T., Shao, W., Huang, Z. et al. MOGONET integrates multi-omics data using graph convolutional networks allowing patient classification and biomarker identification. Nat Commun 12, 3445 (2021). https://doi.org/10.1038/s41467-021-23774-w
Chuang, YH., Huang, SH., Hung, TM. et al. Convolutional neural network for human cancer types prediction by integrating protein interaction networks and omics data. Sci Rep 11, 20691 (2021). https://doi.org/10.1038/s41598-021-98814-y
Ji, R., Geng, Y. & Quan, X. Inferring gene regulatory networks with graph convolutional network based on causal feature reconstruction. Sci Rep 14, 21342 (2024). https://doi.org/10.1038/s41598-024-71864-8
Key Concepts in GCNs
- Graph Structure: A graph G is typically represented by a set of nodes V, a set of edges E, where each edge connects two nodes. From the graph structure one can derive an adjacency matrix A. Apart
- Node Features: Each node can have a feature vector associated with it. For example, in a single cell dataset the feature can be gene expression associated to a node.
- Message Passing and Aggregation: The core mechanism in GCNs is the aggregation of information from a node's neighbors to update the node's feature representation. This process, sometimes called "message passing," allows a node to gain a more informed representation based on its graph context.
GCN Layer (Graph Convolution Layer)
A GCN layer generalizes the concept of the standard convolutional layer. Instead of applying a kernel to a grid of pixels (as in CNNs), GCN layers perform a neighborhood aggregation operation.
For each layer in a GCN, the node feature matrix $ H $ is updated as follows: $$ H^{(l+1)} = \sigma\left( \hat{A} H^{(l)} W^{(l)} \right) $$ Where:
- $ H^{(l)} $ is the feature matrix at layer $ l $, with shape $ N \times D_l $ (i.e., $ N $ nodes and $ D_l $ features per node at layer $ l $).
- $ \hat{A} $ is the adjacency matrix of the graph, normalized to prevent the exploding or vanishing of feature values.
- $ W^{(l)} $ is the learnable weight matrix at layer $ l $.
- $ \sigma $ is an activation function, such as ReLU.
Main idea:
- Neighborhood Aggregation: Instead of using a fixed kernel, the graph structure dictates how information is shared between nodes.
- Weight Transformation: The aggregated features are transformed using a weight matrix, which is learned during training. This step is similar to the linear transformation in fully connected layers of neural networks.
- The activation function, such as ReLU, is applied to introduce non-linearity, allowing the model to capture complex patterns in the graph.
- Multi-Layer GCN: A GCN typically stacks multiple graph convolutional layers. Each layer aggregates more information from a larger neighborhood in the graph. After multiple layers, a node’s representation is influenced not only by its immediate neighbors but also by neighbors that are further away.
Graph Attention Networks (GAT)¶
- Description: GATs use attention mechanisms to learn the importance of neighboring nodes when aggregating information. This allows the model to focus on the most relevant neighbors when updating a node's features.
- Example Use: Citation networks, where a paper's relevance is predicted by selectively focusing on certain referenced papers.
- Spatial Transcriptomics, SCRNA-Seq: GATs can model spatial relationships between neighboring cells by treating each cell as a node, with attention mechanisms helping focus on spatially relevant neighboring cells. This allows for more accurate spatial gene expression patterns to emerge.
Sun H, Qu H, Duan K, Du W. scMGCN: A Multi-View Graph Convolutional Network for Cell Type Identification in scRNA-seq Data. Int J Mol Sci. 2024 Feb 13;25(4):2234. doi: 10.3390/ijms25042234. PMID: 38396909; PMCID: PMC10889820.
Yang, W., Wang, P., Xu, S. et al. Deciphering cell–cell communication at single-cell resolution for spatial transcriptomics with subgraph-based graph attention network. Nat Commun 15, 7101 (2024). https://doi.org/10.1038/s41467-024-51329-2
Zizhan Gao, Kai Cao, Lin Wan, Graspot: a graph attention network for spatial transcriptomics data integration with optimal transport, Bioinformatics, Volume 40, Issue Supplement_2, September 2024, Pages ii137–ii145, https://doi.org/10.1093/bioinformatics/btae394
Long, Y., Ang, K.S., Li, M. et al. Spatially informed clustering, integration, and deconvolution of spatial transcriptomics with GraphST. Nat Commun 14, 1155 (2023). https://doi.org/10.1038/s41467-023-36796-3
Guangyi Chen, Zhi-Ping Liu, Graph attention network for link prediction of gene regulations from single-cell RNA-sequencing data, Bioinformatics, Volume 38, Issue 19, October 2022, Pages 4522–4529, https://doi.org/10.1093/bioinformatics/btac559
Li, H., Han, Z., Sun, Y. et al. CGMega: explainable graph neural network framework with attention mechanisms for cancer gene module dissection. Nat Commun 15, 5997 (2024). https://doi.org/10.1038/s41467-024-50426-6
- Metagenomics: In metagenomic studies, GATs can integrate various layers of data, such as species abundance and functional annotations, by focusing on the most relevant species in a microbiome community, helping identify key microbes involved in disease.
Andre Lamurias, Mantas Sereika, Mads Albertsen, Katja Hose, Thomas Dyhre Nielsen, Metagenomic binning with assembly graph embeddings, Bioinformatics, Volume 38, Issue 19, October 2022, Pages 4481–4487, https://doi.org/10.1093/bioinformatics/btac557
Key concepts in GANs
Problem with Basic GNNs: In traditional GNNs, the information from neighboring nodes is aggregated uniformly, assuming all neighbors have equal importance. This may result in a loss of relevant information. Graph Attention Networks (GAT) are a type of neural network designed to process graph-structured data by leveraging the attention mechanism to assign different weights to neighboring nodes.
Input Features: Each node $ v_i \in V $ is associated with a feature vector $ h_i \in R^F $ , where $ F $ is the dimensionality of the node features.
Linear Transformation: First, apply a shared linear transformation to each node's feature vector: $ h_i' = W h_i $, where W is a learnable weight matrix, and f is the transformed feature vector for node i.
Attention Mechanism: For each node $ v_i $, compute attention coefficients for each of its neighboring nodes. The attention coefficient between nodes i and j is computed as a non-linear activation function (LeakyReLU) applied on a learnable attention vector and a concatenation of the two nodes transformed feature vectors ( the f - transformed feature vectors for node i and j respectively). Self-Attention: Attends to node pairs independently of the graph structure, making the model more flexible. $$ e_{ij} = \text{LeakyReLU}\left( a^T [ h_i' \, || \, h_j' ] \right) $$
Normalization of attention coefficients (Softmax) to make them comparable
$$ \alpha_{ij} = \frac{\exp(e_{ij})}{\sum_{k \in N(i)} \exp(e_{ik})} $$ where $ \alpha_{ij} $ represents the normalized attention coefficient, reflecting the importance of node j in the neighborhood of node i .Aggregation:
The node features are then updated by aggregating the transformed features of its neighbors, weighted by the attention coefficients: $$ h_i^{(l+1)} = \sigma \left( \sum_{j \in N(i)} \alpha_{ij} h_j' \right) $$ , where ( \sigma ) is a non-linear activation function such as ReLU.
- Multi-head Attention (optional): To stabilize the learning process and increase model capacity, GAT uses multi-head attention. This means the attention mechanism is applied multiple times in parallel, and the results are concatenated or averaged. For K attention heads, the output for node i after multi-head attention is: $$ h_i^{(l+1)} = \, ||_{k=1}^K \sigma \left( \sum_{j \in N(i)} \alpha_{ij}^{(k)} h_j'^{(k)} \right) $$ where $ ||_{k=1}^K $ denotes concatenation of the results from each attention head $ k $.
Other Graph NN types¶
Here’s a very incomplete classification of the other graph neural networks (GNNs) along with some example use for each type. Each type of GNN serves a specific purpose, based on the structure and nature of the graph data.
Graph Recurrent Networks (GRN)
- Description: GRNs introduce recurrent neural networks to handle graph sequences, useful for dynamic graphs.
- Example Use: Traffic forecasting, where the road network's state evolves over time, and predictions are made based on previous patterns.
- Single-cell Temporal Analysis: GRNs can capture time-evolving cellular states in dynamic single-cell experiments. For example, in developmental biology, GRNs can track the progression of cells over time by integrating temporal transcriptomics data.
- Spatial Transcriptomics Time-Series: For spatial data over time (e.g., tissue healing or disease progression), GRNs can model how spatial expression patterns change, helping understand how spatial gene expression correlates with temporal dynamics.
Graph Autoencoders (GAE)
- Description: GAEs learn embeddings by encoding graph structures and then reconstructing the graph from these embeddings.
- Example Use: Recommendation systems, where user-item interactions are represented as a graph and new recommendations are generated.
- Bulk and Single-cell Omics Integration: GAEs can be applied to integrate bulk and single-cell data by encoding both types of data into a unified low-dimensional space, capturing common biological signals while reconstructing multi-omics profiles. This can help in identifying shared molecular pathways across different scales.
Dong, K., Zhang, S. Deciphering spatial domains from spatially resolved transcriptomics with an adaptive graph attention auto-encoder. Nat Commun 13, 1739 (2022). https://doi.org/10.1038/s41467-022-29439-6
STGNNks: Identifying cell types in spatial transcriptomics data based on graph neural network, denoising auto-encoder, and k-sums clustering - ScienceDirect: https://www.sciencedirect.com/science/article/abs/pii/S0010482523009058
- Metagenomics: In metagenomics, GAEs can encode microbial community structures and gene abundance patterns to infer functional or ecological relationships between microbes, aiding in the identification of microbial communities driving disease.
Spatial-Temporal GNNs
- Description: These models integrate spatial and temporal information, capturing both graph structure and time-evolving patterns.
- Example Use: Predicting pedestrian movement in smart cities by analyzing how the flow of people changes over time.
- Spatial Transcriptomics: These GNNs can combine spatial gene expression and temporal data to model processes like tissue regeneration, tracking how gene expression in different regions changes over time. It can provide insights into cellular differentiation and tissue architecture.
- Single-cell Multi-omics: Spatial-temporal GNNs can integrate transcriptomics, epigenomics, and proteomics data from single cells, allowing for a detailed view of cellular dynamics, such as differentiation over time and space, important for developmental biology and cancer studies.
GraphSAGE
- Description: GraphSAGE samples a fixed-size neighborhood for each node and aggregates its information, allowing it to scale to larger graphs.
- Example Use: Large-scale molecular property prediction, where molecular structures are represented as graphs and properties are inferred from sampled subgraphs.
- Single-cell RNA-seq and Proteomics Integration: GraphSAGE can be used to integrate gene expression data with protein expression data by aggregating information from neighboring cells or samples. This can reveal the co-regulation of genes and proteins, helping to identify pathways active in specific cell types.
- Metagenomics and Bulk Omics: In metagenomics studies, GraphSAGE can scale to large datasets, integrating microbial abundance and functional genomics (like gene expression) across environments. This can help uncover how microbial communities influence host transcriptomes in different contexts, such as gut health or disease states.
Andre Lamurias, Mantas Sereika, Mads Albertsen, Katja Hose, Thomas Dyhre Nielsen, Metagenomic binning with assembly graph embeddings, Bioinformatics, Volume 38, Issue 19, October 2022, Pages 4481–4487, https://doi.org/10.1093/bioinformatics/btac557
Libraries, platforms, field reviews, etc:
- Wandy, J., Daly, R. GraphOmics: an interactive platform to explore and integrate multi-omics data. BMC Bioinformatics 22, 603 (2021). https://doi.org/10.1186/s12859-021-04500-1
- Comparative Analysis of Multi-Omics Integration Using Advanced Graph Neural Networks for Cancer Classification, Fadi Alharbi, Aleksandar Vakanski, Boyu Zhang, Murtada K. Elbashir, Mohanad Mohammed, https://doi.org/10.48550/arXiv.2410.05325
- Stanford Graph ML course: https://snap.stanford.edu/class/cs224w-2023/
- https://huggingface.co/blog/intro-graphml
- https://pytorch-geometric.readthedocs.io/en/latest/
- DGL — DGL 2.2.1 documentation: https://docs.dgl.ai/tutorials/models/index.html