GNN(Graph Neural Network)とGCN(Graph Convolutional Network)の理論と実装

GNN(Graph Neural Network)とGCN(Graph Convolutional Network)はGNNの一種で、グラフ畳み込み演算を用いる。 個人的なメモとして、GNNとGCNについて、簡単な理論と実装についてまとめる。

ソースコード

github

  • jupyter notebook形式のファイルはこちら

google colaboratory

  • google colaboratory で実行する場合はこちら

実行環境

OSはmacOSである。LinuxやUnixのコマンドとはオプションが異なるので注意していただきたい。

!sw_vers
ProductName:		macOS
ProductVersion:		15.2
BuildVersion:		24C101
!python -V
Python 3.9.17

基本的なライブラリをインポートし watermark を利用してそのバージョンを確認しておきます。 ついでに乱数のseedの設定をします。

%matplotlib inline
%config InlineBackend.figure_format = 'svg'
import random

import scipy
import numpy as np

import matplotlib
import matplotlib.pyplot as plt

seed = 123
random_state = 123

random.seed(seed)
np.random.seed(seed)


from watermark import watermark

print(watermark(python=True, watermark=True, iversions=True, globals_=globals()))
Python implementation: CPython
Python version       : 3.9.17
IPython version      : 8.17.2

scipy     : 1.11.2
matplotlib: 3.8.1
numpy     : 1.25.2

Watermark: 2.4.3

概要

GNN(Graph Neural Network)はグラフ構造を扱うニューラルネットワークである。 GCN(Graph Convolutional Network)はGNNの一種で、グラフ畳み込み演算を用いる。

ノード同士の接続情報や特徴量を学習し、様々なタスクに応用可能である。 グラフ上のノード分類やリンク予測、グラフ全体の分類などに活用されおり、推薦システムやSNS分析など、大規模データにも利用が進む。

本記事ではGNNの基礎からGCNの仕組み、実装例まで幅広く取り上げる。

あくまでも個人的な備忘録であるので、間違い等は勘弁してください。


1. GNNの基礎

1-1. GNNとは何か

GNNはグラフを入力として処理するネットワークの総称である。 グラフは頂点(ノード)と辺(エッジ)で構成される離散構造である。 画像やテキストとは異なり、ノード間のつながりを明示的に扱う。 ノードごとに特徴ベクトルが与えられる場合も多い。

リンク予測などのタスクでは、ノード同士の関係が重視される。 GNNはノード情報と近傍ノードの情報を集約して特徴を学習する。

代表的なアプローチに、Message Passing機構などが挙げられる。 この機構によって、各ノードは隣接ノードからの情報を受け取り、更新する。 層を重ねることで、より遠いノードの情報も伝播させることができる。

ノード分類、グラフ分類、リンク予測など、応用範囲は広い。

1-2. GNNの一般的な構造

GNNの一般的なフレームワークは以下で示される。

  1. ノード特徴を初期化し、$ \mathbf{h}_v^{(0)} $として与える。
  2. 各層でメッセージを隣接ノードから集約する演算が行われる。
  3. そのメッセージを用いてノードの埋め込みを更新する。
  4. 最終層の出力埋め込みをタスクに応じて用いる。

GNNでは隣接ノードとの情報交換が鍵となる。

ノード $ v $ の近傍を $ \mathcal{N}(v) $ とすると、更新式は一般的に $$ \mathbf{h}_v^{(l+1)} = \phi\bigl(\mathbf{h}_v^{(l)}, \text{Aggregate} (\mathbf{h}_u^{(l)}: u \in \mathcal{N}(v))\bigr) $$ のようになる。

このAggregate部分に加重平均や畳み込みなどが入ることで、モデルが変化する。 学習可能パラメータを設け、誤差逆伝播によってパラメータを更新する。

1-3. GNNの応用例

SNS分析では、ユーザーノード同士のつながりを解析できる。 推薦システムでは、ユーザーとアイテムを二部グラフとして扱う例がある。 化合物などの分子構造をグラフで表し、薬剤設計にも活用される。 物流ネットワークなどの最適化問題にも応用が広がる。

ノードの中心性やコミュニティ構造を学習することで分析精度が向上する。 企業においても人事データを活用し、組織内コミュニケーションの可視化にも利用される。 GNNの汎用性は広いため、今後さらに導入が加速していく見込みである。

1-4. GNNのメリット・デメリット

