Skip to content
Published on

グラフニューラルネットワーク完全ガイド: GCNからGraph Transformerまで

Authors

はじめに

ソーシャルネットワーク、分子構造、知識グラフ、交通ネットワーク — 現実世界の多くのデータは**グラフ(Graph)**構造を持っています。従来のCNNやRNNはグリッド(格子)や系列(シーケンス)データ向けに設計されているため、このような非ユークリッド(non-Euclidean)データの処理に限界があります。

グラフニューラルネットワーク(Graph Neural Network, GNN) はこの問題を解決するために登場しました。GNNはノード(node)、エッジ(edge)、グラフ全体の各レベルで表現を学習し、分子特性の予測から推薦システム、物理シミュレーションまで幅広く活用されています。

このガイドでは、グラフ理論の基礎から最新のGraph Transformerまでを段階的に解説し、PyTorch Geometricを用いた実際の実装まで網羅します。


1. グラフ理論の基礎

グラフの定義

グラフ G はノード(頂点)集合 V とエッジ(辺)集合 E で定義されます。

G=(V,E)G = (V, E)

  • ノード(Node/Vertex): グラフの個体。例: ユーザー、原子、Webページ
  • エッジ(Edge): ノード間の関係。例: 友人関係、化学結合、ハイパーリンク
  • 有向グラフ(Directed Graph): エッジに方向がある (例: Twitterのフォロー)
  • 無向グラフ(Undirected Graph): エッジに方向がない (例: Facebookの友達)

隣接行列 (Adjacency Matrix)

ノード数が N のグラフの隣接行列 A は N×N の行列です。ノード i と j の間にエッジがあれば A[i][j] = 1、なければ 0 となります。

Aij={1if (i,j)E0otherwiseA_{ij} = \begin{cases} 1 & \text{if } (i, j) \in E \\ 0 & \text{otherwise} \end{cases}

次数行列 (Degree Matrix)

次数行列 D は対角行列で、各対角要素 D[i][i] はノード i の次数(接続されているエッジの数)です。

Dii=jAijD_{ii} = \sum_j A_{ij}

ラプラシアン行列 (Laplacian Matrix)

グラフのラプラシアンは、グラフのスペクトル分析の核心となる行列です。

L=DAL = D - A

正規化ラプラシアンは次のように定義されます。

L^=D1/2LD1/2=ID1/2AD1/2\hat{L} = D^{-1/2} L D^{-1/2} = I - D^{-1/2} A D^{-1/2}

ラプラシアンの固有値(eigenvalue)と固有ベクトル(eigenvector)はグラフのスペクトル特性を表し、GCNの理論的基盤となります。

import numpy as np
import networkx as nx

# グラフの作成
G = nx.karate_club_graph()

# 隣接行列の取得
A = nx.adjacency_matrix(G).toarray().astype(float)

# 次数行列の計算
degrees = np.array(A.sum(axis=1)).flatten()
D = np.diag(degrees)

# ラプラシアン行列
L = D - A
print(f"ラプラシアン行列の形状: {L.shape}")
print(f"最小固有値: {np.linalg.eigvalsh(L).min():.4f}")
print(f"最大固有値: {np.linalg.eigvalsh(L).max():.4f}")

2. なぜグラフニューラルネットワークなのか?

非ユークリッドデータの特性

画像やテキストと異なり、グラフデータには以下のような特性があります。

  • 可変サイズ: ノードとエッジの数がグラフごとに異なる
  • 順序がない: ノードに固定された順序がない (置換不変性: permutation invariance)
  • 不規則な近傍: 各ノードの近傍ノード数が異なる

グラフMLの主要タスク

タスク入力出力
ノード分類論文カテゴリ予測グラフ + 一部ラベル未ラベルノードのクラス
リンク予測友達推薦グラフエッジの存在確率
グラフ分類分子毒性予測グラフの集合グラフラベル
グラフ生成新薬設計条件情報新しいグラフ

CNNとの比較

CNNはグリッド構造の画像に適しています。すべてのピクセルが同じ数の近傍(3×3、5×5など)を持つためです。一方、グラフは各ノードの近傍数と構造が異なるため、固定されたカーネルサイズを適用できません。GNNはこれをメッセージパッシング(message passing)方式で解決します。


3. Graph Convolutional Network (GCN)

スペクトルグラフ畳み込み

GCN (Kipf & Welling, 2017) はスペクトルグラフ理論から出発します。グラフ信号 x に対する畳み込みはグラフフーリエ領域で定義され、ラプラシアンの固有ベクトルを基底として使用します。

Kipf & Welling は計算複雑度を下げるために1次近似に単純化しました。レイヤーごとの伝播則は次の通りです。

H(l+1)=σ(D~1/2A~D~1/2H(l)W(l))H^{(l+1)} = \sigma\left(\tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2} H^{(l)} W^{(l)}\right)

