Skip to content
Published on

Graph Neural Networks Complete Guide: GCN, GAT, GraphSAGE to Molecular Design

Authors

Graph Neural Networks Complete Guide

Social networks, molecular structures, knowledge graphs, recommendation systems — countless real-world datasets are naturally represented as graphs. Graph Neural Networks (GNNs) are the core tool for applying deep learning to this non-Euclidean data. This guide systematically covers everything from graph theory fundamentals to the latest GNN architectures and hands-on implementations with PyTorch Geometric.

1. Graph Theory Fundamentals

Graph Definition

A graph G consists of a set of nodes V and a set of edges E, expressed as G = (V, E). Nodes represent entities, while edges represent relationships between entities.

  • Nodes (Vertices): Represent entities. Examples: users, atoms, papers
  • Edges: Represent relationships. Examples: friendships, chemical bonds, citations
  • Node Features: Feature vectors attached to each node
  • Edge Features: Feature vectors attached to each edge

Directed vs Undirected Graphs

import networkx as nx
import numpy as np

# Undirected Graph
G_undirected = nx.Graph()
G_undirected.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 0), (0, 2)])

# Directed Graph
G_directed = nx.DiGraph()
G_directed.add_edges_from([(0, 1), (1, 2), (2, 0), (0, 3)])

print(f"Undirected - Nodes: {G_undirected.number_of_nodes()}, Edges: {G_undirected.number_of_edges()}")
print(f"Directed - Nodes: {G_directed.number_of_nodes()}, Edges: {G_directed.number_of_edges()}")

Adjacency Matrix and Edge List

import torch
import numpy as np

# Adjacency Matrix
# A[i][j] = 1 means an edge exists between nodes i and j
adj_matrix = torch.tensor([
    [0, 1, 1, 0],
    [1, 0, 1, 0],
    [1, 1, 0, 1],
    [0, 0, 1, 0]
], dtype=torch.float32)

# Edge Index (used by PyG)
# shape: (2, num_edges) - first row: source nodes, second row: target nodes
edge_index = torch.tensor([
    [0, 0, 1, 1, 2, 2, 2, 3],  # Source nodes
    [1, 2, 0, 2, 0, 1, 3, 2]   # Target nodes
], dtype=torch.long)

print(f"Adjacency matrix shape: {adj_matrix.shape}")  # (4, 4)
print(f"Edge list shape: {edge_index.shape}")          # (2, 8)

# Convert adjacency matrix to edge index
def adj_to_edge_index(adj):
    """Convert adjacency matrix to edge index"""
    row, col = torch.where(adj > 0)
    return torch.stack([row, col], dim=0)

converted = adj_to_edge_index(adj_matrix)
print(f"Converted edge list:\n{converted}")

Graph Properties

import networkx as nx
import numpy as np

def analyze_graph(G):
    """Analyze key properties of a graph"""

    # Degree
    degrees = dict(G.degree())
    avg_degree = np.mean(list(degrees.values()))

    # Clustering Coefficient
    clustering = nx.average_clustering(G)

    # Average Path Length
    if nx.is_connected(G):
        avg_path = nx.average_shortest_path_length(G)
    else:
        # Use the largest connected component
        largest_cc = max(nx.connected_components(G), key=len)
        subgraph = G.subgraph(largest_cc)
        avg_path = nx.average_shortest_path_length(subgraph)

    # Centrality
    betweenness = nx.betweenness_centrality(G)
    pagerank = nx.pagerank(G)

    print(f"Nodes: {G.number_of_nodes()}")
    print(f"Edges: {G.number_of_edges()}")
    print(f"Average degree: {avg_degree:.2f}")
    print(f"Clustering coefficient: {clustering:.3f}")
    print(f"Average path length: {avg_path:.2f}")

    return {
        "degrees": degrees,
        "clustering": clustering,
        "avg_path": avg_path,
        "betweenness": betweenness,
        "pagerank": pagerank
    }

# Social network example (Karate Club)
G = nx.karate_club_graph()
stats = analyze_graph(G)

Real-World Graphs

DomainNodesEdgesTask
Social NetworkUsersFriendshipsCommunity detection
Molecular StructureAtomsChemical bondsProperty prediction
Knowledge GraphEntitiesRelationsLink prediction
Citation NetworkPapersCitationsNode classification
Traffic NetworkIntersectionsRoadsRoute prediction
RecommendationUsers/ItemsInteractionsRecommendation

2. Motivation for Graph Machine Learning

Why CNN/RNN Falls Short

Traditional CNNs assume a grid structure. Images work naturally because pixels are arranged in a regular 2D grid. RNNs assume sequential structure.

But graphs are:

  • Irregular structure: Each node has a different number of neighbors
  • Orderless: Permutation invariance of nodes
  • Global dependencies: Distant nodes can still influence each other
# Illustrating graph data characteristics
# Images: fixed-size grids
image = torch.randn(3, 224, 224)  # channels, height, width

