Attention is all you need
缩放点积注意力,到底在算什么
给输入序列里的每个token,分配一个权重,重要的token权重高,不重要的权重低,然后把所有token的表示加权求和,得到最终的上下文表示。
Q、K、V到底是什么
- Query(查询Q):你现在要处理的token,相当于你在搜索引擎里输入的「搜索词」,你想知道这个token和序列里的其他token有什么关系。
- Key(键K):序列里所有token的「名片」,相当于搜索引擎里网页的「标题」,用来和你的搜索词做匹配,计算相似度。
- Value(值V):序列里所有token的「实际内容」,相当于搜索引擎里网页的「正文」,匹配完成后,我们会根据相似度权重,把这些内容加权求和,得到最终的结果。
在自注意力里,Q、K、V都来自同一个输入序列,也就是每个token,都会生成自己的Q、K、V: - 用自己的Q,去和所有token的K做匹配,计算相似度;
- 用相似度做权重,给所有token的V加权求和,得到这个token的上下文表示
example
我爱中国
- 生成「我」的Q,以及「我」「爱」「中国」的K和V;
- 计算「我」的Q和「我」「爱」「中国」的K的点积,得到相似度,比如和「我」的相似度是0.7,和「爱」是0.2,和「中国」是0.1;
- 经过softmax归一化,权重变成[0.7, 0.2, 0.1];
- 用这个权重,给三个token的V加权求和,得到「我」的上下文表示,这个表示里,既包含了自己的信息,也包含了上下文的信息。
为什么一定要除以
防止点积的数值过大,把softmax推入梯度饱和区,导致梯度消失。
原因
假设 和 中的每个元素,都是独立的随机变量,均值为 ,方差为 。那么两个 维向量的点积:
根据方差的性质:
- 每个 的均值是:
- 每个 的方差是:
- 所以 个这样的变量相加,总方差是 ,标准差是 。
softmax函数的特性是:输入的数值差异越大,输出的分布就越尖锐,几乎所有的概率都会集中在最大的那个值上,其他值的概率趋近于0。比如,输入是[100, 1, 1],softmax之后的结果几乎是[1, 0, 0]。
这种尖锐的分布,会带来两个致命问题:
反向传播时,softmax的梯度会趋近于0,模型完全学不到东西;
注意力权重几乎全部集中在一个token上,丢失了上下文的信息。
而我们把点积除以就可以把点积的方差重新拉回1,让数值分布在一个合理的范围内,softmax的分布不会过于尖锐,梯度也能正常传播,模型才能收敛。
掩码(Mask)到底在做什么
前瞻掩码(Look-ahead Mask):解码器的因果约束
解码器是自回归生成的,生成第 个 token 的时候,只能看到前 个已经生成的 token,绝对不能看到 及之后的 token,否则就是“考试提前看了答案”,模型学不到任何东西。
前瞻掩码的实现非常简单:生成一个和注意力矩阵相同大小的上三角矩阵,对角线以下的元素是 ,对角线以上的元素是 。把这个掩码加到 的结果上,再做 softmax:
- 未来位置的元素,加上 之后,softmax 的结果就是 ,权重为 ,完全不会被关注;
- 过去和当前位置的元素,加上 ,不受影响,正常计算权重。
举个例子,序列长度是 ,掩码矩阵就是:
这样,第 个 token 只能看自己,第 个 token 能看第 、 个,以此类推,完美保证了因果性。
填充掩码(Padding Mask):处理变长序列
我们的训练批次里,句子的长度是不一样的,为了打包成矩阵,我们会给短句子做padding,填充无意义的token。这些padding的token,不应该被模型关注到,所以我们需要把它们对应的注意力权重,设置为0。
实现方式也很简单:生成一个和序列长度相同的掩码,padding的位置是,正常token的位置是0,加到的结果上,softmax之后,padding位置的权重就是0。
多头注意力:为什么一个头不够,非要8个头
单头注意力已经能实现上下文建模了,为什么论文非要用多头注意力?这是 Transformer 的第二个核心创新,我们把它彻底讲透。
先看多头注意力的公式:
其中:
它的核心逻辑是:
用 组不同的线性投影矩阵 、、,把 、、 分别投影到 个不同的低维子空间里。原文中每个头的维度是 , 个头加起来正好是 ,和单头的维度一致;
在每个子空间里,独立做一次缩放点积注意力计算,得到 个不同的输出;
把这 个输出拼接起来,再用一个线性矩阵 做投影,得到最终的输出。
为什么多头注意力效果更好?
用一句话总结:单头注意力只能学到一种依赖模式,而多头注意力,可以让模型在不同的表示子空间里,同时学到多种不同的依赖关系。
我们用一个具体的例子,就能瞬间理解:
比如句子:「The animal didn’t cross the street because it was too tired」,这里的「it」指代的是「animal」。
单头注意力计算的时候,可能会把「it」和「street」的权重算得很高,因为它们离得近,语法上相关,反而忽略了真正的指代对象「animal」。
而多头注意力就不一样了:
- 第1个头,专门关注指代关系,学到「it」和「animal」的强依赖,权重很高;
- 第2个头,专门关注语法结构,学到「it」和「was tired」的主谓关系;
- 第3个头,专门关注因果关系,学到「because」连接的前后两个分句的依赖;
- 剩下的头,还可以关注相邻词的关系、介词搭配、时态等等。
- 每个头都有自己的分工,从不同的角度,学习序列里不同的依赖关系,最后把所有信息汇总起来,模型对句子的理解,就会比单头注意力深刻得多。
论文里的可视化结果也证明了这一点:不同的注意力头,确实学到了完全不同的模式,有的头专门关注长距离的指代关系,有的头专门关注局部的语法结构。
头数越多越好吗
不是。论文里做了消融实验,头数从8变成16,每个头的维度从64变成32,模型的BLEU值反而下降了0.9。
核心原因是:头数太多,每个头的维度就会太小,单个头的表示能力不足,反而会影响模型效果。同时,头数太多,计算量也会上升,训练成本会增加。
原文选择8个头,每个头64维,是一个非常均衡的选择:既保证了每个头有足够的表示能力,又能让模型学到足够多的依赖模式,计算量也和单头注意力基本一致。
编码器-解码器架构:到底是怎么工作的?
Transformer的整体架构,是经典的编码器-解码器结构,我们用机器翻译的例子,把整个流程走一遍,你就彻底懂了。
比如我们要把中文「我爱中国」翻译成英文「I love China」,整个流程分为两步:
编码器由6层完全相同的层堆叠而成,每层有两个子层:
1.多头自注意力层:输入序列的每个token,都能和整个输入序列的所有token做注意力,包括自己。比如「我」这个token,能关注到「爱」和「中国」,「中国」也能关注到「我」和「爱」,实现双向的全局理解。
2. 逐位置前馈网络:对每个token的表示,做独立的非线性变换,进一步提取特征。
每个子层都有残差连接+层归一化:
- 残差连接:解决深度网络的梯度消失问题,让6层的深度网络可以正常训练;
- 层归一化:把每个token的表示,归一化到均值为0、方差为1的分布,稳定训练,加速收敛。
6层编码器处理完成后,会输出一个和输入序列长度相同的表示矩阵,这个矩阵里,每个token的表示,都包含了整个输入序列的全局上下文信息。
解码器:生成输出序列
解码器的作用,是基于编码器输出的原文表示,自回归地生成目标语言的句子,也就是「写出译文」。
解码器同样由6层完全相同的层堆叠而成,每层有三个子层:
- 带掩码的多头自注意力层:和编码器的自注意力不同,这里加入了前瞻掩码,保证生成第ttt个token的时候,只能看到前t−1t-1t−1个已经生成的token,不能看到未来的token,符合自回归生成的逻辑。
- 编码器-解码器注意力层(交叉注意力):这是连接编码器和解码器的核心。这里的Q来自解码器上一层的输出,K和V来自编码器的最终输出。它的作用是:在生成每个英文token的时候,让模型能关注到中文原文里的相关token。比如生成「China」的时候,模型会重点关注原文里的「中国」这个token。
- 逐位置前馈网络:和编码器里的一样,对每个token的表示做非线性变换。
同样,每个子层都有残差连接和层归一化。
具体流程
- 输入中文「我爱中国」,经过嵌入层+位置编码,输入到6层编码器,得到编码器的输出;
- 解码器的输入,先放入一个起始符,经过掩码自注意力层,再和编码器的输出做交叉注意力,经过FFN层,输出第一个token的概率分布,采样得到「I」;(起始符就是一个特殊 token,用来告诉解码器从这里开始生成目标句子。)
- 把「I」加入解码器的输入,重复上面的步骤,生成第二个token「love」;
- 再把「love」加入输入,生成「China」;
- 最后生成结束符,翻译完成。
位置编码:为什么没有循环结构,模型还能懂序列顺序
位置编码。我们必须给每个token,注入它在序列里的位置信息,让模型知道,哪个token在前,哪个在后,序列的顺序是什么。
论文里用的是正弦余弦位置编码,公式再拿出来:
我们先拆解这个公式的含义:
:token 在序列里的位置,从 开始,比如序列长度是 , 就是 ;
:维度的索引,从 到 ,比如 ,就是 到 ;
对于位置编码的每个偶数维度 ,用正弦函数;奇数维度 ,用余弦函数;
不同的维度,有不同的波长:维度越低,波长越短,频率越高;维度越高,波长越长,频率越低。
为什么用正弦余弦,不用可学习的位置编码
论文里也做了实验,可学习的位置编码,效果和正弦余弦的几乎一样,但最终还是选了正弦余弦,核心原因有两个:
- 完美的长序列泛化能力
可学习的位置编码,是给每个位置训练一个向量,比如训练时的最大序列长度是1000,那么模型只有0-999的位置编码。如果推理时来了一个长度2000的序列,1000-1999的位置,模型从来没见过,没有对应的编码,效果会急剧下降。
而正弦余弦位置编码,是用公式计算的,不管序列多长,都可以直接算出任意位置的编码,完美泛化到训练时没见过的长序列,这对于大模型的长上下文能力,至关重要。
对于任意固定的偏移量 , 可以表示为 的线性函数。我们用三角恒等式推导一下:
2. 天然支持相对位置关系的学习
这意味着,两个位置之间的相对偏移量 ,可以通过线性变换计算出来。模型可以比较轻松地学习到「token A 在 token B 前面 3 个位置」这种相对位置关系,这对于序列建模来说非常重要。
位置编码怎么用
位置编码的维度,和token的嵌入向量维度完全相同,都是,所以我们可以直接把位置编码和嵌入向量相加,得到最终的输入向量,再输入到编码器和解码器里。
很多人会问:为什么是相加,不是拼接?
- 拼接会让输入的维度翻倍,从512变成1024,增加了后续层的计算量;
- 相加不会改变维度,计算量不变,而且模型可以通过线性变换,自动学习如何分离嵌入信息和位置信息,效果完全不输拼接。
训练细节:为什么你复现的Transformer训练不收敛
- 学习率调度:warmup是关键
这个公式分为两个阶段:
- 预热阶段(前4000步):学习率随步数线性增长,从0慢慢涨到最大值。这是因为模型初始化的时候,参数都是随机的,如果一开始学习率太大,会导致模型参数更新过猛,直接崩掉,训练不收敛。warmup可以让模型在初期,用小学习率慢慢稳定下来,再逐步提高学习率,加速收敛。
- 衰减阶段(4000步之后):学习率随步数的平方根的倒数衰减,慢慢降低,让模型在训练后期,用小学习率精细调整参数,收敛到最优解。
- 残差连接+层归一化的顺序
论文里的残差连接,用的是Pre-LN结构,也就是:
Post-LN:
先把输入x做层归一化,再输入到子层里,然后和原始的x做残差相加,再做一次层归一化?不,论文里的顺序是:子层的输入先做LN,再过子层,再加残差,也就是Pre-LN。
而后来的很多模型,比如BERT,用的是Post-LN,也就是先过子层,加残差,再做LN。两者的区别是:
Pre-LN:训练更稳定,不需要warmup也能收敛,现在的大模型几乎都用Pre-LN;
Post-LN:最终效果更好,但训练不稳定,需要非常小心的warmup和初始化。
论文里用的Pre-LN,是Transformer能稳定训练的另一个关键。
3. 标签平滑:提升泛化能力的小技巧
论文里用了标签平滑:
标签平滑的核心逻辑是:
传统的交叉熵损失,会让模型对正确标签的置信度趋近于 (1),错误标签趋近于 (0)。这会导致模型过拟合,对自己的预测过于自信;
标签平滑会把正确标签的概率从 (1) 变成 (),把剩下的 () 平均分配给所有错误标签,让模型不要过于自信,从而提升泛化能力。
虽然标签平滑会让模型的困惑度(Perplexity)轻微上升,但 BLEU 值(模型生成的翻译结果和人工参考翻译有多相似。)会显著提升,这也是论文里的一个关键优化。
