当前位置: 首页 > news >正文

注意力机制中的三种掩码技术及其PyTorch实现

在深度学习中,特别是处理序列数据时,注意力机制是一种非常关键的技术,广泛应用于各种先进的神经网络架构中,如Transformer模型。为了确保模型能够正确处理序列数据,掩码技术发挥了重要作用。本文将介绍三种常见的掩码技术:填充掩码(Padding Mask)、序列掩码(Sequence Mask)和前瞻掩码(Look-ahead Mask),并提供相应的PyTorch代码实现。

1. 填充掩码(Padding Mask)

目的
填充掩码的主要目的是确保模型在处理填充的输入数据时,不会将这些无关的数据当作有效数据处理。在序列处理中,由于不同序列的长度可能不同,通常需要对较短的序列进行填充,以保证所有序列长度一致,便于批处理。然而,这些填充的部分并不包含实际信息,因此应该在模型处理时忽略。

PyTorch实现

import torchdef create_padding_mask(seq, pad_token=0):mask = (seq == pad_token).unsqueeze(1).unsqueeze(2)return mask  # (batch_size, 1, 1, seq_len)# 示例使用
seq = torch.tensor([[7, 6, 0, 0], [1, 2, 3, 0]])
padding_mask = create_padding_mask(seq)
print(padding_mask)
2. 序列掩码(Sequence Mask)

目的
序列掩码用于更广泛地控制模型应该关注的数据部分,不仅可以指示填充位置,还可以用于其他类型的掩蔽需求。例如,在序列到序列的任务中,可能需要隐藏未来信息,以确保模型在解码时不会“窥视”到未来信息。

PyTorch实现

def create_sequence_mask(seq):seq_len = seq.size(1)mask = torch.triu(torch.ones((seq_len, seq_len)), diagonal=1)return mask  # (seq_len, seq_len)# 示例使用
seq_len = 4
sequence_mask = create_sequence_mask(torch.zeros(seq_len, seq_len))
print(sequence_mask)
3. 前瞻掩码(Look-ahead Mask)

目的
前瞻掩码主要用于自回归模型中,以确保模型在生成序列时不会“看到”未来的符号。这保证了在给定位置的预测仅依赖于该位置之前的符号,维护了生成过程的时序正确性。

PyTorch实现

def create_look_ahead_mask(size):mask = torch.triu(torch.ones(size, size), diagonal=1)return mask  # (seq_len, seq_len)# 示例使用
look_ahead_mask = create_look_ahead_mask(4)
print(look_ahead_mask)
掩码在注意力机制中的应用

在注意力机制中,掩码被用来修改注意力得分,以确保模型在计算注意力权重时能够正确地考虑哪些部分应该被忽略。以下是一个使用掩码进行缩放点积注意力计算的示例:

import torch.nn.functional as Fdef scaled_dot_product_attention(q, k, v, mask=None):matmul_qk = torch.matmul(q, k.transpose(-2, -1))dk = q.size()[-1]scaled_attention_logits = matmul_qk / torch.sqrt(torch.tensor(dk, dtype=torch.float32))if mask is not None:scaled_attention_logits += (mask * -1e9)attention_weights = F.softmax(scaled_attention_logits, dim=-1)output = torch.matmul(attention_weights, v)return output, attention_weights# 示例使用
d_model = 512
batch_size = 2
seq_len = 4
q = torch.rand((batch_size, seq_len, d_model))
k = torch.rand((batch_size, seq_len, d_model))
v = torch.rand((batch_size, seq_len, d_model))
mask = create_look_ahead_mask(seq_len)
attention_output, attention_weights = scaled_dot_product_attention(q, k, v, mask)
print(attention_output)

http://www.mrgr.cn/news/10358.html

相关文章:

  • 张宇36讲+1000题重点强化!保100冲120速刷攻略
  • 设计模式——状态模式
  • 环绕音效是什么意思,电脑环绕音效怎么开
  • plsql表格怎么显示中文 plsql如何导入表格数据
  • [ICLR-24] LRM: Large Reconstruction Model for Single Image to 3D
  • 机器学习:决策树之回归树的原理
  • redis分布式是如何实现的(面试版)
  • 完成客户端/浏览器可以请求到控制层
  • 我的sql我做主!Mysql 的集群架构详解之组从复制、半同步模式、MGR、Mysql路由和MHA管理集群组
  • 8.26算法训练
  • PHP酒店宾馆民宿预订系统小程序源码
  • 力扣2025.分割数组的最多方案数
  • linux内核链表
  • Three 物体(四)
  • Python编码系列—Python中的HTTPS与加密技术:构建安全的网络通信
  • 使用HTML实现贪吃蛇游戏
  • 为什么制造企业智能化升级需要MES管理系统
  • 【Material-UI】Radio Group中的独立单选按钮详解
  • JavaScript 手写仿instanceof
  • Blazor开发框架Known-V2.0.9