# Sequences: ordered data
sequence = torch.randn(100, 512)  # sequence length, feature dim

# Graphs: variable neighborhood structure
# Node features: (num_nodes, feature_dim)
node_features = torch.randn(34, 16)  # 34 nodes, 16-dim features
# Edges: (2, num_edges) - sparse connections
edge_index = torch.randint(0, 34, (2, 78))

Message Passing Paradigm

The fundamental principle of all GNNs is message passing. Each node receives messages from its neighbors and updates its own representation.

The Message Passing Neural Network (MPNN) framework:

  1. Message computation: Compute message to send from node u to v along edge (u, v)
  2. Aggregation: Each node aggregates all incoming neighbor messages
  3. Update: Update the node representation using the aggregated message
m_v^(l) = AGGREGATE({h_u^(l-1) : u in N(v)})
h_v^(l) = UPDATE(h_v^(l-1), m_v^(l))

where N(v) is the set of neighbors of node v.

3. GNN Fundamental Equations

Aggregation and Update

import torch
import torch.nn as nn
from torch_scatter import scatter_mean, scatter_sum, scatter_max

def manual_message_passing(node_features, edge_index, aggregation="mean"):
    """
    Manually implemented message passing
    node_features: (N, F) - N nodes, F-dimensional features
    edge_index: (2, E) - E edges
    """
    src, dst = edge_index[0], edge_index[1]
    num_nodes = node_features.size(0)

    # Use source node features as messages
    messages = node_features[src]  # (E, F)

    if aggregation == "mean":
        aggregated = scatter_mean(messages, dst, dim=0, dim_size=num_nodes)
    elif aggregation == "sum":
        aggregated = scatter_sum(messages, dst, dim=0, dim_size=num_nodes)
    elif aggregation == "max":
        aggregated, _ = scatter_max(messages, dst, dim=0, dim_size=num_nodes)

    # Update: original features + aggregated messages
    updated = node_features + aggregated
    return updated

# Example
N, F = 6, 8
node_features = torch.randn(N, F)
edge_index = torch.tensor([[0,1,2,3,4,0,1], [1,2,3,4,0,3,4]])

output = manual_message_passing(node_features, edge_index, "mean")
print(f"Input shape: {node_features.shape}")
print(f"Output shape: {output.shape}")

4. Key GNN Architectures

GCN (Graph Convolutional Network)

GCN, proposed by Kipf and Welling in 2017, derives an efficient layer-wise propagation rule starting from spectral graph theory.

Layer-wise propagation rule:

Using a normalized adjacency matrix: A_tilde = D^(-1/2) _ (A + I) _ D^(-1/2)

H^(l+1) = sigma(A_tilde _ H^(l) _ W^(l))

where D is the degree matrix, I is the identity matrix, and W is the learnable weight matrix.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

# Load dataset
dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=NormalizeFeatures())
data = dataset[0]

print(f"Nodes: {data.num_nodes}")
print(f"Edges: {data.num_edges}")
print(f"Node feature dim: {data.num_node_features}")
print(f"Num classes: {dataset.num_classes}")
print(f"Train nodes: {data.train_mask.sum().item()}")
print(f"Val nodes: {data.val_mask.sum().item()}")
print(f"Test nodes: {data.test_mask.sum().item()}")


class GCN(nn.Module):
    """Graph Convolutional Network"""

    def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.5):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)
        self.dropout = dropout

    def forward(self, x, edge_index):
        # First GCN layer
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        # Second GCN layer
        x = self.conv2(x, edge_index)
        return x


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN(
    in_channels=dataset.num_features,
    hidden_channels=64,
    out_channels=dataset.num_classes
).to(device)
data = data.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)


def train_gcn():
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()


def test_gcn():
    model.eval()
    with torch.no_grad():
        out = model(data.x, data.edge_index)
        pred = out.argmax(dim=1)

    results = {}
    for split, mask in [("train", data.train_mask),
                         ("val", data.val_mask),
                         ("test", data.test_mask)]:
        correct = pred[mask].eq(data.y[mask]).sum().item()
        results[split] = correct / mask.sum().item()
    return results


# Training loop
best_val_acc = 0
for epoch in range(200):
    loss = train_gcn()
    accs = test_gcn()

    if accs["val"] > best_val_acc:
        best_val_acc = accs["val"]

    if (epoch + 1) % 50 == 0:
        print(f"Epoch {epoch+1:03d} | Loss: {loss:.4f} | "
              f"Train: {accs['train']:.4f} | Val: {accs['val']:.4f} | "
              f"Test: {accs['test']:.4f}")

Manual GCN Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F