GNNはグラフ構造を直接扱える点が最大の利点である。 ノード同士の豊富な関係性を考慮した学習が可能となる。 また、疎なグラフではスケールが大きくなっても演算を工夫できる。 一方で、大規模かつ密なグラフでは計算負荷が高騰しやすい。

隣接行列が巨大になれば、メモリ消費も無視できない。

多層GNNではオーバースムージング(over-smoothing)も課題となる。 オーバースムージングとは、層を深くするとノード埋め込みが均質化する現象。 これにより、ノード間の区別が失われ、タスク性能が下がる。 研究ではジャンプ構造やスキップ接続などで回避を図る試みがある。 また、学習におけるハイパーパラメータ調整もGNNの難しさの一つとなる。


2. GCNの基礎

2-1. GCNとは

GCNはGNNの一形態であり、グラフ畳み込み演算を導入したモデルである。 画像解析などで利用されるなCNN(畳み込みニューラルネットワーク)の考え方をグラフに拡張する。 ノード毎に定義された特徴量の行列を $ \mathbf{X} $、隣接行列を $ \mathbf{A} $ とする。 畳み込み処理には正規化された隣接行列を用いるのが一般的である。 簡単なGCN層では以下のような更新式が用いられることが多い。 $$ \mathbf{H}^{(l+1)} = \sigma\bigl(\hat{\mathbf{D}}^{-\frac{1}{2}}\hat{\mathbf{A}}\hat{\mathbf{D}}^{-\frac{1}{2}} \mathbf{H}^{(l)} \mathbf{W}^{(l)}\bigr) $$ ただし、$ \hat{\mathbf{A}} = \mathbf{A} + \mathbf{I} $、 $ \hat{\mathbf{D}} $ は $ \hat{\mathbf{A}} $ の対角要素を用いて定義される。 $ \mathbf{W}^{(l)} $ は学習対象のパラメータ行列、 $ \sigma $ は活性化関数(例: ReLU)である。 GCNは近傍ノード情報を畳み込み的に集約し、ノード埋め込みを更新する。 画像における畳み込みとは異なり、グラフの局所構造を扱う。

2-2. GCNのメリット

GCNはグラフデータを自然に扱える。 ノード分類やグラフ分類への適用実績が豊富である。 畳み込み演算自体がシンプルなので、実装が比較的容易である。 頂点数と辺数が適度に大きい場合でも、工夫次第でスケール可能である。 グラフという離散構造を連続ベクトルとして扱う枠組みを与える。

また、GCNは層構造を深くしすぎなければ計算量の制御がしやすい。 メモリ効率を改善したサンプリング手法も研究されている。

Node2Vecなどの埋め込み手法と比較しても高い性能を示す場合が多い。 GCNは他のGNN手法の基礎になるため、個人的にはとても理解しやすいと思っている。

2-3. GCNのデメリット

単純なGCNは、大規模グラフでの計算負荷が増大しがちである。 フルバッチで隣接行列を扱うため、メモリを大量に消費しやすい。

深い層構造にした場合、オーバースムージングが発生するリスクがある。

さらに、GCNはノードの特徴が一様ではない場合、更新がうまく機能しない可能性もある。 畳み込みの定義の仕方によって表現力が制限される面もある。

一方で、Attention機構を導入したGATなども登場し、性能向上が図られている。 隣接行列の構造に強く依存するため、疎な連結成分の扱いが難しいケースもある。 データ前処理やグラフクリーニングの手間が、従来の機械学習より増える可能性がある。

超巨大グラフの場合、ミニバッチ的なサンプリングが必要となる。 その際、サンプリングの仕方によって性能が大きく変わるという課題もある。

2-4. GCNと他手法との比較

GCNはGraphSAGEやGATなどの先駆け的役割を果たした。 GraphSAGEではサンプリングにより大規模グラフにも対応しやすくなっている。 GAT(Graph Attention Network)は注意機構によりノードごとの重みを学習する。 GCNは実装がシンプルで安定的に動くため、研究の入り口として最適。 実務でも、まずはGCNを試してから他手法を検討する流れが多い。 性能要求やデータ特性によって、モデル選択は変動する。 GCNはグラフ全体が静的で更新が少ない場合に強みがある。 ダイナミックに変わるグラフ(例: SNSのリアルタイム更新)では他手法に軍配が上がる場合もある。 GCNはベースラインとして実験を行いやすい位置にある。 論文やオープンソース実装が豊富に存在するため、導入も容易といえる。


