当前位置: 首页 > news >正文

深入理解Self-Attention - 原理与等价表示

概述

很多文章或论文已经很好的解释了神经网络中self-attention的原理,但是个人觉得还是有其他可解释的方面,主要原因是很多解释都是面向过程的,只解释了它是什么样的,这篇文章主要从其等价形式解释其原理。

Self-Attention原理

普适的self-attention的公式为:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( 1 d k Q K T ) V Attention(Q,K,V) = softmax(\frac{1}{\sqrt{d_k} } QK^\mathsf{T})V Attention(Q,K,V)=softmax(dk 1QKT)V其中Q,K,V分别表示Query,Key,Value,且
Q = x W q Q = xW_q Q=xWq K = x W k K = xW_k K=xWk V = x W v V = xW_v V=xWv即Q,K,V都是 x x x 经过线性变换后的数据,其中 x ∈ R m × n ; W q , W k , W v ∈ R n × j x\in R^{m\times n};W_q,W_k,W_v\in R^{n\times j} xRm×n;Wq,Wk,WvRn×j,即每一个token为 n n n 维向量, x x x是由 m m m 个单词构成的句子。而 1 d k \frac{1}{\sqrt{d_k}} dk 1是避免矩阵乘值过大的超参数,本质上说self-attention的核心内容为:
A t t e n t i o n ( x ) = s o f t m a x ( x W q ( x W k ) T ) V Attention(x) = softmax(xW_q(xW_k)^\mathsf{T})V Attention(x)=softmax(xWq(xWk)T)V

Self-Attention的等价表示

这里为了简便,我们先解释在softmax内的矩阵的等价形式即
M = x W q ( x W k ) T M=xW_q(xW_k)^\mathsf{T} M=xWq(xWk)T假设存在 W ~ \tilde{W} W~其中 W ~ ∈ R m × m \tilde{W}\in R^{m\times m} W~Rm×m使得以下等式成立
M = x W q ( x W k ) T = x x T ⊙ W ~ M=xW_q(xW_k)^\mathsf{T}=xx^\mathsf{T}\odot\tilde{W} M=xWq(xWk)T=xxTW~其中 ⊙ \odot 表示哈达玛积(Hadamard product,即逐点乘积),我们可以计算出等式左边和右边得到都是 m × m m\times m m×m的矩阵,根据哈达玛逆的性质:一个方阵的逆乘以其自身为单位矩阵,所以我们可以两边左乘 ( x x T ) ⊙ − 1 (xx^\mathsf{T})^{\odot-1} (xxT)⊙−1,得到
W ~ = ( x x T ) ⊙ − 1 ⊙ ( x W q ( x W k ) T ) \tilde{W} = (xx^\mathsf{T})^{\odot-1}\odot(xW_q(xW_k)^\mathsf{T}) W~=(xxT)⊙−1(xWq(xWk)T) 这里需要矩阵 x x T xx^\mathsf{T} xxT的任意元素不为0,即 [ x x T ] i j ≠ 0 [xx^\mathsf{T}]_{ij}\ne0 [xxT]ij=0,由于参数是可学习的所以原始矩阵 M M M中的参数可以使用 W ~ \tilde{W} W~来代替 W q , W k W_q,W_k Wq,Wk

当然我们也可以使用低秩矩阵来表示 W ~ \tilde{W} W~,即
W ~ = W a W b T \tilde{W}=W_aW_b^{\mathsf{T}} W~=WaWbT 其中 W a , W b ∈ R m × j W_a,W_b\in R^{m\times j} Wa,WbRm×j,所以我们可以得到另一种表示
M = x x T ⊙ ( W a W b ) T M = xx^\mathsf{T}\odot(W_aW_b)^{\mathsf{T}} M=xxT(WaWb)T 为了简单,这里我们不写成低秩的形式,所以self-attention可以写为
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( 1 d k x x T ⊙ W ~ ) V Attention(Q,K,V) = softmax(\frac{1}{\sqrt{d_k} } xx^\mathsf{T}\odot\tilde{W})V Attention(Q,K,V)=softmax(dk 1xxTW~)V 我们可以看到,其实Attention中本质上关键的是方阵 x x T xx^\mathsf{T} xxT,右边的参数 W ~ \tilde{W} W~只是对该方阵的线性变换。

    import torchk,m,n = 2,3,4x = torch.randn(m,n)wq = torch.randn(n,k)wk = torch.randn(n,k)Q = x@wqK = x@wkprint("Q@K.T:",Q@K.T)W_ = (1.0/(x@x.T))*(Q@K.T)print("x@x.T*W_:",x@x.T*W_)

x x T ⊙ W ~ xx^\mathsf{T}\odot\tilde{W} xxTW~ 的解释

