Pytorch手撸Attention

news/2024/5/18 20:21:03

Pytorch手撸Attention

注释写的很详细了,对照着公式比较下更好理解,可以参考一下知乎的文章

注意力机制

在这里插入图片描述

import torch
import torch.nn as nn
import torch.nn.functional as Fclass SelfAttention(nn.Module):def __init__(self, embed_size):super(SelfAttention, self).__init__()self.embed_size = embed_size# 定义三个全连接层,用于生成查询(Q)、键(K)和值(V)# 用Linear线性层让q、k、y能更好的拟合实际需求self.value = nn.Linear(embed_size, embed_size)self.key = nn.Linear(embed_size, embed_size)self.query = nn.Linear(embed_size, embed_size)def forward(self, x):# x 的形状应为 (batch_size批次数量, seq_len序列长度, embed_size嵌入维度)batch_size, seq_len, embed_size = x.shapeQ = self.query(x)K = self.key(x)V = self.value(x)# 计算注意力分数矩阵# 使用 Q 矩阵乘以 K 矩阵的转置来得到原始注意力分数# 注意力分数的形状为 [batch_size, seq_len, seq_len]# K.transpose(1,2)转置后[batch_size, embed_size, seq_len]# 为什么不直接使用 .T 直接转置?直接转置就成了[embed_size, seq_len,batch_size],不方便后续进行矩阵乘法attention_scores = torch.matmul(Q, K.transpose(1, 2)) / torch.sqrt(torch.tensor(self.embed_size, dtype=torch.float32))# 应用 softmax 获取归一化的注意力权重,dim=-1表示基于最后一个维度做softmaxattention_weight = F.softmax(attention_scores, dim=-1)# 应用注意力权重到 V 矩阵,得到加权和# 输出的形状为 [batch_size, seq_len, embed_size]output = torch.matmul(attention_weight, V)return output

多头注意力机制

在这里插入图片描述

class MultiHeadAttention(nn.Module):def __init__(self, embed_size, num_heads):super().__init__()self.embed_size = embed_sizeself.num_heads = num_heads# 整除来确定每个头的维度self.head_dim = embed_size // num_heads# 加入断言,防止head_dim是小数,必须保证可以整除assert self.head_dim * num_heads == embed_sizeself.q = nn.Linear(embed_size, embed_size)self.k = nn.Linear(embed_size, embed_size)self.v = nn.Linear(embed_size, embed_size)self.out = nn.Linear(embed_size, embed_size)def forward(self, query, key, value):# N就是batch_size的数量N = query.shape[0]# *_len是序列长度q_len = query.shape[1]k_len = key.shape[1]v_len = value.shape[1]# 通过线性变换让矩阵更好的拟合queries = self.q(query)keys = self.k(key)values = self.v(value)# 重新构建多头的queries,permute调整tensor的维度顺序# 结合下文demo进行理解queries = queries.reshape(N, q_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)keys = keys.reshape(N, k_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)values = values.reshape(N, v_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)# 计算多头注意力分数attention_scores = torch.matmul(queries, keys.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))attention = F.softmax(attention_scores, dim=-1)# 整合多头注意力机制的计算结果out = torch.matmul(attention, values).permute(0, 2, 1, 3).reshape(N, q_len, self.embed_size)# 过一遍线性函数out = self.out(out)return out

demo测试

self-attention测试
# 测试自注意力机制
batch_size = 2
seq_len = 3
embed_size = 4# 生成一个随机数据 tensor
input_tensor = torch.rand(batch_size, seq_len, embed_size)# 创建自注意力模型实例
model = SelfAttention(embed_size)# print输入数据
print("输入数据 [batch_size, seq_len, embed_size]:")
print(input_tensor)# 运行自注意力模型
output_tensor = model(input_tensor)# print输出数据
print("输出数据 [batch_size, seq_len, embed_size]:")
print(output_tensor)

=======print=========

输入数据 [batch_size, seq_len, embed_size]:
tensor([[[0.7579, 0.7342, 0.1031, 0.8610],[0.8250, 0.0362, 0.8953, 0.1687],[0.8254, 0.8506, 0.9826, 0.0440]],[[0.0700, 0.4503, 0.1597, 0.6681],[0.8587, 0.4884, 0.4604, 0.2724],[0.5490, 0.7795, 0.7391, 0.9113]]])输出数据 [batch_size, seq_len, embed_size]:
tensor([[[-0.3714,  0.6405, -0.0865, -0.0659],[-0.3748,  0.6389, -0.0861, -0.0706],[-0.3694,  0.6388, -0.0855, -0.0660]],[[-0.2365,  0.4541, -0.1811, -0.0354],[-0.2338,  0.4455, -0.1871, -0.0370],[-0.2332,  0.4458, -0.1867, -0.0363]]], grad_fn=<UnsafeViewBackward0>)
MultiHeadAttention

多头注意力机制务必自己debug一下,主要聚焦在理解如何拆分成多头的,不结合代码你很难理解多头的操作过程

