なぜ √dk で割るのか — softmax飽和問題
前章で先送りした「÷√dk って何?」を解き明かします。原論文の脚注4でひっそり説明されているこの一行が、実はTransformerが深層化できる隠れた立役者です。
問題:高次元の内積は大きくなりすぎる
Q と K の各成分が平均0・分散1の独立な確率変数だとします。dk次元の内積は次のようになります。
内積の分散:
Var(Q · K) = dk
→ 標準偏差は √dk のオーダー
つまり、次元 dk が大きいほど、内積の絶対値はおおむね √dk のオーダーで増えていきます。dk=512なら内積は √512 ≒ 22 のスケールに、dk=4096なら √4096 = 64 のスケールに……どんどん大きくなるのです。
問題:softmaxが飽和する
内積スコアが極端に大きい状態でsoftmaxを通すと、何が起きるか。たとえばスコアが (10, 1, 1, 1) と (100, 91, 91, 91) ではどうでしょう。差は両方とも9ですが…
# スコア (10, 1, 1, 1) の場合
e^10 = 22026, e^1 = 2.72
softmax ≒ (22026, 2.72, 2.72, 2.72) / 22034
≒ (0.999, 0.0001, 0.0001, 0.0001)
→ ほぼ one-hot(最大候補だけ1、他は0)
# スコア (100, 91, 91, 91) の場合
e^100 = 2.7e43, e^91 = 3.4e39
softmax ≒ (2.7e43, 3.4e39, 3.4e39, 3.4e39) / 2.7e43
≒ (0.99996, 0.000013, 0.000013, 0.000013)
→ さらに極端に one-hot 化 softmaxの出力が 「ほぼ one-hot(=ほぼ最大候補だけが1で他がほぼ0)」 に飽和してしまいます。これの何が問題か?
解決策:√dkで割って分散を1に戻す
そこで、内積を √dk で割って、分散を1付近に戻します。
スケーリング後の分散:
Var(Q · K / √dk) = dk / (√dk)² = 1
これにより、softmaxの入力が「鋭すぎず鈍すぎない」健全な範囲に収まり、勾配が消えずに学習が進みます。BatchNorm や LayerNorm と同じ「信号の標準化」の思想です。
graph LR A[高次元の内積\n大きすぎる] --> B[÷√dk\nスケーリング] B --> C[softmax\n健全な範囲] C --> D[勾配が流れる\n学習可能] A2[÷√dk なし] -.-> B2[softmax飽和\nほぼ one-hot] B2 -.-> C2[勾配ほぼ0\n学習停止] style A fill:#3b82f6,stroke:#1d4ed8,color:#fff style D fill:#14b8a6,stroke:#0d9488,color:#fff style C2 fill:#ef4444,stroke:#b91c1c,color:#fff
Multi-Head Attention — 複数の視点を同時に
ここまで説明した「Q, K, V を作って内積→softmax→加重平均」は、実は 単一のヘッド(Single-Head) での話でした。実際のTransformerは、これを 複数並列で動かす Multi-Head Attention を使います。
単一ヘッドの限界
単一のAttentionは、softmaxによる重み付き平均で1つの「平均化された関係」を表現します。しかし自然言語には、同時に存在する複数の関係性があります。
- 構文的関係:主語と動詞、修飾語と被修飾語
- 共参照関係:代名詞と先行詞("it" → "animal")
- 意味的関係:同義語、反義語、関連語
- 位置的関係:近接する単語間
これらを1つのAttentionで平均化してしまうと、すべての関係が混ざってしまいます。Vaswaniらの原論文の言葉を借りると:
"With a single attention head, averaging inhibits this."
(単一のヘッドだと、平均化がこれを妨げてしまう)
Multi-Headの仕組み
Multi-Head Attentionは、d_model 次元を h 個のヘッド に分割し、それぞれが独立にAttentionを計算します。
# Multi-Head Attention の構造
# 原論文設定: d_model = 512, h = 8
各ヘッドの次元: d_k = d_v = d_model / h = 512 / 8 = 64
# 各ヘッドが独自のW^Q, W^K, W^Vを持つ
head_1 = Attention(Q·W_1^Q, K·W_1^K, V·W_1^V) # 64次元
head_2 = Attention(Q·W_2^Q, K·W_2^K, V·W_2^V) # 64次元
...
head_8 = Attention(Q·W_8^Q, K·W_8^K, V·W_8^V) # 64次元
# 全ヘッドの出力を連結 (concat)
multi_head = Concat(head_1, ..., head_8) · W^O
= 64 * 8 = 512次元 graph TD X[入力 d_model=512] --> H1[Head 1\nd_k=64\n構文に注目] X --> H2[Head 2\nd_k=64\n共参照に注目] X --> H3[Head 3\nd_k=64\n意味に注目] X --> H4[...] X --> H8[Head 8\nd_k=64\n位置に注目] H1 --> C[Concat\n8ヘッド連結\n512次元] H2 --> C H3 --> C H4 --> C H8 --> C C --> O[W^O 線形変換\n出力 512次元] style X fill:#3b82f6,stroke:#1d4ed8,color:#fff style H1 fill:#8b5cf6,stroke:#6d28d9,color:#fff style H2 fill:#8b5cf6,stroke:#6d28d9,color:#fff style H3 fill:#8b5cf6,stroke:#6d28d9,color:#fff style H8 fill:#8b5cf6,stroke:#6d28d9,color:#fff style O fill:#14b8a6,stroke:#0d9488,color:#fff
各ヘッドは何を学ぶか
訓練後、各ヘッドは実際に異なる役割を獲得することが分かっています。研究者がAttention重みを可視化すると、以下のようなパターンが観察されます。
| ヘッドの典型的な役割 | 具体例 | 可視化された傾向 |
|---|---|---|
| 構文的依存 | 主語と動詞、形容詞と名詞 | 係り受け関係のペアに強い注目 |
| 共参照解析 | "it" → "the animal" | 代名詞からその先行詞への矢印 |
| 位置的近接 | 隣接する数単語 | 対角線上に注目が集中 |
| 語彙意味 | 同義語・反義語 | 意味的に近い単語間に重み |
| 句読点・終端 | 文末や文の区切り | 特定の記号トークンに反応 |
計算コストはほぼ変わらない
「8倍の計算?」と思うかもしれませんが、実は 計算量はほぼ変わりません。なぜなら d_model 次元を 8等分して各ヘッドが64次元で計算するため、全ヘッドを合わせても元と同じ計算量になるからです。
# 単一ヘッド (d_model=512)
- QK^T の計算: O(n² × 512)
# Multi-Head (h=8, d_k=64)
- 1ヘッドの QK^T: O(n² × 64)
- 8ヘッド合計: O(n² × 64 × 8) = O(n² × 512)
→ 計算量は同じ! ヘッド数 h はどう決めるか
原論文では h=8 でしたが、モデルが大きくなるにつれてヘッド数も増えていきます。
| モデル | d_model | ヘッド数 h | 1ヘッドの次元 d_k |
|---|---|---|---|
| Original Transformer | 512 | 8 | 64 |
| BERT-base | 768 | 12 | 64 |
| BERT-large | 1024 | 16 | 64 |
| GPT-3 | 12,288 | 96 | 128 |
| Llama 3 8B | 4,096 | 32 | 128 |
| DeepSeek V3 | 7,168 | 128 | 56 |
傾向として、1ヘッドの次元 d_k は 64〜128 あたりが多く、ヘッド数 h でモデル全体のサイズを調整しています。
この章のまとめ
√dkスケーリング は、高次元の内積が大きくなりすぎる問題を解消し、softmaxが飽和して勾配が消えるのを防ぎます。Multi-Head Attention は、d_model次元を複数ヘッドに分割し、各ヘッドが異なる視点で関係性を学習することで、単一の平均化では捉えきれない多様な関係を表現します。
次の第8章では、GPTの心臓部である 因果マスク(Causal Mask) と 自己回帰生成 の仕組みに進みます。「未来を見ない」というシンプルな制約が、どうやって「次の単語を予測する」モデルを作るのかを解き明かします。
理解度チェック
√dk で割るスケーリングがないと、何が起きますか?
キーボード: 1〜4 で選択、Enter で回答