なぜ √dk で割るのか — softmax飽和問題

前章で先送りした「÷√dk って何?」を解き明かします。原論文の脚注4でひっそり説明されているこの一行が、実はTransformerが深層化できる隠れた立役者です。

問題:高次元の内積は大きくなりすぎる

Q と K の各成分が平均0・分散1の独立な確率変数だとします。dk次元の内積は次のようになります。

内積の分散

Var(Q · K) = dk

→ 標準偏差は √dk のオーダー

つまり、次元 dk が大きいほど、内積の絶対値はおおむね √dk のオーダーで増えていきます。dk=512なら内積は √512 ≒ 22 のスケールに、dk=4096なら √4096 = 64 のスケールに……どんどん大きくなるのです。

問題:softmaxが飽和する

内積スコアが極端に大きい状態でsoftmaxを通すと、何が起きるか。たとえばスコアが (10, 1, 1, 1) と (100, 91, 91, 91) ではどうでしょう。差は両方とも9ですが…

# スコア (10, 1, 1, 1) の場合
e^10 = 22026, e^1 = 2.72
softmax ≒ (22026, 2.72, 2.72, 2.72) / 22034
       ≒ (0.999, 0.0001, 0.0001, 0.0001)
       → ほぼ one-hot(最大候補だけ1、他は0)

# スコア (100, 91, 91, 91) の場合
e^100 = 2.7e43, e^91 = 3.4e39
softmax ≒ (2.7e43, 3.4e39, 3.4e39, 3.4e39) / 2.7e43
       ≒ (0.99996, 0.000013, 0.000013, 0.000013)
       → さらに極端に one-hot 化

softmaxの出力が 「ほぼ one-hot(=ほぼ最大候補だけが1で他がほぼ0)」 に飽和してしまいます。これの何が問題か?

解決策:√dkで割って分散を1に戻す

そこで、内積を √dk で割って、分散を1付近に戻します。

スケーリング後の分散

Var(Q · K / √dk) = dk / (√dk)² = 1

これにより、softmaxの入力が「鋭すぎず鈍すぎない」健全な範囲に収まり、勾配が消えずに学習が進みます。BatchNormLayerNorm と同じ「信号の標準化」の思想です。

graph LR
  A[高次元の内積\n大きすぎる] --> B[÷√dk\nスケーリング]
  B --> C[softmax\n健全な範囲]
  C --> D[勾配が流れる\n学習可能]

  A2[÷√dk なし] -.-> B2[softmax飽和\nほぼ one-hot]
  B2 -.-> C2[勾配ほぼ0\n学習停止]

  style A fill:#3b82f6,stroke:#1d4ed8,color:#fff
  style D fill:#14b8a6,stroke:#0d9488,color:#fff
  style C2 fill:#ef4444,stroke:#b91c1c,color:#fff
√dkスケーリングがあるかないかで、学習可能かどうかが決まる

Multi-Head Attention — 複数の視点を同時に

ここまで説明した「Q, K, V を作って内積→softmax→加重平均」は、実は 単一のヘッド(Single-Head) での話でした。実際のTransformerは、これを 複数並列で動かす Multi-Head Attention を使います。

単一ヘッドの限界

単一のAttentionは、softmaxによる重み付き平均で1つの「平均化された関係」を表現します。しかし自然言語には、同時に存在する複数の関係性があります。

  • 構文的関係:主語と動詞、修飾語と被修飾語
  • 共参照関係:代名詞と先行詞("it" → "animal")
  • 意味的関係:同義語、反義語、関連語
  • 位置的関係:近接する単語間

これらを1つのAttentionで平均化してしまうと、すべての関係が混ざってしまいます。Vaswaniらの原論文の言葉を借りると:

"With a single attention head, averaging inhibits this."
(単一のヘッドだと、平均化がこれを妨げてしまう)

Multi-Headの仕組み

Multi-Head Attentionは、d_model 次元を h 個のヘッド に分割し、それぞれが独立にAttentionを計算します。

# Multi-Head Attention の構造
# 原論文設定: d_model = 512, h = 8

各ヘッドの次元: d_k = d_v = d_model / h = 512 / 8 = 64