1、queries.reshape(N, q_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3) 处理之后的 size = torch.Size([64, 8, 10, 16])

  • 通过上述操作,queries 张量的最终形状变为 [N, self.num_heads, q_len, self.head_dim]。这样的排列方式使得每个注意力头可以单独处理对应的序列部分,而每个头的处理仅关注其分配到的特定维度 self.head_dim
  • 这个形状是为了后续的矩阵乘法操作准备的,其中每个头的查询将与对应的键进行点乘,以计算注意力分数

2、attention_scores = torch.matmul(queries, keys.transpose(-2, -1)) / torch.sqrt( torch.tensor(self.head_dim, dtype=torch.float32)) 将reshape后的quries的后两个维度进行转置后点乘,对应了 Q ⋅ K T Q \cdot K^T QKT ;根据demo这里的头数为8,所以公式中对应的下标 i i i 为8

3、在进行完多头注意力机制的计算后通过 torch.matmul(attention, values).permute(0, 2, 1, 3).reshape(N, q_len, self.embed_size) 整合,变回原来的 [batch_size,seq_length,embed_size]形状

# 测试多头注意力
embed_size = 128  # 嵌入维度
num_heads = 8    # 头数
attention = MultiHeadAttention(embed_size, num_heads)# 创建随机数据模拟 [batch_size, seq_length, embedding_dim]
batch_size = 64
seq_length = 10
dummy_values = torch.rand(batch_size, seq_length, embed_size)
dummy_keys = torch.rand(batch_size, seq_length, embed_size)
dummy_queries = torch.rand(batch_size, seq_length, embed_size)# 计算多头注意力输出
output = attention(dummy_values, dummy_keys, dummy_queries)
print(output.shape)  # [batch_size, seq_length, embed_size]

=======print=========

torch.Size([64, 10, 128])

如果你难以理解权重矩阵的拼接和拆分,推荐李宏毅的attention课程(YouTobe)


http://www.mrgr.cn/p/78658262

相关文章

架构师核心-云计算云上实战(云计算基础、云服务器ECS、云设施实战、云上高并发Web架构)

文章目录 云计算基础1. 概念1. 云平台优势2. 公有云3. 私有云4. IaaS、PaaS、SaaS 2. 云设施1. 概览2. 核心组件 云服务器ECS1. ECS介绍1. 简介2. 组件3. 概念4. 图解5. 规格6. 场景 2. ECS服务器开通1. 开通服务器2. 连接服务器 3. 云部署准备1. 1Panel介绍2. 安装1Panel3.安全…

读天才与算法:人脑与AI的数学思维笔记09_分形

分形1. 分形 1.1. 1904年,瑞典数学家科赫(Helge von Koch)首次发表了雪花图案的结构——科赫曲线(又称雪花曲线),它被认为是一种数学怪胎,一种奇怪的人工构造 1.1.1. 但实际上并不是,自然界中到处都是以分形结构存在着的图形 1.1.2. 既不能说科赫曲线是一维的,也不能说…

C/C++程序设计实验报告4 | 函数实验

本文整理自博主本科大一《C/C程序设计》专业课的课内实验报告&#xff0c;适合C语言初学者们学习、练习。 编译器&#xff1a;gcc 10.3.0 ---- 注&#xff1a; 1.虽然课程名为C程序设计&#xff0c;但实际上当时校内该课的内容大部分其实都是C语言&#xff0c;C的元素最多可能只…

[转帖] 银河麒麟系统安全机制-KYSEC

https://zhuanlan.zhihu.com/p/349663329 麒麟系统为什么称为国内最安全的Linux系统?秘密就在于KYSEC,麒麟系统安全机制。一般情况下Linux下默认的接入控制是DAC,其特点是资源的拥有者可以对他进行任何操作(读、写、执行)。当一个进程准备操作资源时,Linux内核会比较进程…

[转帖]数据库系列之简要对比下GaussDB和OpenGauss数据库

https://blog.csdn.net/solihawk/article/details/134939941GaussDB作为一款企业级的数据库产品,和开源数据库OpenGauss之间又是什么样的关系,刚开始接触的时候是一头雾水,因此本文简要对比下二者的区别,以加深了解。 1、GaussDB和OpenGauss数据库简要对比 GaussDB是华为基…

分布式与一致性协议之CAP(二)

CAP CAP不可能三角 CAP不可能三角是指对于一个分布式系统而言&#xff0c;一致性、可用性、分区容错性指标不可兼得&#xff0c;只能从中选择两个&#xff0c; 如图所示。CAP不可能三角最初是埃里克布鲁尔(Eric Brewer)基于自己的工程实践提出的一个猜想&#xff0c;后被塞斯吉…

如何在CentOS搭建docker compose ui可视化工具并无公网IP远程管理容器

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

在一台笔记本电脑上试用Ubuntu22.04

在一台笔记本电脑上试用Ubuntu22.04。 本来想看以下该操作系统能否识别笔记本电脑上的硬盘&#xff0c;于是下载试一下。选了一个国内镜像网站下载。下载速度很快。下载以后用软件win image 将下载的iso文件写到U盘上&#xff0c;用的是usb2.0的U盘&#xff0c;该操作用时11分…

