当前位置: 首页 > news >正文

Transformer简明笔记:文本翻译

Bert和gpt都是基于transformer的,在此之前流行的是rnn,复杂度有限且效率不高,容易受到文本长度的限制。
项目地址:https://github.com/lansinuote/Transformer_Example
b站视频:https://www.bilibili.com/video/BV19Y411b7qx?p=9&spm_id_from=pageDriver&vd_source=eca9b4f9ea9577b666c089a010621a99

总体架构

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
编码器:自注意力层->全连接层
解码器:自注意力层->编码解码注意力->全连接层

计算注意力

在这里插入图片描述
词向量编码
x1*wq得到queries,以此类推,得到Q K V
在这里插入图片描述
除以8和词向量的编码有关
z1是自注意力计算的结果
在这里插入图片描述
在这里插入图片描述
得到多组QKV向量,就是多头注意力
在这里插入图片描述
图中有八组这样的矩阵

词向量编码

在这里插入图片描述
右边计算出的结果是一样的,transformer会做同样的处理

在这里插入图片描述

在这里插入图片描述
pos是第几个词,i是第几个向量,pos是行,i是列,偶数列是上面的式子计算,奇数列是下面的式子计算在这里插入图片描述
红色是大数,蓝色是小数,第0列是sin,第1列是cos,波动比较快,波动频率会逐渐降低

MASK

在这里插入图片描述
把a b pad理解为一句话,为了把各个句子保持相同长度,会补充pad。对pad的计算没有意义,把对pad的注意力全部替换成mask,但是pad对其他的词的注意力不做处理。
在这里插入图片描述
b和c是要预测的结果,所以计算b的时候不能让a看到。
在这里插入图片描述

对两个mask取一个并集

完整计算流程

在这里插入图片描述
layerNorm这部分是短接的计算,然后数据标准化,得到z1,z2,全连接运算
在这里插入图片描述
n个encoder上下串联,decoder拿到x1,x2,也要计算注意力,标准化,encoder-decoder这一层和self-attention其实一样,只不过qkv是拿encoder计算得出的结果当作kv,自己的自注意力层计算出的结果当作q,短接相加,标准化,全连接,标准化,decoder也会有n个,串联,最终做一个全连接层的输出。
在这里插入图片描述
翻译过程,不断预测下一个字

实验数据的生成策略在这里插入图片描述

词表是x语言的所有词汇,这里只有7个词,模仿自然语言,采样概率不等,x的长度随机,均为模仿自然语言。在这里插入图片描述
最终目的:x翻译成y
所以x和y要有关联性,这里的关系非常简单,黑色箭头表明y当中的每一个词是x逆序得到的,小写字母翻译成大写字母,y当中的数字用9-x得到在这里插入图片描述
虚线的箭头表明,y中的第一位取决于x的最后一位,这样y的第一位和第二位是相同的,这样做是为了让y中的数据长度比x多一位,同时增加映射复杂度
在这里插入图片描述
在这里插入图片描述

代码实现:定义数据

在这里插入图片描述
data.py
在这里插入图片描述
字典中共39个词
在这里插入图片描述
m的概率最高
生成数据的函数:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
定义数据集:

在这里插入图片描述
len固定返回10万,get_data生成一对x和y
数据加载器比较简单,每次调用生成8对x和y

代码实现:util.py

注意力计算函数:
在这里插入图片描述
几维向量就除以几的平方根

归一化层:
在这里插入图片描述
规范化,数值的均值是0,标准差是1,bn(batch normalization)取不同的样本做归一化,ln(layer nomalization)对不同通道做归一化。

多头注意力计算层:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
位置编码层:
在这里插入图片描述
在这里插入图片描述
全连接输出层:
在这里插入图片描述

mask.py

在这里插入图片描述
trilmask 上三角mask
在这里插入图片描述
在这里插入图片描述

model.py

编码器:
在这里插入图片描述
完整的编码器:
在这里插入图片描述
解码器:
在这里插入图片描述
完整的解码器:
在这里插入图片描述
主模型:
在这里插入图片描述
维度是变化的,注释有误

main.py

在这里插入图片描述
第一列是epoch 第二列是i 第三列是learning rate 第四列是loss 不断下降 第五列是正确率,97%
在这里插入图片描述
预测时不需要y的最后一个字符,y的第0个字符一定是SOS,不需要预测
在这里插入图片描述
在后面补上49个pad
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
预测结果几乎一模一样

一个更加复杂的翻译任务

用transformer做加法
在这里插入图片描述
y是x左右两边的相加得到的,这个难度要高一些,替换掉生成数据的函数就可以得到,训练10个epoch,learning rate decay也生效了,最终准确率是92%
在这里插入图片描述


http://www.mrgr.cn/news/17978.html

相关文章:

  • 绿色物流:TMS在节能减排中的角色
  • 深入理解MySQL慢查询优化(1) -- 优化策略
  • Maven的常用插件
  • 台球助教预约系统小程序源码开发
  • 字符分类函数
  • 2024.08.26 校招 实习 内推 面经
  • 公司企业大楼智慧厕所建设步骤和技术要求@卓振思众
  • 在线演示文稿应用PPTist本地化部署并实现无公网IP远程编辑PPT
  • 考试系统将来市场会如何
  • centos7使用ifconfig查看IP,终端无ens33信息解决办法
  • Java基于微信小程序的超市购物管理系统
  • 多重背包问题 模板 C++实现
  • 关于contextmenu-ui组件库
  • 【Python123题库】#统计文章字符数 #查询高校信息 #查询高校名
  • IM项目:进阶版即时通讯项目---项目总览
  • Trino大量查询会导致HDFS namenode主备频繁切换吗?
  • LRU Cache
  • 5.12 飞行控制——PID参数优化
  • Oracle手动误删物理上的数据文件解决办法
  • 多头切片的关键:Model 类 call解释;LlamaModel 类 call解释;多头切片的关键:cache的数据拼接