class ManualGCNLayer(nn.Module):
    """Manual GCN layer implementation - for understanding internals"""

    def __init__(self, in_features, out_features):
        super().__init__()
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        self.bias = nn.Parameter(torch.FloatTensor(out_features))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x, adj):
        """
        x: node features (N, F_in)
        adj: normalized adjacency matrix (N, N)
        """
        # Linear transform: X * W
        support = x @ self.weight
        # Graph convolution: A_hat * X * W
        output = adj @ support + self.bias
        return output

    @staticmethod
    def normalize_adjacency(adj):
        """D^(-1/2) * A * D^(-1/2) normalization"""
        # Add self-loops
        N = adj.size(0)
        adj_hat = adj + torch.eye(N, device=adj.device)

        # Compute degree matrix
        deg = adj_hat.sum(dim=1)
        d_inv_sqrt = torch.diag(deg.pow(-0.5))

        # Normalize
        adj_normalized = d_inv_sqrt @ adj_hat @ d_inv_sqrt
        return adj_normalized

GraphSAGE (Inductive Learning)

GraphSAGE is designed for inductive learning. It uses neighbor sampling to enable mini-batch training instead of processing the entire graph.

from torch_geometric.nn import SAGEConv
import torch
import torch.nn as nn
import torch.nn.functional as F

class GraphSAGE(nn.Module):
    """GraphSAGE - Inductive Representation Learning"""

    def __init__(self, in_channels, hidden_channels, out_channels,
                 num_layers=3, dropout=0.5, aggr="mean"):
        super().__init__()
        self.dropout = dropout

        self.convs = nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels, aggr=aggr))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels, aggr=aggr))
        self.convs.append(SAGEConv(hidden_channels, out_channels, aggr=aggr))

        self.bns = nn.ModuleList([
            nn.BatchNorm1d(hidden_channels)
            for _ in range(num_layers - 1)
        ])

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            x = self.bns[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.convs[-1](x, edge_index)
        return x


# Mini-batch training with neighbor sampling
from torch_geometric.loader import NeighborLoader

# NeighborLoader: sample num_neighbors neighbors per layer
train_loader = NeighborLoader(
    data,
    num_neighbors=[25, 10],  # 2-hop: 25 at 1st hop, 10 at 2nd hop
    batch_size=256,
    input_nodes=data.train_mask,
    shuffle=True
)

model_sage = GraphSAGE(
    in_channels=dataset.num_features,
    hidden_channels=64,
    out_channels=dataset.num_classes
).to(device)

optimizer_sage = torch.optim.Adam(model_sage.parameters(), lr=0.001)

def train_sage():
    model_sage.train()
    total_loss = 0

    for batch in train_loader:
        batch = batch.to(device)
        optimizer_sage.zero_grad()
        out = model_sage(batch.x, batch.edge_index)
        # Only the first batch_size nodes are training nodes
        loss = F.cross_entropy(out[:batch.batch_size], batch.y[:batch.batch_size])
        loss.backward()
        optimizer_sage.step()
        total_loss += loss.item()

    return total_loss / len(train_loader)

GAT (Graph Attention Network)

GAT uses attention mechanisms to assign different weights to each neighbor. It implements the intuition that "not all neighbors are equally important."

Attention coefficient computation:

Attention score: e_ij = LeakyReLU(a^T [Wh_i || Wh_j])

Softmax normalization: alpha_ij = exp(e_ij) / sum_k(exp(e_ik))

Update: h_i' = sigma(sum_j alpha_ij _ W _ h_j)

from torch_geometric.nn import GATConv, GATv2Conv
import torch
import torch.nn as nn
import torch.nn.functional as F

class GAT(nn.Module):
    """Graph Attention Network"""

    def __init__(self, in_channels, hidden_channels, out_channels,
                 heads=8, dropout=0.6):
        super().__init__()
        self.dropout = dropout

        # First layer: multi-head attention
        self.conv1 = GATConv(
            in_channels,
            hidden_channels,
            heads=heads,
            dropout=dropout,
            concat=True  # Concatenate heads
        )

        # Second layer: average heads
        self.conv2 = GATConv(
            hidden_channels * heads,
            out_channels,
            heads=1,
            dropout=dropout,
            concat=False  # Average heads
        )

    def forward(self, x, edge_index):
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        return x


class GATv2(nn.Module):
    """
    GATv2 - Improved attention mechanism
    GATv2 computes dynamic attention, giving higher expressiveness
    """

    def __init__(self, in_channels, hidden_channels, out_channels,
                 heads=8, dropout=0.6):
        super().__init__()
        self.conv1 = GATv2Conv(
            in_channels,
            hidden_channels,
            heads=heads,
            dropout=dropout,
            concat=True
        )
        self.conv2 = GATv2Conv(
            hidden_channels * heads,
            out_channels,
            heads=1,
            dropout=dropout,
            concat=False
        )
        self.dropout = dropout

    def forward(self, x, edge_index):
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, p=self.dropout, training=self.training)
        return self.conv2(x, edge_index)


# GAT training
model_gat = GAT(
    in_channels=dataset.num_features,
    hidden_channels=8,
    out_channels=dataset.num_classes,
    heads=8
).to(device)

optimizer_gat = torch.optim.Adam(model_gat.parameters(), lr=0.005, weight_decay=5e-4)

Graph Transformer