【八股】计算机网络篇

网络模型 应用层【HTTP&#x1f449;报文/消息】 传输层【TCP或UDP&#x1f449;段&#x1f449;MSS】网络层【IP、寻址和路由&#x1f449;MTU】 ①IP&#xff08;Internet Protocol&#xff0c;网际协议&#xff09;主要作用是定义数据包的格式、对数据包进行路由和寻址&…

9.Eureka服务发现+Ribbon+RestTemplate服务调用

order-service服务通过服务名称来代替 ip:port的方式访问user-service服务的接口。 原来的请求代码&#xff1a; Service public class OrderServiceImpl implements OrderService {Autowiredprivate OrderMapper orderMapper;Autowiredprivate RestTemplate restTemplate;Ov…

深度学习中的熵、交叉熵、相对熵(KL散度)、极大释然估计之间的联系与区别

熵的最初来源于热力学。在热力学中&#xff0c;熵代表了系统的无序程度或混乱程度&#xff0c;也可以理解为系统的热力学状态的一种度量。后来被广泛引用于各个领域中&#xff0c;如信息学、统计学、AI等&#xff0c;甚至社会学当中。接下来将大家领略一下深度学习中熵的应用。…

npm、yarn与pnpm详解

&#x1f525; npm、yarn与pnpm详解 &#x1f516; 一、npm &#x1f50d; 简介&#xff1a; npm是随Node.js一起安装的官方包管理工具&#xff0c;它为开发者搭建了一个庞大的资源库&#xff0c;允许他们在这个平台上搜索、安装和管理项目所必需的各种代码库或模块。 &#…

面向对象三大特征(python)

目录 1. 封装 为什么使用封装&#xff1f; 如何实现封装&#xff1f; 一个简单的封装示例 二.继承 为什么使用继承&#xff1f; 如何实现继承&#xff1f; 一个简单的继承示例 使用继承的好处 三.多态 为什么使用多态&#xff1f; 如何实现多态&#xff1f; 一个简…

Django模型继承之多表继承

在Django模型继承中&#xff0c;支持的第二种模型继承方式是层次结构中的每个模型都是一个单独的模型。每个模型都指向分离的数据表&#xff0c;并且可以被独立查询和创建。在继承关系中&#xff0c;子类和父类之间通过一个自动创建的OneToOneField进行连接。示例代码如下&…

4.26日学习记录

[湖湘杯 2021 final]Penetratable SUID提权 SUID是一种对二进制程序进行设置的特殊权限&#xff0c;可以让二进制程序的执行者临时拥有属主的权限 SUID具有一定的限制&#xff1a; 1.仅对于二进制有效&#xff1b; 2.执行者在程序中有可以执行的权限&#xff1b; 3.权限仅在程序…

公开课学习——基于索引B+树精准建立高性能索引

文章目录 遇到慢查询怎么办&#xff1f;—— 创建索引联合索引的底层的数据存储结构长什么样&#xff1f; mysql脑图 阿里开发手册 遇到慢查询怎么办&#xff1f;—— 创建索引 不用索引的话一个一个找太慢了&#xff0c;用索引就快的多。 假如使用树这样的结构建立索引&#x…

[Algorithm][前缀和][和为K的子数组][和可被K整除的子数组][连续数组][矩阵区域和]详细讲解

目录 1.和为 K 的子数组1.题目链接2.算法原理详解3.代码实现 2.和可被 K 整除的子数组1.题目链接2.算法原理详解3.代码实现 3.连续数组1.题目链接2.算法原理详解3.代码实现 4.矩阵区域和1.题目链接2.算法原理详解3.代码实现 1.和为 K 的子数组 1.题目链接 和为 K 的子数组 2.…

数据可视化-ECharts Html项目实战(14)

在之前的文章中&#xff0c;我们深入学习ECharts鼠标左键触发。想了解的朋友可以查看这篇文章。同时&#xff0c;希望我的文章能帮助到你&#xff0c;如果觉得我的文章写的不错&#xff0c;请留下你宝贵的点赞&#xff0c;谢谢。 数据可视化-ECharts Html项目实战&#xff08;…

【蓝桥杯2025备赛】素数判断:从O(n^2)到O(n)学习之路

素数判断:从O( n 2 n^2 n2)到O(n)学习之路 背景:每一个初学计算机的人肯定避免不了碰到素数&#xff0c;素数是什么&#xff0c;怎么判断&#xff1f; 素数的概念不难理解:素数即质数&#xff0c;指的是在大于1的自然数中&#xff0c;除了1和它本身不再有其他因数的自然数。 …

动手写sql 《牛客网80道sql》

第1章&#xff1a;SQL编写基础逻辑和常见问题 基础逻辑 SELECT语句: 选择数据表中的列。FROM语句: 指定查询将要从哪个表中检索数据。WHERE语句: 过滤条件&#xff0c;用于提取满足特定条件的记录。GROUP BY语句: 对结果进行分组。HAVING语句: 对分组后的结果进行条件过滤。O…