3. GCNの実装例 (Python)

以下に簡単なGCNの実装例を示す。

わかりやすくするため、クラス定義とフォワード処理を簡略化する。

import torch
import torch.nn as nn
import torch.nn.functional as F


class GCNLayer(nn.Module):
    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()
        self.W_mat = nn.Parameter(torch.FloatTensor(in_dim, out_dim))
        nn.init.xavier_uniform_(self.W_mat.data)

    def forward(self, X_tensor, adj_mat):
        # 正規化隣接行列の生成
        I_tensor = torch.eye(adj_mat.size(0))
        A_hat = adj_mat + I_tensor
        D_hat = torch.diag(torch.sum(A_hat, dim=1))
        D_hat_inv_sqrt = torch.sqrt(torch.inverse(D_hat))
        A_norm = D_hat_inv_sqrt @ A_hat @ D_hat_inv_sqrt

        # GCN演算
        H_tensor = A_norm @ X_tensor @ self.W_mat
        return H_tensor


class GCN(nn.Module):
    def __init__(self, in_dim: int, hidden_dim: int, out_dim: int):
        super().__init__()
        self.layer1 = GCNLayer(in_dim, hidden_dim)
        self.layer2 = GCNLayer(hidden_dim, out_dim)

    def forward(self, X_tensor, adj_mat):
        H_tensor = self.layer1(X_tensor, adj_mat)
        H_tensor = F.relu(H_tensor)
        H_tensor = self.layer2(H_tensor, adj_mat)
        return H_tensor


# ノード特徴行列X_tensorと隣接行列adj_matの例
X_list = [[1.0, 0.5], [0.3, 0.8], [0.9, 0.1], [0.2, 0.6]]
adj_list = [[0, 1, 0, 0], [1, 0, 1, 0], [0, 1, 0, 1], [0, 0, 1, 0]]
X_tensor = torch.tensor(X_list, dtype=torch.float)
adj_mat = torch.tensor(adj_list, dtype=torch.float)

model = GCN(in_dim=2, hidden_dim=4, out_dim=2)
output_tensor = model(X_tensor, adj_mat)
print(output_tensor)
tensor([[ 0.3374, -0.1291],
        [ 0.3729, -0.1495],
        [ 0.3290, -0.1271],
        [ 0.2402, -0.0990]], grad_fn=<MmBackward0>)

上記コードは非常に簡素化したものである。

実際にはミニバッチやGPU対応、損失関数などを含める必要がある。 PyTorch GeometricやDGLといったライブラリを利用するのが一般的である。(個人的には使った事がないが…)

この例では、隣接行列は小規模なのでフルバッチで扱っている。 さらに大きなグラフではサンプリングやデータローダーが必要になる。 研究や企業での実際の実装時には最適化手法にも注意を払う必要がある。


4. GNNの実装例 (Python)

GCN以外のシンプルなGNN実装例を示す。 ここではMessage Passingの概念を簡易的に表現した例を紹介する。 ノードごとに近傍ノードから情報を集約し、埋め込みを更新する流れである。

import torch
import torch.nn as nn
import torch.nn.functional as F


class SimpleGNNLayer(nn.Module):
    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()
        self.W_mat = nn.Linear(in_dim, out_dim)

    def forward(self, X_tensor, adj_mat):
        # 隣接行列の行方向に和をとって近傍の特徴を集約
        agg_tensor = torch.matmul(adj_mat, X_tensor)
        # 重み付け線形変換
        updated_tensor = self.W_mat(agg_tensor)
        return updated_tensor


class SimpleGNN(nn.Module):
    def __init__(self, in_dim: int, hidden_dim: int, out_dim: int):
        super().__init__()
        self.layer1 = SimpleGNNLayer(in_dim, hidden_dim)
        self.layer2 = SimpleGNNLayer(hidden_dim, out_dim)

    def forward(self, X_tensor, adj_mat):
        H_tensor = self.layer1(X_tensor, adj_mat)
        H_tensor = F.relu(H_tensor)
        H_tensor = self.layer2(H_tensor, adj_mat)
        return H_tensor


# ノード特徴行列X_tensorと隣接行列adj_matの例
X_list = [[1.0, 0.5], [0.3, 0.8], [0.9, 0.1], [0.2, 0.6]]
adj_list = [[0, 1, 0, 0], [1, 0, 1, 0], [0, 1, 0, 1], [0, 0, 1, 0]]
X_tensor = torch.tensor(X_list, dtype=torch.float)
adj_mat = torch.tensor(adj_list, dtype=torch.float)