Graph Transformer applies global Transformer attention to graphs.

from torch_geometric.nn import TransformerConv
import torch
import torch.nn as nn
import torch.nn.functional as F

class GraphTransformer(nn.Module):
    """Graph Transformer Layer"""

    def __init__(self, in_channels, hidden_channels, out_channels,
                 heads=4, num_layers=3, dropout=0.3):
        super().__init__()
        self.dropout = dropout

        self.convs = nn.ModuleList()
        self.convs.append(
            TransformerConv(in_channels, hidden_channels // heads, heads=heads,
                           dropout=dropout, beta=True)
        )

        for _ in range(num_layers - 2):
            self.convs.append(
                TransformerConv(hidden_channels, hidden_channels // heads,
                               heads=heads, dropout=dropout, beta=True)
            )

        self.convs.append(
            TransformerConv(hidden_channels, out_channels // heads,
                           heads=heads, dropout=dropout, beta=True)
        )

        self.norms = nn.ModuleList([
            nn.LayerNorm(hidden_channels) for _ in range(num_layers - 1)
        ])

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            x = self.norms[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        return self.convs[-1](x, edge_index)

5. Graph-Level Prediction

While node classification predicts individual nodes, graph classification predicts entire graphs. For example: predicting whether a molecule is toxic.

Global Pooling

from torch_geometric.nn import (
    global_mean_pool,
    global_max_pool,
    global_add_pool
)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GraphClassifier(nn.Module):
    """Graph classification model"""

    def __init__(self, in_channels, hidden_channels, out_channels,
                 num_layers=3, dropout=0.5, pooling="mean"):
        super().__init__()
        self.dropout = dropout
        self.pooling = pooling

        self.convs = nn.ModuleList()
        self.convs.append(GCNConv(in_channels, hidden_channels))
        for _ in range(num_layers - 1):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))

        self.bns = nn.ModuleList([
            nn.BatchNorm1d(hidden_channels) for _ in range(num_layers)
        ])

        # Graph-level classifier
        self.classifier = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_channels, out_channels)
        )

    def forward(self, x, edge_index, batch):
        """
        batch: index vector indicating which graph each node belongs to
        """
        # Node embedding
        for conv, bn in zip(self.convs, self.bns):
            x = conv(x, edge_index)
            x = bn(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        # Graph-level pooling
        if self.pooling == "mean":
            x = global_mean_pool(x, batch)
        elif self.pooling == "max":
            x = global_max_pool(x, batch)
        elif self.pooling == "sum":
            x = global_add_pool(x, batch)

        # Classification
        return self.classifier(x)

DiffPool (Differentiable Pooling)

from torch_geometric.nn import dense_diff_pool
import torch
import torch.nn as nn
import torch.nn.functional as F

class DiffPoolLayer(nn.Module):
    """Hierarchical graph pooling"""

    def __init__(self, in_channels, hidden_channels, num_clusters):
        super().__init__()
        # GNN for node embedding
        self.gnn_embed = nn.Sequential(
            nn.Linear(in_channels, hidden_channels),
            nn.ReLU()
        )
        # GNN for cluster assignment
        self.gnn_pool = nn.Sequential(
            nn.Linear(in_channels, num_clusters),
        )

    def forward(self, x, adj, mask=None):
        embed = self.gnn_embed(x)
        # Cluster assignment matrix
        s = torch.softmax(self.gnn_pool(x), dim=-1)
        # DiffPool
        out, out_adj, link_loss, entropy_loss = dense_diff_pool(embed, adj, s, mask)
        return out, out_adj, link_loss, entropy_loss
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.utils import negative_sampling

class LinkPredictor(nn.Module):
    """Link prediction model"""

    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        # Node embedding encoder
        self.encoder = nn.ModuleList([
            GCNConv(in_channels, hidden_channels),
            GCNConv(hidden_channels, out_channels)
        ])

        # Edge decoder
        self.decoder = nn.Sequential(
            nn.Linear(out_channels * 2, out_channels),
            nn.ReLU(),
            nn.Linear(out_channels, 1)
        )

    def encode(self, x, edge_index):
        for i, conv in enumerate(self.encoder):
            x = conv(x, edge_index)
            if i < len(self.encoder) - 1:
                x = F.relu(x)
        return x

    def decode(self, z, edge_index):
        # Concatenate source/target node embeddings
        src, dst = edge_index
        edge_feat = torch.cat([z[src], z[dst]], dim=1)
        return self.decoder(edge_feat).squeeze()

    def forward(self, x, edge_index, pos_edge_index, neg_edge_index):
        z = self.encode(x, edge_index)

        pos_pred = self.decode(z, pos_edge_index)
        neg_pred = self.decode(z, neg_edge_index)

        return pos_pred, neg_pred


def train_link_prediction(model, data, optimizer):
    model.train()
    optimizer.zero_grad()

    # Node embedding
    z = model.encode(data.x, data.edge_index)

    # Positive edges
    pos_edge = data.train_pos_edge_index

    # Negative edge sampling
    neg_edge = negative_sampling(
        edge_index=pos_edge,
        num_nodes=data.num_nodes,
        num_neg_samples=pos_edge.size(1)
    )

    pos_pred = model.decode(z, pos_edge)
    neg_pred = model.decode(z, neg_edge)

    # Binary cross-entropy loss
    pred = torch.cat([pos_pred, neg_pred])
    labels = torch.cat([
        torch.ones(pos_pred.size(0)),
        torch.zeros(neg_pred.size(0))
    ]).to(pred.device)

    loss = F.binary_cross_entropy_with_logits(pred, labels)
    loss.backward()
    optimizer.step()

    return loss.item()

7. PyTorch Geometric (PyG) Complete Guide

Installation

pip install torch-geometric
pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.1.0+cu121.html

Data Object

from torch_geometric.data import Data
import torch

# Create graph data
x = torch.randn(6, 3)           # 6 nodes, 3-dimensional features
edge_index = torch.tensor([
    [0, 1, 2, 3, 4, 0],
    [1, 2, 3, 4, 0, 3]
], dtype=torch.long)
y = torch.tensor([0, 1, 0, 1, 0, 1])   # Node labels
edge_attr = torch.randn(6, 2)           # Edge features

data = Data(
    x=x,
    edge_index=edge_index,
    y=y,
    edge_attr=edge_attr
)

print(data)
print(f"Nodes: {data.num_nodes}")
print(f"Edges: {data.num_edges}")
print(f"Node feature dim: {data.num_node_features}")
print(f"Edge feature dim: {data.num_edge_features}")
print(f"Has self-loops: {data.has_self_loops()}")
print(f"Is directed: {data.is_directed()}")

# Validation
print(f"Valid data: {data.validate()}")

DataLoader and Mini-batching

from torch_geometric.data import Data, DataLoader
import torch

# Create graph dataset
dataset = []
for _ in range(100):
    n = torch.randint(5, 20, (1,)).item()  # 5-20 nodes
    e = torch.randint(10, 40, (1,)).item()  # 10-40 edges
    data = Data(
        x=torch.randn(n, 8),
        edge_index=torch.randint(0, n, (2, e)),
        y=torch.randint(0, 3, (1,))  # Graph label
    )
    dataset.append(data)

# DataLoader: batch multiple graphs into one disconnected graph
loader = DataLoader(dataset, batch_size=32, shuffle=True)

for batch in loader:
    print(f"Number of graphs in batch: {batch.num_graphs}")
    print(f"Total nodes: {batch.num_nodes}")
    print(f"Total edges: {batch.num_edges}")
    print(f"Batch vector: {batch.batch.shape}")  # Graph index per node
    break

Built-in Datasets

from torch_geometric.datasets import (
    Planetoid,    # Cora, Citeseer, PubMed
    TUDataset,    # Molecular datasets (MUTAG, ENZYMES, etc.)
)
from torch_geometric.transforms import NormalizeFeatures

# Cora citation network
cora = Planetoid(root='/tmp/Cora', name='Cora', transform=NormalizeFeatures())
print(f"Cora - Nodes: {cora[0].num_nodes}, Edges: {cora[0].num_edges}")

# MUTAG molecular dataset
mutag = TUDataset(root='/tmp/TUDataset', name='MUTAG')
print(f"MUTAG - Graphs: {len(mutag)}, Classes: {mutag.num_classes}")

# Open Graph Benchmark (large-scale)
try:
    from ogb.nodeproppred import PygNodePropPredDataset
    dataset_ogb = PygNodePropPredDataset(name='ogbn-arxiv')
    split_idx = dataset_ogb.get_idx_split()
    data_ogb = dataset_ogb[0]
    print(f"OGB-Arxiv - Nodes: {data_ogb.num_nodes}, Edges: {data_ogb.num_edges}")
except ImportError:
    print("ogb not installed. pip install ogb")

Complete Node Classification Pipeline

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, SAGEConv
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

# Load data
dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=NormalizeFeatures())
data = dataset[0]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = data.to(device)