ここで:

  • A~=A+I\tilde{A} = A + I (自己ループを追加した隣接行列)
  • D~\tilde{D}A~\tilde{A} の次数行列
  • H(l)H^{(l)} は第 l 層のノード埋め込み
  • W(l)W^{(l)} は学習可能な重み行列
  • σ\sigma は活性化関数

PyTorch Geometricによる実装

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

# データセットの読み込み
dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=NormalizeFeatures())
data = dataset[0]

print(f"ノード数: {data.num_nodes}")
print(f"エッジ数: {data.num_edges}")
print(f"ノード特徴量の次元: {data.num_node_features}")
print(f"クラス数: {dataset.num_classes}")

class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GCNConv(dataset.num_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

model = GCN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss

# 学習の実行
for epoch in range(200):
    loss = train()
    if epoch % 20 == 0:
        print(f'Epoch {epoch}, Loss: {loss:.4f}')

Coraデータセット(2708本の論文、5429件の引用リンク)では、GCNは約81%のノード分類精度を達成します。


4. Graph Attention Network (GAT)

アテンションメカニズムのグラフへの応用

GAT (Velickovic et al., 2018) はGCNの限界を克服します。GCNはすべての近傍ノードを同等に扱いますが、GATはアテンション(attention)メカニズムを通じて各近傍の重要度を動的に計算します。

ノード i からノード j へのアテンション係数は次のように計算されます。

eij=LeakyReLU(aT[WhiWhj])e_{ij} = \text{LeakyReLU}\left(\vec{a}^T [W\vec{h}_i \| W\vec{h}_j]\right)

Softmax正規化後:

αij=exp(eij)kN(i)exp(eik)\alpha_{ij} = \frac{\exp(e_{ij})}{\sum_{k \in \mathcal{N}(i)} \exp(e_{ik})}

最終的なノード表現:

hi=σ(jN(i)αijWhj)\vec{h}_i' = \sigma\left(\sum_{j \in \mathcal{N}(i)} \alpha_{ij} W \vec{h}_j\right)

マルチヘッドアテンション

GATはK個の独立したアテンションヘッドを使用して表現力を高めます。

PyTorch Geometricによる実装

from torch_geometric.nn import GATConv

class GAT(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GATConv(dataset.num_features, 8, heads=8, dropout=0.6)
        self.conv2 = GATConv(8 * 8, dataset.num_classes, heads=1, dropout=0.6)

    def forward(self, x, edge_index):
        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

model = GAT()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)

GATはCoraで約83%の精度を達成し、GCNを若干上回ります。


5. GraphSAGE

帰納的学習 (Inductive Learning)

GCNとGATは**トランスダクティブ(transductive)学習方式です。つまり、学習時に使用したグラフのノードに対してのみ予測できます。GraphSAGE (Hamilton et al., 2017) は帰納的(inductive)**学習をサポートし、学習時に見たことのない新しいノードにも汎化できます。

近傍サンプリング

大規模グラフですべての近傍を使用すると計算量が爆発的に増加します。GraphSAGEは近傍を固定数だけサンプリングします。

集約関数

GraphSAGEはさまざまな集約関数をサポートします。

平均集約 (Mean Aggregator):

hvk=σ(WMEAN({hvk1}{huk1:uN(v)}))h_v^k = \sigma\left(W \cdot \text{MEAN}(\{h_v^{k-1}\} \cup \{h_u^{k-1} : u \in \mathcal{N}(v)\})\right)

from torch_geometric.nn import SAGEConv
from torch_geometric.loader import NeighborLoader

class GraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

# 近傍サンプリングによるミニバッチ学習
train_loader = NeighborLoader(
    data,
    num_neighbors=[25, 10],
    batch_size=256,
    input_nodes=data.train_mask,
    shuffle=True
)

6. Message Passing Neural Networks (MPNN)

統合フレームワーク

Gilmer et al. (2017) はさまざまなGNN変種を**メッセージパッシングニューラルネットワーク(MPNN)**フレームワークで統合しました。このフレームワークは3つの段階で構成されます。

ステップ1 - メッセージ (Message): 各エッジのメッセージを計算

mijt=Mt(hit,hjt,eij)m_{ij}^t = M_t(h_i^t, h_j^t, e_{ij})

ステップ2 - 集約 (Aggregate): 近傍からのメッセージを集約

mit=jN(i)mijtm_i^t = \sum_{j \in \mathcal{N}(i)} m_{ij}^t

ステップ3 - 更新 (Update): ノード埋め込みを更新

hit+1=Ut(hit,mit)h_i^{t+1} = U_t(h_i^t, m_i^t)

カスタムMPNNの実装

PyTorch Geometricの MessagePassing 基底クラスを使ってカスタムGNNを実装できます。

from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
import torch

class CustomGNNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # 自己ループの追加
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # 正規化係数の計算
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):
        # メッセージの計算: 正規化された近傍特徴量
        return norm.view(-1, 1) * x_j

    def update(self, aggr_out):
        # 更新: 線形変換
        return self.lin(aggr_out)