# 各ヘッドが独自のW^Q, W^K, W^Vを持つ
head_1 = Attention(Q·W_1^Q, K·W_1^K, V·W_1^V)  # 64次元
head_2 = Attention(Q·W_2^Q, K·W_2^K, V·W_2^V)  # 64次元
...
head_8 = Attention(Q·W_8^Q, K·W_8^K, V·W_8^V)  # 64次元

# 全ヘッドの出力を連結 (concat)
multi_head = Concat(head_1, ..., head_8) · W^O
           = 64 * 8 = 512次元
graph TD
  X[入力 d_model=512] --> H1[Head 1\nd_k=64\n構文に注目]
  X --> H2[Head 2\nd_k=64\n共参照に注目]
  X --> H3[Head 3\nd_k=64\n意味に注目]
  X --> H4[...]
  X --> H8[Head 8\nd_k=64\n位置に注目]
  H1 --> C[Concat\n8ヘッド連結\n512次元]
  H2 --> C
  H3 --> C
  H4 --> C
  H8 --> C
  C --> O[W^O 線形変換\n出力 512次元]

  style X fill:#3b82f6,stroke:#1d4ed8,color:#fff
  style H1 fill:#8b5cf6,stroke:#6d28d9,color:#fff
  style H2 fill:#8b5cf6,stroke:#6d28d9,color:#fff
  style H3 fill:#8b5cf6,stroke:#6d28d9,color:#fff
  style H8 fill:#8b5cf6,stroke:#6d28d9,color:#fff
  style O fill:#14b8a6,stroke:#0d9488,color:#fff
Multi-Head Attention: 8つのヘッドが異なる視点で並列に注意を計算し、最後に連結する

各ヘッドは何を学ぶか

訓練後、各ヘッドは実際に異なる役割を獲得することが分かっています。研究者がAttention重みを可視化すると、以下のようなパターンが観察されます。

ヘッドの典型的な役割 具体例 可視化された傾向
構文的依存 主語と動詞、形容詞と名詞 係り受け関係のペアに強い注目
共参照解析 "it" → "the animal" 代名詞からその先行詞への矢印
位置的近接 隣接する数単語 対角線上に注目が集中
語彙意味 同義語・反義語 意味的に近い単語間に重み
句読点・終端 文末や文の区切り 特定の記号トークンに反応

計算コストはほぼ変わらない

「8倍の計算?」と思うかもしれませんが、実は 計算量はほぼ変わりません。なぜなら d_model 次元を 8等分して各ヘッドが64次元で計算するため、全ヘッドを合わせても元と同じ計算量になるからです。

# 単一ヘッド (d_model=512)
- QK^T の計算: O(n² × 512)

# Multi-Head (h=8, d_k=64)
- 1ヘッドの QK^T: O(n² × 64)
- 8ヘッド合計:    O(n² × 64 × 8) = O(n² × 512)

→ 計算量は同じ!

ヘッド数 h はどう決めるか

原論文では h=8 でしたが、モデルが大きくなるにつれてヘッド数も増えていきます。

モデル d_model ヘッド数 h 1ヘッドの次元 d_k
Original Transformer 512 8 64
BERT-base 768 12 64
BERT-large 1024 16 64
GPT-3 12,288 96 128
Llama 3 8B 4,096 32 128
DeepSeek V3 7,168 128 56

傾向として、1ヘッドの次元 d_k は 64〜128 あたりが多く、ヘッド数 h でモデル全体のサイズを調整しています。

この章のまとめ

√dkスケーリング は、高次元の内積が大きくなりすぎる問題を解消し、softmaxが飽和して勾配が消えるのを防ぎます。Multi-Head Attention は、d_model次元を複数ヘッドに分割し、各ヘッドが異なる視点で関係性を学習することで、単一の平均化では捉えきれない多様な関係を表現します。

次の第8章では、GPTの心臓部である 因果マスク(Causal Mask)自己回帰生成 の仕組みに進みます。「未来を見ない」というシンプルな制約が、どうやって「次の単語を予測する」モデルを作るのかを解き明かします。

理解度チェック

問題 0 / 50%
Q1

√dk で割るスケーリングがないと、何が起きますか?

キーボード: 1〜4 で選択、Enter で回答