class MultiLayerGNN(nn.Module):
    """Model combining multiple GNN layers"""

    def __init__(self, in_channels, hidden_channels, out_channels,
                 gnn_type="gcn", num_layers=3, dropout=0.5):
        super().__init__()
        self.dropout = dropout
        self.gnn_type = gnn_type

        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()

        # Input layer
        self.convs.append(self._make_conv(in_channels, hidden_channels, gnn_type))
        self.bns.append(nn.BatchNorm1d(hidden_channels))

        # Hidden layers
        for _ in range(num_layers - 2):
            self.convs.append(self._make_conv(hidden_channels, hidden_channels, gnn_type))
            self.bns.append(nn.BatchNorm1d(hidden_channels))

        # Output layer
        self.convs.append(self._make_conv(hidden_channels, out_channels, gnn_type))

    def _make_conv(self, in_ch, out_ch, gnn_type):
        if gnn_type == "gcn":
            return GCNConv(in_ch, out_ch)
        elif gnn_type == "sage":
            return SAGEConv(in_ch, out_ch)
        elif gnn_type == "gat":
            return GATConv(in_ch, out_ch, heads=1)
        else:
            raise ValueError(f"Unknown GNN type: {gnn_type}")

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            x = self.bns[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        return self.convs[-1](x, edge_index)


def run_experiment(gnn_type, epochs=200):
    model = MultiLayerGNN(
        in_channels=dataset.num_features,
        hidden_channels=64,
        out_channels=dataset.num_classes,
        gnn_type=gnn_type,
        num_layers=3
    ).to(device)

    optimizer = torch.optim.Adam(
        model.parameters(), lr=0.01, weight_decay=5e-4
    )

    train_losses = []
    val_accs = []

    for epoch in range(epochs):
        # Training
        model.train()
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())

        # Evaluation
        model.eval()
        with torch.no_grad():
            out = model(data.x, data.edge_index)
            pred = out.argmax(dim=1)
            val_acc = pred[data.val_mask].eq(data.y[data.val_mask]).sum().item()
            val_acc /= data.val_mask.sum().item()
            val_accs.append(val_acc)

    # Final test
    model.eval()
    with torch.no_grad():
        out = model(data.x, data.edge_index)
        pred = out.argmax(dim=1)
        test_acc = pred[data.test_mask].eq(data.y[data.test_mask]).sum().item()
        test_acc /= data.test_mask.sum().item()

    return test_acc, train_losses, val_accs


