最近遇到一个有趣的问题:就是Transformer中的MultiHeadAttention为什么使用scaled?打算在这个问题上展开来分析并做一些拓展思考。

这些分享一下~

首先我们还是谈谈MultiHeadAttention及其实现。

Transformer中的Attention

使用scaled的Attention的数学表达如下,

这里的scaled指的就是分母除以$\sqrt{d_k}$。$d_k$是$\boldsymbol{K} \in \mathbb{R}^{l \times d}$中$d$的大小。后续的论文也在Attention中提出很多的变形,包括不同形式的Mask、引入先验、线性化Attention(即去掉softmax)等等。

MultiHeadAttention的Python-based伪代码实现,

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def MultiHeadAttention(Q, K, V):
"""O(n^2*d)"""
# 线性变换
Qw = Linear(Q)
Kw = Linear(K)
Vw = Linear(V)

# 形状变换
Qw = Qw.reshape((batch_size, seq_len, heads, hdims))
Kw = Kw.reshape((batch_size, seq_len, heads, hdims))
Vw = Vw.reshape((batch_size, seq_len, heads, hdims))

# 计算评分矩阵
scores = einsum("bjhd,bkhd->bhjk", Qw, Kw) # 爱因斯坦求和约定
scores = scores / sqrt(d_k)
scores = scores * mask - 1e12 * (1 - mask) # mask处理
A = softmax(scores, axis=-1) # (j,k)的k方向归一化

# 更多处理,如Dropout,先验Mask等等
A = Dropout()(A)

# 加权求和
attn = einsum("bhjk,bkhd->bjhd", A, Vw)
attn = attn.reshape((batch_size, seq_len, heads * hdims))
# 线性变换整合多头信息
attn = Linear(attn)
return attn

当然我们还有比较偷懒的实现,就是直接按照数学公式来实现,毕竟是伪代码。以上的实现我们注意到:

  • $\sqrt{d_k}$参数是固定的,即是一个常数
  • $\sqrt{d_k}$参数在多个头之间是共享的

注意到这两点,一会的分析会用到。那么线性的问题是Transformer中的MultiHeadAttention为什么使用scaled?换成注意力机制问题来说就是,为什么MultiHeadAttention为什么要引入点积缩放评分函数?

相关的评分函数可以参考过去的文章漫谈注意力机制(一):人类的注意力和注意力机制基础-常见评分函数

随机向量点积的方差

一切皆是随机变量。这里可以从随机向量上分析。

在点积缩放模型评分函数下,第$i$个查询向量$\boldsymbol{q}_i$对向量序列中第$j$个键$\boldsymbol{k}_j$​的评分值为$\alpha_{ij}$,

可以把$\boldsymbol{q}_i$和$\boldsymbol{k}_j$都看做是随机向量,

这里假定每个随机向量中的元素都是独立同分布的,在恰当的初始化(Transformer使用截断正太分布)下有$\operatorname{Var}(x_i) = 1, E[x_i] = 0$,对于$y_i$一样成立。回到Attention语境下,这里的$n$就是$d_k$。于是有,

那么有如下方差推导,

也就是说$\boldsymbol{q}_{i}\boldsymbol{k}_{j}^{\top}$的方差比$\operatorname{Var}(x_i y_i) $放大了$n$倍。这时候$\boldsymbol{q}_{i}\boldsymbol{k}_{j}^{\top}$有更大的概率取到大的值,进而有更大的概率落到softmax的饱和区间下,这时候的梯度几乎为0,因此不利于模型训练。可以类比到一维情况,$\sigma(x)$在$x$较大时落入到饱和区间来理解。

解决方案是$\boldsymbol{q}_{i}\boldsymbol{k}_{j}^{\top}$除以$\sqrt{n}$,于是有,

于是$\frac{\boldsymbol{q}_{i}\boldsymbol{k}_{j}^{\top}}{\sqrt{d_k}}$就是这么来的。还有其他角度吗?以下是一个启发的可以拓展的角度。

从光滑逼近角度理解

在过去的文章引入参数控制softmax的smooth程度中讨论过参数化softmax函数的方法。直觉上来看其实就是$\operatorname{softmax}(\alpha\boldsymbol{x})$,但是该直觉结果无法给予我们更多的解释以及参数化softmax的意义。那篇文章从光滑逼近的角度导出。

首先容易推导$\operatorname{one-hot}(\arg \max(\boldsymbol{x}))$的带参数的光滑逼近形式,

以上的推导需要说明三点:

  • 引入$x_i - \max(\boldsymbol{x})$使得最大值为0,使得$e^0 = 1$,对应one-hot中的1
  • 引入$e^x$​​​是考虑到$e^0=1, 0 \lt e^{x|_{x \lt 0}} \lt 1$​​​​​,更好适配one-hot特点
  • max不具有光滑性,被替换为其光滑近似logsumexp,可以参考函数光滑近似(1):maximum函数

根据以上的推导有极限,

因为,

这意味着参数$\alpha$可以控制$\operatorname{softmax}(\alpha \boldsymbol{x})$对$\operatorname{one-hot}(\arg \max(\boldsymbol{x}))$的逼近程度。因此,当$\alpha$越大,逼近程度越好,对应的就是输出向量越稀疏,能够容纳的上下文信息的范围就越小;类似地,当$\alpha$越小,逼近程度越差,对应的就是输出向量越稠密,能够容纳的上下文信息的范围越大。回到Attention语境下,

这里的$\alpha$就是$\frac{1}{\sqrt{d_k}}$。通过$\frac{1}{\sqrt{d_k}}$参数,控制Attention容纳上下文信息的范围。在漫谈注意力机制(二):硬性注意力机制与软性注意力机制也有类似的分析。

直观上来说,可以理解成是正太分布中的方差参数$\sigma^2$​,$\sigma$不同取值下的可视化,

当$\sigma$取值越大,图像越平缓。

论文TENER: Adapting Transformer Encoder for Named Entity Recognition中提到Un-scaled Dot-Product Attention,其实就是原来的Attention去掉$\sqrt{d_k}$参数,其在NER中的表现更好。论文是从经验上解释,去掉该参数后,Attention的图像会变得更sharper,进而仅仅关注token的若干个上下文即可而非全局上下文,更契合NER任务的特点,因此获得更好的性能。这个与以上的数学推导是一致的。

思考

基于以上分析,那么这里提出两个问题:

  • $\sqrt{d_k}$是否可以参数化?例如BERT在预训练时就参数化$\sqrt{d_k}$,或者预训练时是固定的参数,但是在具体任务fine-tune时是可学习的参数,这样可以根据任务本身的特点自适应地容纳多少上下文信息。
  • 不同的头是否可以使用不同的$\sqrt{d_k}$?这样可以极大地丰富不同头间的差异和表达能力。

第二个问题是第一个问题的很自然的延伸,既然$\sqrt{d_k}$可以参数化,那么不同的头使用不同的$\sqrt{d_k}$是很自然的事情。

总结

本文从随机向量点积的方差的性质上解释为什么Transformer中的MultiHeadAttention使用scaled。然后从光滑逼近从的角度启发式讨论这个scaled的意义。

感觉没有写完,待续~

转载请包括本文地址:https://allenwind.github.io/blog/16228
更多文章请参考:https://allenwind.github.io/blog/archives/