这里我们用简单的例子来阐述 x x T xx^\mathsf{T} xxT ,假设存在一个以 m = 3 m=3 m=3个单词组成的句子
x = [ a , b , c ] T x = [a,b,c]^{\mathsf{T}} x=[a,b,c]T 其中 a , b , c ∈ R n a,b,c\in R^n a,b,cRn ,那么自然 x ∈ R 3 × n x \in R^{3 \times n} xR3×n,我们可以得到
x x T = [ a a a b a c b a b b b c c a c b c c ] xx^\mathsf{T} = \begin{bmatrix} aa& ab&ac \\ ba& bb&bc \\ ca& cb&cc \end{bmatrix} xxT= aabacaabbbcbacbccc 这里向量 a a , a b , . . . , c c aa,ab,...,cc aa,ab,...,cc 为两个向量的点积,而点积也可以表示为如下形式
a ⋅ b = ∥ a ∥ ∥ b ∥ c o s ( θ ) a\cdot b = \left \| a \right \| \left \| b \right \| cos(\theta) ab=abcos(θ) 反过来 a ⋅ b ∥ a ∥ ∥ b ∥ = c o s ( θ ) \frac{a\cdot b}{\left \| a \right \| \left \| b \right \|} = cos(\theta) abab=cos(θ)所以点积可以表示非归一化的相似性,或者准确来说是表示的相关性。一般在神经网络中不用除以范数来表示相关性,经验性的结果来说,一是收敛比较困难、二是计算量增加了不少。而直接用点积也有其局限性(静态的相关性),所以我们可以带参数的形式,即可学习的相关性:
a ⋅ b ⋅ w ≈ c o s ( θ ) a\cdot b\cdot w \approx cos(\theta) abwcos(θ) 这里 ≈ \approx 只是表示一个近似操作而非约等于,所以本质上Attention中 x x T ⊙ W ~ xx^\mathsf{T}\odot\tilde{W} xxTW~表示的是学习后句子中各单词的相关关系程度矩阵,需要说明的是这种相关性是学习到的,而非像 c o s i n e cosine cosine静态的关系(比如这句话:“我喜欢小狗,它们很可爱”,token“小狗”与“小狗”的cosine相似度一定是最高的,而学习到的这种相关性并不一定是最高的,而可能是“小狗”与“他们”有更高的相关性)。

我们也要注意到 当句子很长时 W ~ \tilde{W} W~ 矩阵是 m × m m\times m m×m的,所以参数量将非常大,容易过拟合,而且计算也大了很多,所以等价和等计算是完全不同的,原始的attention中参数 W q , W k , W v W_q,W_k,W_v Wq,Wk,Wv 本质是(token级别)共享的变换矩阵,类似于CNN中的卷积。

1 d k \frac{1}{\sqrt{d_k} } dk 1的解释

这个 1 d k \frac{1}{\sqrt{d_k} } dk 1其实是在解决softmax归一化存在的问题,或者说是指数函数存在的问题。

即假设一个向量 x = [ a , b , c ] x=[a,b,c] x=[a,b,c],其中各元素不等,我们将 x x x放大 s s s倍,这时随着放大倍数的增大, y = s o f t m a x ( x ) y=softmax(x) y=softmax(x)中向量中某一变量将趋向于1,换句话说,softmax不再是soft max而是hard max。而如果与Attention中的 V V V作点积,那么Attention就变为了简单的取某一个token,而不是token的加权混合。

import torch
x = torch.randn(3)
print(torch.softmax(x,dim=0))
print(torch.softmax(x*100,dim=0))

从连接的图结构看attention

实际上不管是MLP中的全连接层,还是CNN中的卷积操作,本身表示的是变量(或token)与变量聚合的操作,如:
y = w 1 x 1 + w 2 x 2 + b y = w_1x_1+w_2x_2+b y=w1x1+w2x2+b而attention不同,它首先所表示的是变量与变量间的关系,以及根据变量间关系的特征聚合的操作,某种程度来说它与全连接或卷积的连接方式有根本的不同。

总结

我们通过self-attetion的等价形式
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( 1 d k x x T ⊙ W ~ ) V Attention(Q,K,V) = softmax(\frac{1}{\sqrt{d_k} } xx^\mathsf{T}\odot\tilde{W})V Attention(Q,K,V)=softmax(dk 1xxTW~)V 可以更容易解释attention的作用,即 1 d k x x T \frac{1}{\sqrt{d_k}} xx^\mathsf{T} dk 1xxT 表示句子内各单词的相关关系程度矩阵,而 W ~ \tilde{W} W~表示对相关关系的线性变换,softmax则是对变换后的相关关系在句子内的归一化,乘以 V V V则是对线性变换后的token作加权特征混合(feature mixing)。

参考

  1. Attention Is All You Need

http://www.mrgr.cn/news/92175.html

相关文章:

  • 玩机日记 11 解决fnOS识别不了虚拟核显的问题
  • 【Python量化金融实战】-第2章:金融市场数据获取与处理:2.1 数据源概览:Tushare、AkShare、Baostock、通联数据(DataAPI)
  • 【初探数据结构】时间复杂度和空间复杂度
  • 玩机日记 12 fnOS使用lucky反代https转发到外网提供服务
  • 结构型模式--组合模式
  • Vue3 + Vite使用 vue-i18n
  • 安全面试5
  • 15.4 FAISS 向量数据库实战:构建毫秒级响应的智能销售问答系统
  • 谈谈 ES 6.8 到 7.10 的功能变迁(3)- 查询方法篇
  • MySQL-数据库的基本操作
  • X86_64位下的GS寄存器
  • 15.1 智能销售顾问系统架构与业务价值解析:AI 如何重塑销售流程
  • Windows 11【1001问】如何下载Windows 11系统镜像
  • 排序算法漫游:从冒泡到堆排的底层逻辑与性能厮杀
  • 系统学习算法:专题十二 记忆化搜索
  • 快速上手 Unstructured:安装、Docker部署及PDF文档解析示例
  • STM32-智能小车项目
  • 人工神经网络ANN入门学习笔记1
  • 前端防重复请求终极方案:从Loading地狱到精准拦截的架构升级
  • UE 跟着官方文档学习 容器TArray 系列三