- Authors

- Name
- Youngju Kim
- @fjvbn20031
Graph Neural Networks Complete Guide
Social networks, molecular structures, knowledge graphs, recommendation systems — countless real-world datasets are naturally represented as graphs. Graph Neural Networks (GNNs) are the core tool for applying deep learning to this non-Euclidean data. This guide systematically covers everything from graph theory fundamentals to the latest GNN architectures and hands-on implementations with PyTorch Geometric.
1. Graph Theory Fundamentals
Graph Definition
A graph G consists of a set of nodes V and a set of edges E, expressed as G = (V, E). Nodes represent entities, while edges represent relationships between entities.
- Nodes (Vertices): Represent entities. Examples: users, atoms, papers
- Edges: Represent relationships. Examples: friendships, chemical bonds, citations
- Node Features: Feature vectors attached to each node
- Edge Features: Feature vectors attached to each edge
Directed vs Undirected Graphs
import networkx as nx
import numpy as np
# Undirected Graph
G_undirected = nx.Graph()
G_undirected.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 0), (0, 2)])
# Directed Graph
G_directed = nx.DiGraph()
G_directed.add_edges_from([(0, 1), (1, 2), (2, 0), (0, 3)])
print(f"Undirected - Nodes: {G_undirected.number_of_nodes()}, Edges: {G_undirected.number_of_edges()}")
print(f"Directed - Nodes: {G_directed.number_of_nodes()}, Edges: {G_directed.number_of_edges()}")
Adjacency Matrix and Edge List
import torch
import numpy as np
# Adjacency Matrix
# A[i][j] = 1 means an edge exists between nodes i and j
adj_matrix = torch.tensor([
[0, 1, 1, 0],
[1, 0, 1, 0],
[1, 1, 0, 1],
[0, 0, 1, 0]
], dtype=torch.float32)
# Edge Index (used by PyG)
# shape: (2, num_edges) - first row: source nodes, second row: target nodes
edge_index = torch.tensor([
[0, 0, 1, 1, 2, 2, 2, 3], # Source nodes
[1, 2, 0, 2, 0, 1, 3, 2] # Target nodes
], dtype=torch.long)
print(f"Adjacency matrix shape: {adj_matrix.shape}") # (4, 4)
print(f"Edge list shape: {edge_index.shape}") # (2, 8)
# Convert adjacency matrix to edge index
def adj_to_edge_index(adj):
"""Convert adjacency matrix to edge index"""
row, col = torch.where(adj > 0)
return torch.stack([row, col], dim=0)
converted = adj_to_edge_index(adj_matrix)
print(f"Converted edge list:\n{converted}")
Graph Properties
import networkx as nx
import numpy as np
def analyze_graph(G):
"""Analyze key properties of a graph"""
# Degree
degrees = dict(G.degree())
avg_degree = np.mean(list(degrees.values()))
# Clustering Coefficient
clustering = nx.average_clustering(G)
# Average Path Length
if nx.is_connected(G):
avg_path = nx.average_shortest_path_length(G)
else:
# Use the largest connected component
largest_cc = max(nx.connected_components(G), key=len)
subgraph = G.subgraph(largest_cc)
avg_path = nx.average_shortest_path_length(subgraph)
# Centrality
betweenness = nx.betweenness_centrality(G)
pagerank = nx.pagerank(G)
print(f"Nodes: {G.number_of_nodes()}")
print(f"Edges: {G.number_of_edges()}")
print(f"Average degree: {avg_degree:.2f}")
print(f"Clustering coefficient: {clustering:.3f}")
print(f"Average path length: {avg_path:.2f}")
return {
"degrees": degrees,
"clustering": clustering,
"avg_path": avg_path,
"betweenness": betweenness,
"pagerank": pagerank
}
# Social network example (Karate Club)
G = nx.karate_club_graph()
stats = analyze_graph(G)
Real-World Graphs
| Domain | Nodes | Edges | Task |
|---|---|---|---|
| Social Network | Users | Friendships | Community detection |
| Molecular Structure | Atoms | Chemical bonds | Property prediction |
| Knowledge Graph | Entities | Relations | Link prediction |
| Citation Network | Papers | Citations | Node classification |
| Traffic Network | Intersections | Roads | Route prediction |
| Recommendation | Users/Items | Interactions | Recommendation |
2. Motivation for Graph Machine Learning
Why CNN/RNN Falls Short
Traditional CNNs assume a grid structure. Images work naturally because pixels are arranged in a regular 2D grid. RNNs assume sequential structure.
But graphs are:
- Irregular structure: Each node has a different number of neighbors
- Orderless: Permutation invariance of nodes
- Global dependencies: Distant nodes can still influence each other
# Illustrating graph data characteristics
# Images: fixed-size grids
image = torch.randn(3, 224, 224) # channels, height, width
# Sequences: ordered data
sequence = torch.randn(100, 512) # sequence length, feature dim
# Graphs: variable neighborhood structure
# Node features: (num_nodes, feature_dim)
node_features = torch.randn(34, 16) # 34 nodes, 16-dim features
# Edges: (2, num_edges) - sparse connections
edge_index = torch.randint(0, 34, (2, 78))
Message Passing Paradigm
The fundamental principle of all GNNs is message passing. Each node receives messages from its neighbors and updates its own representation.
The Message Passing Neural Network (MPNN) framework:
- Message computation: Compute message to send from node u to v along edge (u, v)
- Aggregation: Each node aggregates all incoming neighbor messages
- Update: Update the node representation using the aggregated message
m_v^(l) = AGGREGATE({h_u^(l-1) : u in N(v)})
h_v^(l) = UPDATE(h_v^(l-1), m_v^(l))
where N(v) is the set of neighbors of node v.
3. GNN Fundamental Equations
Aggregation and Update
import torch
import torch.nn as nn
from torch_scatter import scatter_mean, scatter_sum, scatter_max
def manual_message_passing(node_features, edge_index, aggregation="mean"):
"""
Manually implemented message passing
node_features: (N, F) - N nodes, F-dimensional features
edge_index: (2, E) - E edges
"""
src, dst = edge_index[0], edge_index[1]
num_nodes = node_features.size(0)
# Use source node features as messages
messages = node_features[src] # (E, F)
if aggregation == "mean":
aggregated = scatter_mean(messages, dst, dim=0, dim_size=num_nodes)
elif aggregation == "sum":
aggregated = scatter_sum(messages, dst, dim=0, dim_size=num_nodes)
elif aggregation == "max":
aggregated, _ = scatter_max(messages, dst, dim=0, dim_size=num_nodes)
# Update: original features + aggregated messages
updated = node_features + aggregated
return updated
# Example
N, F = 6, 8
node_features = torch.randn(N, F)
edge_index = torch.tensor([[0,1,2,3,4,0,1], [1,2,3,4,0,3,4]])
output = manual_message_passing(node_features, edge_index, "mean")
print(f"Input shape: {node_features.shape}")
print(f"Output shape: {output.shape}")
4. Key GNN Architectures
GCN (Graph Convolutional Network)
GCN, proposed by Kipf and Welling in 2017, derives an efficient layer-wise propagation rule starting from spectral graph theory.
Layer-wise propagation rule:
Using a normalized adjacency matrix: A_tilde = D^(-1/2) _ (A + I) _ D^(-1/2)
H^(l+1) = sigma(A_tilde _ H^(l) _ W^(l))
where D is the degree matrix, I is the identity matrix, and W is the learnable weight matrix.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
# Load dataset
dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=NormalizeFeatures())
data = dataset[0]
print(f"Nodes: {data.num_nodes}")
print(f"Edges: {data.num_edges}")
print(f"Node feature dim: {data.num_node_features}")
print(f"Num classes: {dataset.num_classes}")
print(f"Train nodes: {data.train_mask.sum().item()}")
print(f"Val nodes: {data.val_mask.sum().item()}")
print(f"Test nodes: {data.test_mask.sum().item()}")
class GCN(nn.Module):
"""Graph Convolutional Network"""
def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.5):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
self.dropout = dropout
def forward(self, x, edge_index):
# First GCN layer
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
# Second GCN layer
x = self.conv2(x, edge_index)
return x
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN(
in_channels=dataset.num_features,
hidden_channels=64,
out_channels=dataset.num_classes
).to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
def train_gcn():
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
return loss.item()
def test_gcn():
model.eval()
with torch.no_grad():
out = model(data.x, data.edge_index)
pred = out.argmax(dim=1)
results = {}
for split, mask in [("train", data.train_mask),
("val", data.val_mask),
("test", data.test_mask)]:
correct = pred[mask].eq(data.y[mask]).sum().item()
results[split] = correct / mask.sum().item()
return results
# Training loop
best_val_acc = 0
for epoch in range(200):
loss = train_gcn()
accs = test_gcn()
if accs["val"] > best_val_acc:
best_val_acc = accs["val"]
if (epoch + 1) % 50 == 0:
print(f"Epoch {epoch+1:03d} | Loss: {loss:.4f} | "
f"Train: {accs['train']:.4f} | Val: {accs['val']:.4f} | "
f"Test: {accs['test']:.4f}")
Manual GCN Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
class ManualGCNLayer(nn.Module):
"""Manual GCN layer implementation - for understanding internals"""
def __init__(self, in_features, out_features):
super().__init__()
self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
self.bias = nn.Parameter(torch.FloatTensor(out_features))
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, x, adj):
"""
x: node features (N, F_in)
adj: normalized adjacency matrix (N, N)
"""
# Linear transform: X * W
support = x @ self.weight
# Graph convolution: A_hat * X * W
output = adj @ support + self.bias
return output
@staticmethod
def normalize_adjacency(adj):
"""D^(-1/2) * A * D^(-1/2) normalization"""
# Add self-loops
N = adj.size(0)
adj_hat = adj + torch.eye(N, device=adj.device)
# Compute degree matrix
deg = adj_hat.sum(dim=1)
d_inv_sqrt = torch.diag(deg.pow(-0.5))
# Normalize
adj_normalized = d_inv_sqrt @ adj_hat @ d_inv_sqrt
return adj_normalized
GraphSAGE (Inductive Learning)
GraphSAGE is designed for inductive learning. It uses neighbor sampling to enable mini-batch training instead of processing the entire graph.
from torch_geometric.nn import SAGEConv
import torch
import torch.nn as nn
import torch.nn.functional as F
class GraphSAGE(nn.Module):
"""GraphSAGE - Inductive Representation Learning"""
def __init__(self, in_channels, hidden_channels, out_channels,
num_layers=3, dropout=0.5, aggr="mean"):
super().__init__()
self.dropout = dropout
self.convs = nn.ModuleList()
self.convs.append(SAGEConv(in_channels, hidden_channels, aggr=aggr))
for _ in range(num_layers - 2):
self.convs.append(SAGEConv(hidden_channels, hidden_channels, aggr=aggr))
self.convs.append(SAGEConv(hidden_channels, out_channels, aggr=aggr))
self.bns = nn.ModuleList([
nn.BatchNorm1d(hidden_channels)
for _ in range(num_layers - 1)
])
def forward(self, x, edge_index):
for i, conv in enumerate(self.convs[:-1]):
x = conv(x, edge_index)
x = self.bns[i](x)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.convs[-1](x, edge_index)
return x
# Mini-batch training with neighbor sampling
from torch_geometric.loader import NeighborLoader
# NeighborLoader: sample num_neighbors neighbors per layer
train_loader = NeighborLoader(
data,
num_neighbors=[25, 10], # 2-hop: 25 at 1st hop, 10 at 2nd hop
batch_size=256,
input_nodes=data.train_mask,
shuffle=True
)
model_sage = GraphSAGE(
in_channels=dataset.num_features,
hidden_channels=64,
out_channels=dataset.num_classes
).to(device)
optimizer_sage = torch.optim.Adam(model_sage.parameters(), lr=0.001)
def train_sage():
model_sage.train()
total_loss = 0
for batch in train_loader:
batch = batch.to(device)
optimizer_sage.zero_grad()
out = model_sage(batch.x, batch.edge_index)
# Only the first batch_size nodes are training nodes
loss = F.cross_entropy(out[:batch.batch_size], batch.y[:batch.batch_size])
loss.backward()
optimizer_sage.step()
total_loss += loss.item()
return total_loss / len(train_loader)
GAT (Graph Attention Network)
GAT uses attention mechanisms to assign different weights to each neighbor. It implements the intuition that "not all neighbors are equally important."
Attention coefficient computation:
Attention score: e_ij = LeakyReLU(a^T [Wh_i || Wh_j])
Softmax normalization: alpha_ij = exp(e_ij) / sum_k(exp(e_ik))
Update: h_i' = sigma(sum_j alpha_ij _ W _ h_j)
from torch_geometric.nn import GATConv, GATv2Conv
import torch
import torch.nn as nn
import torch.nn.functional as F
class GAT(nn.Module):
"""Graph Attention Network"""
def __init__(self, in_channels, hidden_channels, out_channels,
heads=8, dropout=0.6):
super().__init__()
self.dropout = dropout
# First layer: multi-head attention
self.conv1 = GATConv(
in_channels,
hidden_channels,
heads=heads,
dropout=dropout,
concat=True # Concatenate heads
)
# Second layer: average heads
self.conv2 = GATConv(
hidden_channels * heads,
out_channels,
heads=1,
dropout=dropout,
concat=False # Average heads
)
def forward(self, x, edge_index):
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.conv1(x, edge_index)
x = F.elu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.conv2(x, edge_index)
return x
class GATv2(nn.Module):
"""
GATv2 - Improved attention mechanism
GATv2 computes dynamic attention, giving higher expressiveness
"""
def __init__(self, in_channels, hidden_channels, out_channels,
heads=8, dropout=0.6):
super().__init__()
self.conv1 = GATv2Conv(
in_channels,
hidden_channels,
heads=heads,
dropout=dropout,
concat=True
)
self.conv2 = GATv2Conv(
hidden_channels * heads,
out_channels,
heads=1,
dropout=dropout,
concat=False
)
self.dropout = dropout
def forward(self, x, edge_index):
x = F.dropout(x, p=self.dropout, training=self.training)
x = F.elu(self.conv1(x, edge_index))
x = F.dropout(x, p=self.dropout, training=self.training)
return self.conv2(x, edge_index)
# GAT training
model_gat = GAT(
in_channels=dataset.num_features,
hidden_channels=8,
out_channels=dataset.num_classes,
heads=8
).to(device)
optimizer_gat = torch.optim.Adam(model_gat.parameters(), lr=0.005, weight_decay=5e-4)
Graph Transformer
Graph Transformer applies global Transformer attention to graphs.
from torch_geometric.nn import TransformerConv
import torch
import torch.nn as nn
import torch.nn.functional as F
class GraphTransformer(nn.Module):
"""Graph Transformer Layer"""
def __init__(self, in_channels, hidden_channels, out_channels,
heads=4, num_layers=3, dropout=0.3):
super().__init__()
self.dropout = dropout
self.convs = nn.ModuleList()
self.convs.append(
TransformerConv(in_channels, hidden_channels // heads, heads=heads,
dropout=dropout, beta=True)
)
for _ in range(num_layers - 2):
self.convs.append(
TransformerConv(hidden_channels, hidden_channels // heads,
heads=heads, dropout=dropout, beta=True)
)
self.convs.append(
TransformerConv(hidden_channels, out_channels // heads,
heads=heads, dropout=dropout, beta=True)
)
self.norms = nn.ModuleList([
nn.LayerNorm(hidden_channels) for _ in range(num_layers - 1)
])
def forward(self, x, edge_index):
for i, conv in enumerate(self.convs[:-1]):
x = conv(x, edge_index)
x = self.norms[i](x)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
return self.convs[-1](x, edge_index)
5. Graph-Level Prediction
While node classification predicts individual nodes, graph classification predicts entire graphs. For example: predicting whether a molecule is toxic.
Global Pooling
from torch_geometric.nn import (
global_mean_pool,
global_max_pool,
global_add_pool
)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class GraphClassifier(nn.Module):
"""Graph classification model"""
def __init__(self, in_channels, hidden_channels, out_channels,
num_layers=3, dropout=0.5, pooling="mean"):
super().__init__()
self.dropout = dropout
self.pooling = pooling
self.convs = nn.ModuleList()
self.convs.append(GCNConv(in_channels, hidden_channels))
for _ in range(num_layers - 1):
self.convs.append(GCNConv(hidden_channels, hidden_channels))
self.bns = nn.ModuleList([
nn.BatchNorm1d(hidden_channels) for _ in range(num_layers)
])
# Graph-level classifier
self.classifier = nn.Sequential(
nn.Linear(hidden_channels, hidden_channels),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_channels, out_channels)
)
def forward(self, x, edge_index, batch):
"""
batch: index vector indicating which graph each node belongs to
"""
# Node embedding
for conv, bn in zip(self.convs, self.bns):
x = conv(x, edge_index)
x = bn(x)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
# Graph-level pooling
if self.pooling == "mean":
x = global_mean_pool(x, batch)
elif self.pooling == "max":
x = global_max_pool(x, batch)
elif self.pooling == "sum":
x = global_add_pool(x, batch)
# Classification
return self.classifier(x)
DiffPool (Differentiable Pooling)
from torch_geometric.nn import dense_diff_pool
import torch
import torch.nn as nn
import torch.nn.functional as F
class DiffPoolLayer(nn.Module):
"""Hierarchical graph pooling"""
def __init__(self, in_channels, hidden_channels, num_clusters):
super().__init__()
# GNN for node embedding
self.gnn_embed = nn.Sequential(
nn.Linear(in_channels, hidden_channels),
nn.ReLU()
)
# GNN for cluster assignment
self.gnn_pool = nn.Sequential(
nn.Linear(in_channels, num_clusters),
)
def forward(self, x, adj, mask=None):
embed = self.gnn_embed(x)
# Cluster assignment matrix
s = torch.softmax(self.gnn_pool(x), dim=-1)
# DiffPool
out, out_adj, link_loss, entropy_loss = dense_diff_pool(embed, adj, s, mask)
return out, out_adj, link_loss, entropy_loss
6. Link Prediction
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.utils import negative_sampling
class LinkPredictor(nn.Module):
"""Link prediction model"""
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
# Node embedding encoder
self.encoder = nn.ModuleList([
GCNConv(in_channels, hidden_channels),
GCNConv(hidden_channels, out_channels)
])
# Edge decoder
self.decoder = nn.Sequential(
nn.Linear(out_channels * 2, out_channels),
nn.ReLU(),
nn.Linear(out_channels, 1)
)
def encode(self, x, edge_index):
for i, conv in enumerate(self.encoder):
x = conv(x, edge_index)
if i < len(self.encoder) - 1:
x = F.relu(x)
return x
def decode(self, z, edge_index):
# Concatenate source/target node embeddings
src, dst = edge_index
edge_feat = torch.cat([z[src], z[dst]], dim=1)
return self.decoder(edge_feat).squeeze()
def forward(self, x, edge_index, pos_edge_index, neg_edge_index):
z = self.encode(x, edge_index)
pos_pred = self.decode(z, pos_edge_index)
neg_pred = self.decode(z, neg_edge_index)
return pos_pred, neg_pred
def train_link_prediction(model, data, optimizer):
model.train()
optimizer.zero_grad()
# Node embedding
z = model.encode(data.x, data.edge_index)
# Positive edges
pos_edge = data.train_pos_edge_index
# Negative edge sampling
neg_edge = negative_sampling(
edge_index=pos_edge,
num_nodes=data.num_nodes,
num_neg_samples=pos_edge.size(1)
)
pos_pred = model.decode(z, pos_edge)
neg_pred = model.decode(z, neg_edge)
# Binary cross-entropy loss
pred = torch.cat([pos_pred, neg_pred])
labels = torch.cat([
torch.ones(pos_pred.size(0)),
torch.zeros(neg_pred.size(0))
]).to(pred.device)
loss = F.binary_cross_entropy_with_logits(pred, labels)
loss.backward()
optimizer.step()
return loss.item()
7. PyTorch Geometric (PyG) Complete Guide
Installation
pip install torch-geometric
pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
Data Object
from torch_geometric.data import Data
import torch
# Create graph data
x = torch.randn(6, 3) # 6 nodes, 3-dimensional features
edge_index = torch.tensor([
[0, 1, 2, 3, 4, 0],
[1, 2, 3, 4, 0, 3]
], dtype=torch.long)
y = torch.tensor([0, 1, 0, 1, 0, 1]) # Node labels
edge_attr = torch.randn(6, 2) # Edge features
data = Data(
x=x,
edge_index=edge_index,
y=y,
edge_attr=edge_attr
)
print(data)
print(f"Nodes: {data.num_nodes}")
print(f"Edges: {data.num_edges}")
print(f"Node feature dim: {data.num_node_features}")
print(f"Edge feature dim: {data.num_edge_features}")
print(f"Has self-loops: {data.has_self_loops()}")
print(f"Is directed: {data.is_directed()}")
# Validation
print(f"Valid data: {data.validate()}")
DataLoader and Mini-batching
from torch_geometric.data import Data, DataLoader
import torch
# Create graph dataset
dataset = []
for _ in range(100):
n = torch.randint(5, 20, (1,)).item() # 5-20 nodes
e = torch.randint(10, 40, (1,)).item() # 10-40 edges
data = Data(
x=torch.randn(n, 8),
edge_index=torch.randint(0, n, (2, e)),
y=torch.randint(0, 3, (1,)) # Graph label
)
dataset.append(data)
# DataLoader: batch multiple graphs into one disconnected graph
loader = DataLoader(dataset, batch_size=32, shuffle=True)
for batch in loader:
print(f"Number of graphs in batch: {batch.num_graphs}")
print(f"Total nodes: {batch.num_nodes}")
print(f"Total edges: {batch.num_edges}")
print(f"Batch vector: {batch.batch.shape}") # Graph index per node
break
Built-in Datasets
from torch_geometric.datasets import (
Planetoid, # Cora, Citeseer, PubMed
TUDataset, # Molecular datasets (MUTAG, ENZYMES, etc.)
)
from torch_geometric.transforms import NormalizeFeatures
# Cora citation network
cora = Planetoid(root='/tmp/Cora', name='Cora', transform=NormalizeFeatures())
print(f"Cora - Nodes: {cora[0].num_nodes}, Edges: {cora[0].num_edges}")
# MUTAG molecular dataset
mutag = TUDataset(root='/tmp/TUDataset', name='MUTAG')
print(f"MUTAG - Graphs: {len(mutag)}, Classes: {mutag.num_classes}")
# Open Graph Benchmark (large-scale)
try:
from ogb.nodeproppred import PygNodePropPredDataset
dataset_ogb = PygNodePropPredDataset(name='ogbn-arxiv')
split_idx = dataset_ogb.get_idx_split()
data_ogb = dataset_ogb[0]
print(f"OGB-Arxiv - Nodes: {data_ogb.num_nodes}, Edges: {data_ogb.num_edges}")
except ImportError:
print("ogb not installed. pip install ogb")
Complete Node Classification Pipeline
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, SAGEConv
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
# Load data
dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=NormalizeFeatures())
data = dataset[0]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = data.to(device)
class MultiLayerGNN(nn.Module):
"""Model combining multiple GNN layers"""
def __init__(self, in_channels, hidden_channels, out_channels,
gnn_type="gcn", num_layers=3, dropout=0.5):
super().__init__()
self.dropout = dropout
self.gnn_type = gnn_type
self.convs = nn.ModuleList()
self.bns = nn.ModuleList()
# Input layer
self.convs.append(self._make_conv(in_channels, hidden_channels, gnn_type))
self.bns.append(nn.BatchNorm1d(hidden_channels))
# Hidden layers
for _ in range(num_layers - 2):
self.convs.append(self._make_conv(hidden_channels, hidden_channels, gnn_type))
self.bns.append(nn.BatchNorm1d(hidden_channels))
# Output layer
self.convs.append(self._make_conv(hidden_channels, out_channels, gnn_type))
def _make_conv(self, in_ch, out_ch, gnn_type):
if gnn_type == "gcn":
return GCNConv(in_ch, out_ch)
elif gnn_type == "sage":
return SAGEConv(in_ch, out_ch)
elif gnn_type == "gat":
return GATConv(in_ch, out_ch, heads=1)
else:
raise ValueError(f"Unknown GNN type: {gnn_type}")
def forward(self, x, edge_index):
for i, conv in enumerate(self.convs[:-1]):
x = conv(x, edge_index)
x = self.bns[i](x)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
return self.convs[-1](x, edge_index)
def run_experiment(gnn_type, epochs=200):
model = MultiLayerGNN(
in_channels=dataset.num_features,
hidden_channels=64,
out_channels=dataset.num_classes,
gnn_type=gnn_type,
num_layers=3
).to(device)
optimizer = torch.optim.Adam(
model.parameters(), lr=0.01, weight_decay=5e-4
)
train_losses = []
val_accs = []
for epoch in range(epochs):
# Training
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
train_losses.append(loss.item())
# Evaluation
model.eval()
with torch.no_grad():
out = model(data.x, data.edge_index)
pred = out.argmax(dim=1)
val_acc = pred[data.val_mask].eq(data.y[data.val_mask]).sum().item()
val_acc /= data.val_mask.sum().item()
val_accs.append(val_acc)
# Final test
model.eval()
with torch.no_grad():
out = model(data.x, data.edge_index)
pred = out.argmax(dim=1)
test_acc = pred[data.test_mask].eq(data.y[data.test_mask]).sum().item()
test_acc /= data.test_mask.sum().item()
return test_acc, train_losses, val_accs
# Compare different GNNs
results = {}
for gnn_type in ["gcn", "sage", "gat"]:
test_acc, losses, val_accs = run_experiment(gnn_type)
results[gnn_type] = test_acc
print(f"{gnn_type.upper():10s}: Test Accuracy = {test_acc:.4f}")
Complete Graph Classification Example
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import (
GINConv, global_mean_pool, global_add_pool
)
import torch
import torch.nn as nn
import torch.nn.functional as F
# Load MUTAG dataset
dataset = TUDataset(root='/tmp/TUDataset', name='MUTAG')
dataset = dataset.shuffle()
# Train/test split
n = len(dataset)
train_dataset = dataset[:int(0.8 * n)]
test_dataset = dataset[int(0.8 * n):]
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)
class GIN(nn.Module):
"""
Graph Isomorphism Network (GIN) - most expressive GNN
Has stronger discriminative power than GCN
"""
def __init__(self, in_channels, hidden_channels, out_channels,
num_layers=5, dropout=0.5):
super().__init__()
self.dropout = dropout
self.convs = nn.ModuleList()
self.bns = nn.ModuleList()
for i in range(num_layers):
in_ch = in_channels if i == 0 else hidden_channels
# MLP for GIN
mlp = nn.Sequential(
nn.Linear(in_ch, hidden_channels),
nn.BatchNorm1d(hidden_channels),
nn.ReLU(),
nn.Linear(hidden_channels, hidden_channels)
)
self.convs.append(GINConv(mlp, train_eps=True))
self.bns.append(nn.BatchNorm1d(hidden_channels))
# Graph classifier
self.classifier = nn.Sequential(
nn.Linear(hidden_channels * num_layers, hidden_channels),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_channels, out_channels)
)
def forward(self, x, edge_index, batch):
# Store outputs from each layer
xs = []
for conv, bn in zip(self.convs, self.bns):
x = conv(x, edge_index)
x = bn(x)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
xs.append(global_add_pool(x, batch)) # Graph-level aggregation
# Concatenate graph representations from all layers
out = torch.cat(xs, dim=1)
return self.classifier(out)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_gin = GIN(
in_channels=dataset.num_features,
hidden_channels=64,
out_channels=dataset.num_classes
).to(device)
optimizer = torch.optim.Adam(model_gin.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)
def train_gin():
model_gin.train()
total_loss = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
out = model_gin(batch.x, batch.edge_index, batch.batch)
loss = F.cross_entropy(out, batch.y)
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(train_loader)
def test_gin(loader):
model_gin.eval()
correct = 0
for batch in loader:
batch = batch.to(device)
with torch.no_grad():
pred = model_gin(batch.x, batch.edge_index, batch.batch).argmax(dim=1)
correct += pred.eq(batch.y).sum().item()
return correct / len(loader.dataset)
for epoch in range(1, 201):
loss = train_gin()
train_acc = test_gin(train_loader)
test_acc = test_gin(test_loader)
scheduler.step()
if epoch % 20 == 0:
print(f"Epoch {epoch:03d} | Loss: {loss:.4f} | "
f"Train: {train_acc:.4f} | Test: {test_acc:.4f}")
8. DGL (Deep Graph Library) Comparison
# DGL example - comparison with PyG
# pip install dgl
try:
import dgl
import dgl.nn as dglnn
import torch
import torch.nn as nn
import torch.nn.functional as F
class DGLGCN(nn.Module):
"""GCN implemented with DGL"""
def __init__(self, in_feats, hidden_size, num_classes):
super().__init__()
self.conv1 = dglnn.GraphConv(in_feats, hidden_size)
self.conv2 = dglnn.GraphConv(hidden_size, num_classes)
def forward(self, g, features):
x = F.relu(self.conv1(g, features))
x = F.dropout(x, training=self.training)
return self.conv2(g, x)
# Create DGL graph
src = torch.tensor([0, 1, 2, 3, 4])
dst = torch.tensor([1, 2, 3, 4, 0])
g = dgl.graph((src, dst))
g.ndata['feat'] = torch.randn(5, 16)
model_dgl = DGLGCN(16, 32, 4)
out = model_dgl(g, g.ndata['feat'])
print(f"DGL GCN output: {out.shape}")
except ImportError:
print("DGL not installed. pip install dgl")
PyG vs DGL Comparison:
| Feature | PyTorch Geometric (PyG) | Deep Graph Library (DGL) |
|---|---|---|
| API style | PyTorch-native | Framework-agnostic |
| Data representation | edge_index (COO) | DGLGraph object |
| Speed | Very fast | Fast |
| Community | Large | Large |
| Available models | Very extensive | Extensive |
| Learning curve | Low | Medium |
9. Real-World Applications
Molecular Property Prediction (OGB)
try:
from ogb.graphproppred import PygGraphPropPredDataset, Evaluator
from torch_geometric.loader import DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GINEConv, global_mean_pool
# Load HIV molecule dataset
dataset_mol = PygGraphPropPredDataset(name='ogbg-molhiv')
split_idx = dataset_mol.get_idx_split()
train_loader_mol = DataLoader(
dataset_mol[split_idx["train"]],
batch_size=32,
shuffle=True
)
class MoleculeGNN(nn.Module):
"""Molecular property prediction model"""
def __init__(self, hidden_channels=300, num_layers=5):
super().__init__()
self.atom_encoder = nn.Embedding(100, hidden_channels)
self.bond_encoder = nn.Embedding(10, hidden_channels)
self.convs = nn.ModuleList()
for _ in range(num_layers):
mlp = nn.Sequential(
nn.Linear(hidden_channels, hidden_channels * 2),
nn.BatchNorm1d(hidden_channels * 2),
nn.ReLU(),
nn.Linear(hidden_channels * 2, hidden_channels)
)
self.convs.append(GINEConv(mlp))
self.pool = global_mean_pool
self.predictor = nn.Linear(hidden_channels, 1)
def forward(self, x, edge_index, edge_attr, batch):
x = self.atom_encoder(x.squeeze())
edge_attr = self.bond_encoder(edge_attr.squeeze())
for conv in self.convs:
x = conv(x, edge_index, edge_attr)
x = F.relu(x)
graph_embed = self.pool(x, batch)
return self.predictor(graph_embed)
print("OGB molecular dataset loaded successfully")
except ImportError:
print("ogb not installed. pip install ogb")
Recommendation System
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import LightGCN
class RecommendationSystem(nn.Module):
"""
LightGCN-based collaborative filtering
Learn embeddings on a user-item bipartite graph
"""
def __init__(self, num_users, num_items, embedding_dim=64, num_layers=3):
super().__init__()
self.num_users = num_users
self.num_items = num_items
self.embedding_dim = embedding_dim
self.num_layers = num_layers
# User/item embeddings
self.user_emb = nn.Embedding(num_users, embedding_dim)
self.item_emb = nn.Embedding(num_items, embedding_dim)
# LightGCN: simple aggregation without non-linear transforms
self.lightgcn = LightGCN(
num_nodes=num_users + num_items,
embedding_dim=embedding_dim,
num_layers=num_layers
)
self._init_weights()
def _init_weights(self):
nn.init.normal_(self.user_emb.weight, std=0.01)
nn.init.normal_(self.item_emb.weight, std=0.01)
def forward(self, edge_index):
# Full node embeddings
x = torch.cat([self.user_emb.weight, self.item_emb.weight], dim=0)
# LightGCN propagation
embeddings = self.lightgcn(x, edge_index)
return embeddings[:self.num_users], embeddings[self.num_users:]
def predict(self, user_ids, item_ids, edge_index):
user_embs, item_embs = self(edge_index)
u = user_embs[user_ids]
i = item_embs[item_ids]
return (u * i).sum(dim=1)
# BPR Loss
def bpr_loss(pos_scores, neg_scores):
"""Bayesian Personalized Ranking Loss"""
return -F.logsigmoid(pos_scores - neg_scores).mean()
10. Graph Generative Models
GraphVAE
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class GraphVAE(nn.Module):
"""Graph Variational Autoencoder"""
def __init__(self, in_channels, hidden_channels, latent_dim):
super().__init__()
# Encoder (GNN)
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv_mu = GCNConv(hidden_channels, latent_dim)
self.conv_logvar = GCNConv(hidden_channels, latent_dim)
def encode(self, x, edge_index):
h = F.relu(self.conv1(x, edge_index))
mu = self.conv_mu(h, edge_index)
logvar = self.conv_logvar(h, edge_index)
return mu, logvar
def reparameterize(self, mu, logvar):
if self.training:
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
return mu
def decode(self, z):
# Compute edge probabilities via inner product
adj_pred = torch.sigmoid(z @ z.t())
return adj_pred
def forward(self, x, edge_index):
mu, logvar = self.encode(x, edge_index)
z = self.reparameterize(mu, logvar)
adj_pred = self.decode(z)
return adj_pred, mu, logvar
def loss(self, adj_pred, adj_target, mu, logvar):
# Reconstruction loss
recon_loss = F.binary_cross_entropy(adj_pred, adj_target)
# KL divergence
kl_loss = -0.5 * torch.mean(
1 + logvar - mu.pow(2) - logvar.exp()
)
return recon_loss + kl_loss
Quiz
Q1. What is the main difference between GCN and GAT?
Answer: GCN aggregates all neighbors with fixed weights based on node degree, while GAT dynamically learns different attention weights for each neighbor.
Explanation: In GCN, the aggregation weights are fixed by the degree normalization of the adjacency matrix. In GAT, attention coefficients are computed dynamically based on the feature vectors of the two connected nodes. This allows GAT to assign higher weights to more important neighbors. Multi-head attention further improves stability, which is another advantage of GAT.
Q2. Why is GraphSAGE better suited for inductive learning than GCN?
Answer: GraphSAGE learns an aggregator function that can generate embeddings for new nodes by sampling and aggregating their neighbors.
Explanation: GCN requires the full adjacency matrix of the graph at training time, making it a transductive method that needs retraining when new nodes are added. GraphSAGE learns an aggregation function that samples and aggregates neighbor features, so it can generate embeddings for unseen nodes by applying the same function to their neighborhoods. This is why GraphSAGE is used in production systems like Pinterest and LinkedIn that deal with dynamically changing large-scale graphs.
Q3. What are the three stages of the Message Passing Neural Network (MPNN) framework?
Answer: Message (computing messages), Aggregate (aggregating messages), and Update (updating node embeddings).
Explanation: In the Message stage, a message is computed for each edge based on the source node features. In the Aggregate stage, each node collects all incoming messages from its neighbors using sum, mean, or max aggregation. In the Update stage, the aggregated message is combined with the current node embedding to produce a new embedding. Most GNN variants including GCN, GAT, GraphSAGE, and GIN can all be unified under this framework.
Q4. What is the over-smoothing problem in GNNs, and how can it be mitigated?
Answer: As GNN depth increases, all node embeddings converge to similar values. It can be mitigated with residual connections, JK-Net, or DropEdge.
Explanation: A K-layer GNN aggregates information from K-hop neighborhoods. As layers increase, increasingly larger neighborhoods are included, causing all node representations to converge toward the same global average. Residual connections preserve unique node information by directly passing previous layer outputs. Jumping Knowledge Networks (JK-Net) use embeddings from all layers in the final representation. DropEdge randomly removes edges during training to reduce neighbor overlap.
Q5. What does it mean that GNN expressiveness is equivalent to the WL test?
Answer: Standard GNNs embed two graphs identically if the Weisfeiler-Leman (WL) graph isomorphism test cannot distinguish them.
Explanation: The WL test is an algorithm for determining whether two graphs are isomorphic by iteratively aggregating and hashing neighbor labels. Xu et al. (2019) proved through the Graph Isomorphism Network (GIN) that standard GNNs are at most as powerful as the 1-WL test. This means GNNs cannot distinguish graph pairs that the WL test also fails on, such as regular graphs with the same degree sequence. To overcome this limitation, research is ongoing into higher-order k-WL equivalent GNNs, port numbering, and random feature augmentation.
Summary
This guide covered the entire GNN ecosystem:
- Graph Theory Fundamentals: Nodes, edges, adjacency matrices, graph properties
- Message Passing Paradigm: The core principle of GNNs
- Key Architectures: GCN, GraphSAGE, GAT, Graph Transformer, GIN
- Graph-Level Prediction: Global Pooling, DiffPool
- Link Prediction: Knowledge graphs, recommendation systems
- PyTorch Geometric: Complete node and graph classification examples
- Real-World Applications: Molecular design, recommendation systems, fraud detection
- Graph Generative Models: GraphVAE
GNNs are delivering revolutionary results in molecular design, drug discovery, social network analysis, traffic prediction, recommendation systems, and many other fields. Libraries like PyTorch Geometric and DGL make implementation increasingly accessible, and benchmarks like OGB enable fair comparisons across methods.