# Compare different GNNs
results = {}
for gnn_type in ["gcn", "sage", "gat"]:
    test_acc, losses, val_accs = run_experiment(gnn_type)
    results[gnn_type] = test_acc
    print(f"{gnn_type.upper():10s}: Test Accuracy = {test_acc:.4f}")

Complete Graph Classification Example

from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import (
    GINConv, global_mean_pool, global_add_pool
)
import torch
import torch.nn as nn
import torch.nn.functional as F

# Load MUTAG dataset
dataset = TUDataset(root='/tmp/TUDataset', name='MUTAG')
dataset = dataset.shuffle()

# Train/test split
n = len(dataset)
train_dataset = dataset[:int(0.8 * n)]
test_dataset = dataset[int(0.8 * n):]

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)


class GIN(nn.Module):
    """
    Graph Isomorphism Network (GIN) - most expressive GNN
    Has stronger discriminative power than GCN
    """

    def __init__(self, in_channels, hidden_channels, out_channels,
                 num_layers=5, dropout=0.5):
        super().__init__()
        self.dropout = dropout

        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()

        for i in range(num_layers):
            in_ch = in_channels if i == 0 else hidden_channels
            # MLP for GIN
            mlp = nn.Sequential(
                nn.Linear(in_ch, hidden_channels),
                nn.BatchNorm1d(hidden_channels),
                nn.ReLU(),
                nn.Linear(hidden_channels, hidden_channels)
            )
            self.convs.append(GINConv(mlp, train_eps=True))
            self.bns.append(nn.BatchNorm1d(hidden_channels))

        # Graph classifier
        self.classifier = nn.Sequential(
            nn.Linear(hidden_channels * num_layers, hidden_channels),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_channels, out_channels)
        )

    def forward(self, x, edge_index, batch):
        # Store outputs from each layer
        xs = []
        for conv, bn in zip(self.convs, self.bns):
            x = conv(x, edge_index)
            x = bn(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            xs.append(global_add_pool(x, batch))  # Graph-level aggregation

        # Concatenate graph representations from all layers
        out = torch.cat(xs, dim=1)
        return self.classifier(out)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_gin = GIN(
    in_channels=dataset.num_features,
    hidden_channels=64,
    out_channels=dataset.num_classes
).to(device)

optimizer = torch.optim.Adam(model_gin.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)

def train_gin():
    model_gin.train()
    total_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        out = model_gin(batch.x, batch.edge_index, batch.batch)
        loss = F.cross_entropy(out, batch.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

def test_gin(loader):
    model_gin.eval()
    correct = 0
    for batch in loader:
        batch = batch.to(device)
        with torch.no_grad():
            pred = model_gin(batch.x, batch.edge_index, batch.batch).argmax(dim=1)
        correct += pred.eq(batch.y).sum().item()
    return correct / len(loader.dataset)

for epoch in range(1, 201):
    loss = train_gin()
    train_acc = test_gin(train_loader)
    test_acc = test_gin(test_loader)
    scheduler.step()

    if epoch % 20 == 0:
        print(f"Epoch {epoch:03d} | Loss: {loss:.4f} | "
              f"Train: {train_acc:.4f} | Test: {test_acc:.4f}")

8. DGL (Deep Graph Library) Comparison

# DGL example - comparison with PyG
# pip install dgl

try:
    import dgl
    import dgl.nn as dglnn
    import torch
    import torch.nn as nn
    import torch.nn.functional as F

    class DGLGCN(nn.Module):
        """GCN implemented with DGL"""
        def __init__(self, in_feats, hidden_size, num_classes):
            super().__init__()
            self.conv1 = dglnn.GraphConv(in_feats, hidden_size)
            self.conv2 = dglnn.GraphConv(hidden_size, num_classes)

        def forward(self, g, features):
            x = F.relu(self.conv1(g, features))
            x = F.dropout(x, training=self.training)
            return self.conv2(g, x)

    # Create DGL graph
    src = torch.tensor([0, 1, 2, 3, 4])
    dst = torch.tensor([1, 2, 3, 4, 0])
    g = dgl.graph((src, dst))
    g.ndata['feat'] = torch.randn(5, 16)

    model_dgl = DGLGCN(16, 32, 4)
    out = model_dgl(g, g.ndata['feat'])
    print(f"DGL GCN output: {out.shape}")

except ImportError:
    print("DGL not installed. pip install dgl")

PyG vs DGL Comparison:

FeaturePyTorch Geometric (PyG)Deep Graph Library (DGL)
API stylePyTorch-nativeFramework-agnostic
Data representationedge_index (COO)DGLGraph object
SpeedVery fastFast
CommunityLargeLarge
Available modelsVery extensiveExtensive
Learning curveLowMedium

9. Real-World Applications

Molecular Property Prediction (OGB)

try:
    from ogb.graphproppred import PygGraphPropPredDataset, Evaluator
    from torch_geometric.loader import DataLoader
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch_geometric.nn import GINEConv, global_mean_pool

    # Load HIV molecule dataset
    dataset_mol = PygGraphPropPredDataset(name='ogbg-molhiv')
    split_idx = dataset_mol.get_idx_split()

    train_loader_mol = DataLoader(
        dataset_mol[split_idx["train"]],
        batch_size=32,
        shuffle=True
    )

    class MoleculeGNN(nn.Module):
        """Molecular property prediction model"""
        def __init__(self, hidden_channels=300, num_layers=5):
            super().__init__()
            self.atom_encoder = nn.Embedding(100, hidden_channels)
            self.bond_encoder = nn.Embedding(10, hidden_channels)

            self.convs = nn.ModuleList()
            for _ in range(num_layers):
                mlp = nn.Sequential(
                    nn.Linear(hidden_channels, hidden_channels * 2),
                    nn.BatchNorm1d(hidden_channels * 2),
                    nn.ReLU(),
                    nn.Linear(hidden_channels * 2, hidden_channels)
                )
                self.convs.append(GINEConv(mlp))

            self.pool = global_mean_pool
            self.predictor = nn.Linear(hidden_channels, 1)

        def forward(self, x, edge_index, edge_attr, batch):
            x = self.atom_encoder(x.squeeze())
            edge_attr = self.bond_encoder(edge_attr.squeeze())

            for conv in self.convs:
                x = conv(x, edge_index, edge_attr)
                x = F.relu(x)

            graph_embed = self.pool(x, batch)
            return self.predictor(graph_embed)

    print("OGB molecular dataset loaded successfully")

except ImportError:
    print("ogb not installed. pip install ogb")

Recommendation System

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import LightGCN

class RecommendationSystem(nn.Module):
    """
    LightGCN-based collaborative filtering
    Learn embeddings on a user-item bipartite graph
    """

    def __init__(self, num_users, num_items, embedding_dim=64, num_layers=3):
        super().__init__()
        self.num_users = num_users
        self.num_items = num_items
        self.embedding_dim = embedding_dim
        self.num_layers = num_layers

        # User/item embeddings
        self.user_emb = nn.Embedding(num_users, embedding_dim)
        self.item_emb = nn.Embedding(num_items, embedding_dim)

        # LightGCN: simple aggregation without non-linear transforms
        self.lightgcn = LightGCN(
            num_nodes=num_users + num_items,
            embedding_dim=embedding_dim,
            num_layers=num_layers
        )

        self._init_weights()

    def _init_weights(self):
        nn.init.normal_(self.user_emb.weight, std=0.01)
        nn.init.normal_(self.item_emb.weight, std=0.01)

    def forward(self, edge_index):
        # Full node embeddings
        x = torch.cat([self.user_emb.weight, self.item_emb.weight], dim=0)
        # LightGCN propagation
        embeddings = self.lightgcn(x, edge_index)
        return embeddings[:self.num_users], embeddings[self.num_users:]

    def predict(self, user_ids, item_ids, edge_index):
        user_embs, item_embs = self(edge_index)
        u = user_embs[user_ids]
        i = item_embs[item_ids]
        return (u * i).sum(dim=1)


# BPR Loss
def bpr_loss(pos_scores, neg_scores):
    """Bayesian Personalized Ranking Loss"""
    return -F.logsigmoid(pos_scores - neg_scores).mean()

10. Graph Generative Models

GraphVAE

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GraphVAE(nn.Module):
    """Graph Variational Autoencoder"""

    def __init__(self, in_channels, hidden_channels, latent_dim):
        super().__init__()

        # Encoder (GNN)
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv_mu = GCNConv(hidden_channels, latent_dim)
        self.conv_logvar = GCNConv(hidden_channels, latent_dim)

    def encode(self, x, edge_index):
        h = F.relu(self.conv1(x, edge_index))
        mu = self.conv_mu(h, edge_index)
        logvar = self.conv_logvar(h, edge_index)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mu + eps * std
        return mu

    def decode(self, z):
        # Compute edge probabilities via inner product
        adj_pred = torch.sigmoid(z @ z.t())
        return adj_pred

    def forward(self, x, edge_index):
        mu, logvar = self.encode(x, edge_index)
        z = self.reparameterize(mu, logvar)
        adj_pred = self.decode(z)
        return adj_pred, mu, logvar

    def loss(self, adj_pred, adj_target, mu, logvar):
        # Reconstruction loss
        recon_loss = F.binary_cross_entropy(adj_pred, adj_target)

        # KL divergence
        kl_loss = -0.5 * torch.mean(
            1 + logvar - mu.pow(2) - logvar.exp()
        )

        return recon_loss + kl_loss

Quiz

Q1. What is the main difference between GCN and GAT?

Answer: GCN aggregates all neighbors with fixed weights based on node degree, while GAT dynamically learns different attention weights for each neighbor.

Explanation: In GCN, the aggregation weights are fixed by the degree normalization of the adjacency matrix. In GAT, attention coefficients are computed dynamically based on the feature vectors of the two connected nodes. This allows GAT to assign higher weights to more important neighbors. Multi-head attention further improves stability, which is another advantage of GAT.

Q2. Why is GraphSAGE better suited for inductive learning than GCN?

Answer: GraphSAGE learns an aggregator function that can generate embeddings for new nodes by sampling and aggregating their neighbors.

Explanation: GCN requires the full adjacency matrix of the graph at training time, making it a transductive method that needs retraining when new nodes are added. GraphSAGE learns an aggregation function that samples and aggregates neighbor features, so it can generate embeddings for unseen nodes by applying the same function to their neighborhoods. This is why GraphSAGE is used in production systems like Pinterest and LinkedIn that deal with dynamically changing large-scale graphs.

Q3. What are the three stages of the Message Passing Neural Network (MPNN) framework?

Answer: Message (computing messages), Aggregate (aggregating messages), and Update (updating node embeddings).

Explanation: In the Message stage, a message is computed for each edge based on the source node features. In the Aggregate stage, each node collects all incoming messages from its neighbors using sum, mean, or max aggregation. In the Update stage, the aggregated message is combined with the current node embedding to produce a new embedding. Most GNN variants including GCN, GAT, GraphSAGE, and GIN can all be unified under this framework.

Q4. What is the over-smoothing problem in GNNs, and how can it be mitigated?

Answer: As GNN depth increases, all node embeddings converge to similar values. It can be mitigated with residual connections, JK-Net, or DropEdge.

Explanation: A K-layer GNN aggregates information from K-hop neighborhoods. As layers increase, increasingly larger neighborhoods are included, causing all node representations to converge toward the same global average. Residual connections preserve unique node information by directly passing previous layer outputs. Jumping Knowledge Networks (JK-Net) use embeddings from all layers in the final representation. DropEdge randomly removes edges during training to reduce neighbor overlap.

Q5. What does it mean that GNN expressiveness is equivalent to the WL test?

Answer: Standard GNNs embed two graphs identically if the Weisfeiler-Leman (WL) graph isomorphism test cannot distinguish them.

Explanation: The WL test is an algorithm for determining whether two graphs are isomorphic by iteratively aggregating and hashing neighbor labels. Xu et al. (2019) proved through the Graph Isomorphism Network (GIN) that standard GNNs are at most as powerful as the 1-WL test. This means GNNs cannot distinguish graph pairs that the WL test also fails on, such as regular graphs with the same degree sequence. To overcome this limitation, research is ongoing into higher-order k-WL equivalent GNNs, port numbering, and random feature augmentation.


Summary

This guide covered the entire GNN ecosystem:

  1. Graph Theory Fundamentals: Nodes, edges, adjacency matrices, graph properties
  2. Message Passing Paradigm: The core principle of GNNs
  3. Key Architectures: GCN, GraphSAGE, GAT, Graph Transformer, GIN
  4. Graph-Level Prediction: Global Pooling, DiffPool
  5. Link Prediction: Knowledge graphs, recommendation systems
  6. PyTorch Geometric: Complete node and graph classification examples
  7. Real-World Applications: Molecular design, recommendation systems, fraud detection
  8. Graph Generative Models: GraphVAE

GNNs are delivering revolutionary results in molecular design, drug discovery, social network analysis, traffic prediction, recommendation systems, and many other fields. Libraries like PyTorch Geometric and DGL make implementation increasingly accessible, and benchmarks like OGB enable fair comparisons across methods.

References