7. Graph Transformer

TransformerとグラフのFusion

TransformerのSelf-Attentionはすべてのトークンペア間の関係を考慮します。これをグラフに適用すると、隣接ノードだけでなくグラフ全体のノード間の関係を学習できます。

Graphormer (Ying et al., Microsoft Research, ICLR 2022) はこのアイデアを発展させ、分子グラフで優れた性能を示しました。

構造エンコーディング

Transformerには位置エンコーディングが必要です。グラフでは次のような構造情報をエンコードします。

  • 次数エンコーディング (Degree Encoding): ノードの次数情報
  • 空間エンコーディング (Spatial Encoding): ノード間の最短距離
  • エッジエンコーディング (Edge Encoding): エッジタイプと特性

ラプラシアン固有ベクトルに基づく位置エンコーディング:

グラフラプラシアンの固有ベクトルを位置エンコーディングとして使用します。これによりグラフの幾何学的構造をエンコードできます。

import torch
import torch.nn as nn
from torch_geometric.nn import TransformerConv

class GraphTransformer(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=4):
        super().__init__()
        self.conv1 = TransformerConv(
            in_channels, hidden_channels, heads=heads, dropout=0.1
        )
        self.conv2 = TransformerConv(
            hidden_channels * heads, out_channels, heads=1, concat=False
        )
        self.norm1 = nn.LayerNorm(hidden_channels * heads)
        self.norm2 = nn.LayerNorm(out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = self.norm1(x)
        x = F.elu(x)
        x = F.dropout(x, p=0.1, training=self.training)
        x = self.conv2(x, edge_index)
        x = self.norm2(x)
        return F.log_softmax(x, dim=1)

8. 分子ML (Molecular Machine Learning)

化学分子をグラフとして表現

分子は原子(ノード)と化学結合(エッジ)からなる自然なグラフです。

  • ノード特徴量: 原子番号、電気陰性度、混成状態、芳香族かどうか
  • エッジ特徴量: 結合タイプ(単結合、二重結合、三重結合)、芳香族結合かどうか

SMILESとRDKit

SMILES (Simplified Molecular Input Line Entry System) は分子を文字列として表現する形式です。RDKitとPyTorch Geometricを使ってSMILESをグラフに変換できます。

from rdkit import Chem
from torch_geometric.utils import from_smiles

# SMILESをグラフに変換
mol = from_smiles('CC(=O)Oc1ccccc1C(=O)O')  # アスピリン
print(f"ノード(原子)数: {mol.num_nodes}")
print(f"エッジ(結合)数: {mol.num_edges}")
print(f"ノード特徴量の次元: {mol.x.shape}")

MoleculeNetベンチマーク

MoleculeNetは分子MLの標準ベンチマークです。

データセットタスク分子数タスク数
BBBP血液脳関門透過性20391
Tox21毒性予測783112
ESOL溶解度予測11281
QM9量子化学特性13388519

創薬(Drug Discovery)におけるGNN

GNNは次のような創薬タスクに活用されます。

  • 分子特性予測: 毒性、溶解度、結合親和性の予測
  • 分子生成: 望ましい特性を持つ新しい分子の設計 (例: Junction Tree VAE)
  • 仮想スクリーニング: 数百万の分子から薬物候補のフィルタリング
  • タンパク質-リガンド相互作用: 結合エネルギーの予測

9. 知識グラフ (Knowledge Graph)

トリプル構造

知識グラフは (head, relation, tail) 形式のトリプルで知識を表現します。例: (アインシュタイン, 発明した, 一般相対性理論)。

KG埋め込みモデル

モデルスコア関数特徴
TransEdist(h+r, t)シンプルで効率的
RotatE複素数空間での回転対称/反対称関係の処理
ComplEx複素数内積非対称関係の処理
DistMult三線形スコアシンプルで効果的

知識グラフ補完 (KG Completion)

知識グラフには欠損トリプルが多数あります。KG補完はこの欠損している関係を予測します。

import torch
from torch_geometric.nn import TransE

# TransEモデルの初期化
model = TransE(
    num_nodes=1000,
    num_relations=50,
    hidden_channels=50,
    margin=1.0,
    p_norm=1
)

10. GNNの限界と発展方向

過平滑化 (Over-smoothing)

GNNのレイヤー数が増えるにつれてノード埋め込みが互いに似通ってくる過平滑化問題が発生します。K層GNNはK-hop近傍の情報を集約しますが、Kが大きくなるとすべてのノードの表現が収束してしまいます。

解決策:

  • 残差接続(Residual connections)の使用
  • JK-Net (Jumping Knowledge Networks)
  • DropEdge: 学習時に一部のエッジをランダムに削除

表現力: WLテスト

GNNの表現力はWeisfeiler-Leman (WL) グラフ同型テストと等価であることが証明されています。つまり、WLテストで区別できない2つのグラフは、標準GNNでも区別できません。

より強力なGNN: k-WLテストに相当する高次元GNN、ランダム特徴量の追加、ポートナンバリングなどの方法が研究されています。

スケーラビリティ (Scalability)

数百万個のノードを持つグラフでは、グラフ全体を一度に処理することが困難です。

解決策:

  • GraphSAGEミニバッチ: 近傍サンプリングベースのミニバッチ
  • Cluster-GCN: グラフをクラスターに分割してクラスター内で学習
  • GraphSAINT: 重要度ベースのサブグラフサンプリング

GNNを超えて

  • Hypergraph Neural Networks: ハイパーエッジ(3つ以上のノードを接続)の処理
  • 3D分子構造学習: SchNet、DimeNetなど3D座標の活用
  • 時系列グラフ: TGN (Temporal Graph Networks) による動的グラフの処理

11. 実践プロジェクト: ソーシャルネットワークのノード分類

Redditデータセット

Redditデータセットは232965件の投稿(ノード)と約114万件のエッジで構成される大規模グラフです。各投稿がどのコミュニティ(subreddit)に属するかを分類するタスクです。

Cluster-GCNミニバッチ学習

大規模グラフでは全隣接行列をメモリに載せることができません。Cluster-GCNはMETISアルゴリズムでグラフをクラスターに分割し、各クラスター内でミニバッチ学習を行います。

import torch
import torch.nn.functional as F
from torch_geometric.datasets import Reddit
from torch_geometric.loader import ClusterData, ClusterLoader
from torch_geometric.nn import SAGEConv

# データ読み込み
dataset = Reddit(root='/tmp/Reddit')
data = dataset[0]

# クラスター分割
cluster_data = ClusterData(data, num_parts=1500, recursive=False)
train_loader = ClusterLoader(
    cluster_data, batch_size=20, shuffle=True, num_workers=2
)

class ClusterGCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels)
        self.conv3 = SAGEConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu(self.conv2(x, edge_index))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv3(x, edge_index)
        return F.log_softmax(x, dim=1)