model_gnn = SimpleGNN(in_dim=2, hidden_dim=4, out_dim=2)
output_gnn_tensor = model_gnn(X_tensor, adj_mat)
print(output_gnn_tensor)
tensor([[-0.0374,  0.0476],
        [-0.4478,  0.1145],
        [-0.1853, -0.0557],
        [-0.2854,  0.0934]], grad_fn=<AddmmBackward0>)

GCNとの違いは、正規化などを施していない点である。 また、GCNはさらに畳み込みの概念を数式に組み込んでいる。 このシンプルGNNはノード近傍情報の単純な加算で表している。 実運用では、正規化の有無などで学習性能や安定性が変化する。


5. 数式の例と定義域

GNNやGCNで用いられる関数を数式で定義する場合を考える。 ノード集合を $V$、辺集合を $E$ とし、グラフを $G = (V, E)$ とする。 各ノード $ v \in V $ に特徴ベクトル $ \mathbf{x}_v \in \mathbb{R}^d $ が与えられる。 ここで、$ v \mapsto \mathbf{x}_v $ という写像を考える。 メッセージパッシングでは、ノード $ v $ に対して近傍 $ \mathcal{N}(v) $ の特徴を集約する。 一般に集約関数を $ \text{AGG} $ とすると、 $$ \mathbf{m}_v^{(l)} = \text{AGG}\bigl({\mathbf{h}_u^{(l)} : u \in \mathcal{N}(v)}\bigr) $$ という形になる。 その後、更新関数を $ \text{UPDATE} $ として、 $$ \mathbf{h}_v^{(l+1)} = \text{UPDATE}\bigl(\mathbf{h}_v^{(l)}, \mathbf{m}_v^{(l)}\bigr) $$ とする。 これらの関数の定義域は、

  • $ \text{AGG}: (\mathbb{R}^d)^{|\mathcal{N}(v)|} \rightarrow \mathbb{R}^d $
  • $ \text{UPDATE}: \mathbb{R}^d \times \mathbb{R}^d \rightarrow \mathbb{R}^d $ となることが多い。 GCNの場合は、集約関数が行列演算で表現され、$ \text{UPDATE} $ は線形変換と非線形活性化となる。 数式表現をきちんと理解することで、GNNの動作を深く把握できる。

6. 実際のビジネス応用

6-1. レコメンドシステム

ユーザーとアイテムをノードとして扱う二部グラフを構築する例がある。 ノード同士の類似度や共起情報をエッジでつなぐ。 GNNを用いることでユーザーの嗜好を学習し、推薦精度を向上させる。 従来のMatrix Factorizationよりも、多様な特徴を取り込める利点がある。 企業のECサイトや動画配信、音楽ストリーミングなどで活用が進む。 ノード数が数千万、辺数が数億を超えるケースもあり、サンプリング必須である。 GraphSAGEなどの大規模手法と組み合わせる場合が多い。 アイテムノードの属性情報(価格帯、ジャンルなど)を統合して特徴化を図る。 ユーザー側も年齢や性別、行動履歴など多くの情報をノード特徴に落とし込む。 GNNでノード埋め込みを獲得し、そのベクトル近傍のアイテムを推薦候補に挙げる。

6-2. SNS分析

SNSのユーザー間つながりは、自然にグラフとして表現できる。 GNNを使えば、コミュニティやインフルエンサーを発見できる。 投稿内容やフォロー関係をノード特徴やエッジ特徴に含めることが可能。 GCNでユーザー埋め込みを学習し、新規アカウントの興味を推定する例がある。 不正アカウントの検出やトレンド予測にも活用されている。 大規模SNSではマルチGPUや分散処理を利用し、リアルタイム更新を可能にする。 GATのように注目すべき近傍ノードを強調する仕組みも有効。 Edgeとしていいね関係やコメント関係を同時に扱うケースもある。 SNS上の頻繁につながりが変化する環境では動的GNNへの研究が重要。 企業のマーケティング部門での活用事例が増えている分野でもある。

6-3. 分子構造解析