model = ClusterGCN(
    in_channels=dataset.num_features,
    hidden_channels=256,
    out_channels=dataset.num_classes
)

optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

def train():
    model.train()
    total_loss = total_nodes = 0
    for batch in train_loader:
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index)
        loss = F.nll_loss(out[batch.train_mask], batch.y[batch.train_mask])
        loss.backward()
        optimizer.step()
        nodes = batch.train_mask.sum().item()
        total_loss += loss.item() * nodes
        total_nodes += nodes
    return total_loss / total_nodes

評価指標

ノード分類で使用する主な評価指標は以下の通りです。

  • Accuracy (精度): 全体の正解率
  • Macro F1: クラス不均衡時に有用
  • Micro F1: 全TP/FP/FNに基づく
  • ROC-AUC: 二値分類時に有用

Redditデータセットでは、Cluster-GCNは95%以上のF1スコアを達成します。


12. GNNアーキテクチャの比較

主要GNNの性能比較

モデルCora精度特徴適用場面
GCN~81%シンプル、高速小中規模グラフ
GAT~83%アテンション重み近傍重要度が異なる場合
GraphSAGE~82%帰納的学習大規模動的グラフ
GIN~80%+最大表現力グラフ分類
Graph Transformer~84%+グローバルアテンション分子・知識グラフ

グラフ分類の完全な例: GIN

from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINConv, global_add_pool

# MUTAGデータセット
dataset = TUDataset(root='/tmp/TUDataset', name='MUTAG')
dataset = dataset.shuffle()
n = len(dataset)
train_dataset = dataset[:int(0.8 * n)]
test_dataset = dataset[int(0.8 * n):]

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)

class GIN(torch.nn.Module):
    """Graph Isomorphism Network - 最大表現力を持つGNN"""

    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=5):
        super().__init__()
        self.convs = torch.nn.ModuleList()
        self.bns = torch.nn.ModuleList()

        for i in range(num_layers):
            in_ch = in_channels if i == 0 else hidden_channels
            mlp = torch.nn.Sequential(
                torch.nn.Linear(in_ch, hidden_channels),
                torch.nn.BatchNorm1d(hidden_channels),
                torch.nn.ReLU(),
                torch.nn.Linear(hidden_channels, hidden_channels)
            )
            self.convs.append(GINConv(mlp, train_eps=True))
            self.bns.append(torch.nn.BatchNorm1d(hidden_channels))

        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(hidden_channels * num_layers, hidden_channels),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.5),
            torch.nn.Linear(hidden_channels, out_channels)
        )

    def forward(self, x, edge_index, batch):
        xs = []
        for conv, bn in zip(self.convs, self.bns):
            x = conv(x, edge_index)
            x = bn(x)
            x = F.relu(x)
            xs.append(global_add_pool(x, batch))

        out = torch.cat(xs, dim=1)
        return self.classifier(out)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_gin = GIN(
    in_channels=dataset.num_features,
    hidden_channels=64,
    out_channels=dataset.num_classes
).to(device)

クイズ

Q1. GCNとGATの最も大きな違いは何ですか?

正解: GCNはすべての近傍ノードを次数に基づく固定重みで集約しますが、GATはアテンションメカニズムを通じて各近傍の重みを動的に学習します。

説明: GCNの集約重みは隣接行列の次数正規化によって固定されています。一方、GATのアテンション係数は接続された2つのノードの特徴ベクトルに基づいて動的に計算されるため、より重要な近傍により多くの注意を払うことができます。マルチヘッドアテンションで安定性を高めることもGATの利点です。

Q2. GraphSAGEがGCNよりも帰納的学習に優れている理由は何ですか?

正解: GraphSAGEは集約関数(aggregator)を学習することで、新しいノードの近傍から埋め込みを生成できるからです。