分子を原子ノードと化学結合エッジからなるグラフとみなす。 GNNを用いると、新薬候補分子の性質や反応性を予測できる。 結合関係の種類をエッジ特徴として設定し、学習に用いる。 GCNやGraphSAGEに加え、位置情報を組み込む手法も研究されている。 シミュレーションコストを削減し、高速に有望化合物を探せるメリットがある。 化学分野や製薬企業での導入事例が増えてきている。 ノードの原子種、エッジの結合種類など多次元特徴が絡む問題設定となる。 深いGNNを使い分子の高階な特徴を捉える試みが盛んである。 QSAR(定量的構造活性相関)解析などの伝統的手法を補完する役割も果たす。 探索空間が膨大でも、効率的にアプローチできる点が重要視される。

6-4. その他の応用

サプライチェーン管理では、企業間取引や物流経路をグラフとして扱う例がある。 保険会社では、契約者間の関連性やクレーム情報をグラフとして解析する取り組みがある。 電力網や通信ネットワークの障害解析にも活用可能である。 銀行や証券取引では、取引ネットワークからリスクを検知する事例が報告されている。 ロジスティックスや交通計画において、最適ルート探索にGNNが応用される場合もある。 IoTセンサーの配置やデバイス同士の通信関係をグラフ化し、故障検知に活かす例もある。 各種産業での応用可能性が広がっており、今後の発展が期待される。


7. 計算例: 小規模グラフでのGCN計算

次の小規模例を考える。 ノード数3、特徴次元2、隣接行列は以下とする。

$$ \mathbf{X} = \begin{pmatrix} 1 & 0 \\ 0 & 1 \\ 1 & 1 \end{pmatrix}, \quad \mathbf{A} = \begin{pmatrix} 0 & 1 & 0 \\ 1 & 0 & 1 \\ 0 & 1 & 0 \end{pmatrix} $$

GCNの1層目を考える。 正規化した隣接行列を $ \tilde{\mathbf{A}} $ とする。 $ \tilde{\mathbf{A}} = \mathbf{A} + \mathbf{I} $ なので、

$$ \tilde{\mathbf{A}} = \begin{pmatrix} 1 & 1 & 0 \\ 1 & 1 & 1 \\ 0 & 1 & 1 \end{pmatrix} $$

対角行列 $ \tilde{\mathbf{D}} $ は、各行の総和を対角に置く。

$$ \tilde{\mathbf{D}} = \begin{pmatrix} 2 & 0 & 0 \\ 0 & 3 & 0 \\ 0 & 0 & 2 \end{pmatrix} $$

$ \tilde{\mathbf{D}}^{-\frac{1}{2}} $ を計算すると、

$$ \tilde{\mathbf{D}}^{-\frac{1}{2}} = \begin{pmatrix} \frac{1}{\sqrt{2}} & 0 & 0 \\ 0 & \frac{1}{\sqrt{3}} & 0 \\ 0 & 0 & \frac{1}{\sqrt{2}} \end{pmatrix} $$

すると、

$$ \hat{\mathbf{A}} = \tilde{\mathbf{D}}^{-\frac{1}{2}} \tilde{\mathbf{A}} \tilde{\mathbf{D}}^{-\frac{1}{2}} $$

を求めればいい。 次に、GCNのパラメータ行列 $\mathbf{W}^{(0)}$ を仮に

$$ \mathbf{W}^{(0)} = \begin{pmatrix} 1 & -1 \\ 2 & 1 \end{pmatrix} $$

とする。 このとき、1層目の出力は

$$ \mathbf{H}^{(1)} = \sigma(\hat{\mathbf{A}} \mathbf{X} \mathbf{W}^{(0)}) $$

となる。 具体的に数値を入れれば、各要素が求まる。 ここでは単に行列積を順番に計算して、最後に活性化関数(ReLUなど)を適用する。 実務でも、GCNはこのような行列演算の組み合わせになっている。 $ |\mathbf{H}^{(1)}|_F $ などのノルムをとって勾配計算をする場面もある。 小さな例で手計算してみると、理解が深まる。


8. 結論

この記事では、GNNとGCNの基礎から実装例、ビジネス応用まで解説した。 GNNはグラフ構造を扱う強力なフレームワークであり、GCNはその代表的手法である。 企業のレコメンドやSNS分析、分子構造解析など、多彩な領域で利用されている。 計算コストやオーバースムージングなど課題も存在するが、研究は活発である。 Pythonにおける実装例を通じて、コード化が容易である点も示した。

GNNを始める際は、まずGCNを基礎に学習し、さらに派生手法を検討すると良いと個人的には思っている。