説明: GCNは学習時にグラフ全体の隣接行列が必要なため、新しいノードが追加されると再学習が必要なトランスダクティブ方式です。GraphSAGEは近傍をサンプリングして集約する関数を学習するため、学習時に見ていない新しいノードにもこの関数を適用して埋め込みを生成できます。PinterestやLinkedInなど動的に変化する大規模グラフで実際に活用されています。

Q3. MPNNフレームワークの3つのステップは何ですか?

正解: Message(メッセージ計算)、Aggregate(集約)、Update(更新)の3ステップです。

説明: Messageステップでは各エッジに対して近傍ノードから送るメッセージを計算します。Aggregateステップではノードが受信したすべての近傍メッセージを和、平均、最大値などで集約します。Updateステップでは集約されたメッセージと現在のノード埋め込みを組み合わせて新しいノード埋め込みを生成します。GCN、GAT、GraphSAGE、GINなどほとんどのGNNがこのフレームワークで統一できます。

Q4. 過平滑化(Over-smoothing)問題とは何ですか? どのように解決できますか?

正解: レイヤーが深くなるにつれてすべてのノードの埋め込みが似通ってくる現象です。残差接続、JK-Net、DropEdgeなどで軽減できます。

説明: K層GNNはK-hop近傍の情報を集約します。レイヤーが多くなるほどより広い近傍を含むようになり、最終的にすべてのノードが同じグローバル平均に収束します。残差接続は前のレイヤーの情報を直接伝えて固有情報を保持します。JK-Net(Jumping Knowledge Networks)はすべてのレイヤーの埋め込みを最終表現に活用します。DropEdgeは学習時に一部のエッジをランダムに削除します。

Q5. GNNの表現力がWLテストと等価であるとはどういう意味ですか?

正解: 標準GNNはWeisfeiler-Leman(WL)グラフ同型テストで区別できない2つのグラフを同じように埋め込むということです。

説明: WLテストは2つのグラフが同型(isomorphic)かどうかを判定するアルゴリズムで、反復的に近傍のラベルを集約してハッシュ化します。Xu et al. (2019)はGIN(Graph Isomorphism Network)を通じて、標準GNNの表現力が1-WLテストと等価であることを証明しました。WLテストが失敗するグラフペア(例: 同じ次数列を持つ非同型グラフ)ではGNNも2つのグラフを区別できません。これを克服するためにより強力なk次WLテストに相当する高次元GNNの研究が進んでいます。


まとめ

このガイドではグラフニューラルネットワークの全体的なエコシステムを解説しました。

  1. グラフ理論の基礎: ノード、エッジ、隣接行列、グラフ特性
  2. メッセージパッシングパラダイム: GNNの核心原理
  3. 主要アーキテクチャ: GCN、GraphSAGE、GAT、Graph Transformer、GIN
  4. 分子ML: 創薬・分子特性予測へのGNNの応用
  5. 知識グラフ: KG埋め込みとKG補完
  6. スケーラビリティ: Cluster-GCNによる大規模グラフの学習
  7. GNNの限界: 過平滑化、表現力、スケーラビリティ

GNNは分子設計、創薬、ソーシャルネットワーク分析、交通予測、推薦システムなど多岐にわたる分野で革新的な成果を上げています。PyTorch GeometricやDGLなどのライブラリにより実装がますます容易になっており、OGBなどのベンチマークにより公平な比較も可能になっています。

参考資料

  • Kipf, T. N., & Welling, M. (2017). Semi-Supervised Classification with Graph Convolutional Networks. ICLR 2017. arXiv:1609.02907
  • Velickovic, P., et al. (2018). Graph Attention Networks. ICLR 2018. arXiv:1710.10903
  • Hamilton, W. L., et al. (2017). Inductive Representation Learning on Large Graphs (GraphSAGE). NeurIPS 2017.
  • Gilmer, J., et al. (2017). Neural Message Passing for Quantum Chemistry (MPNN). ICML 2017.
  • Ying, C., et al. (2021). Do Transformers Really Perform Bad for Graph Representation? (Graphormer). NeurIPS 2021.
  • Xu, K., et al. (2019). How Powerful are Graph Neural Networks? (GIN). ICLR 2019.
  • PyTorch Geometric 公式ドキュメント
  • Deep Graph Library (DGL) ドキュメント
  • Open Graph Benchmark (OGB)