【HuggingFace Transformers】BertSdpaSelfAttention源码解析
BertSdpaSelfAttention源码解析
- 1. BertSdpaSelfAttention类 介绍
- 2. BertSdpaSelfAttention类 源码解析
1. BertSdpaSelfAttention类 介绍
BertSdpaSelfAttention类是 BERT 模型自注意力层的实现,继承 BertSelfAttention 类。它的目的是在特定的版本和条件下,使用 PyTorch 中的 scaled_dot_product_attention 函数来计算注意力分数,以提高效率。
2. BertSdpaSelfAttention类 源码解析
源码地址:transformers/src/transformers/models/bert/modeling_bert.py
# -*- coding: utf-8 -*-
# @time: 2024/7/15 14:30import torchfrom typing import Optional, Tuple
from packaging import version
from transformers.models.bert.modeling_bert import BertSelfAttention
from transformers.utils import get_torch_version, logginglogger = logging.get_logger(__name__)class BertSdpaSelfAttention(BertSelfAttention):def __init__(self, config, position_embedding_type=None):super().__init__(config, position_embedding_type=position_embedding_type) # 初始化父类,继承其配置和属性self.dropout_prob = config.attention_probs_dropout_prob # 存储注意力机制中的dropout概率# 检查当前的PyTorch版本,决定是否需要连续的qkv输入,低于2.2.0版本时设为Trueself.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")# Adapted from BertSelfAttentiondef forward(self,hidden_states: torch.Tensor, # 输入的隐藏状态张量attention_mask: Optional[torch.Tensor] = None, # 注意力掩码,可选head_mask: Optional[torch.FloatTensor] = None, # 头部掩码,可选encoder_hidden_states: Optional[torch.FloatTensor] = None, # 编码器的隐藏状态,仅在交叉注意力中使用encoder_attention_mask: Optional[torch.FloatTensor] = None, # 编码器的注意力掩码,仅在交叉注意力中使用past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, # 过去的键和值,用于缓存output_attentions: Optional[bool] = False, # 是否输出注意力权重) -> Tuple[torch.Tensor]:# 1. 如果位置嵌入类型不是 "absolute" 或者需要输出注意力权重或使用头部掩码,则调用BertSelfAttention的forward方法来处理if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None:# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented.# 记录一个警告,提示用户在未来版本中需要手动指定注意力实现logger.warning_once("BertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support ""non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to ""the manual attention implementation, but specifying the manual implementation will be required from ""Transformers version v5.0.0 onwards. This warning can be removed using the argument "'`attn_implementation="eager"` when loading the model.')return super().forward(hidden_states,attention_mask,head_mask,encoder_hidden_states,encoder_attention_mask,past_key_value,output_attentions,)# -------------------- 2. 位置嵌入类型position_embedding_type是 "absolute" -------------------------# ----------------- 2.1 获取输入的批次大小(bsz)、目标序列长度(tgt_len), 后面会用于shape的调整-----------------bsz, tgt_len, _ = hidden_states.size()# ---------------- 2.2 获取query_layer, key_layer, value_layer, attention_mask, is_causal, 用于注意力的计算-------# 将hidden_states投影到查询向量(query_layer),并使用transpose_for_scores方法调整其形状query_layer = self.transpose_for_scores(self.query(hidden_states))# If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention# mask needs to be such that the encoder's padding tokens are not attended to.""" 如果这是一个跨注意力模块实例化的情况,键和值来自编码器;注意力掩码需要确保编码器中的填充标记不被关注。 """# 判断是否是交叉注意力,并为current_states和attention_mask赋值is_cross_attention = encoder_hidden_states is not Nonecurrent_states = encoder_hidden_states if is_cross_attention else hidden_statesattention_mask = encoder_attention_mask if is_cross_attention else attention_mask# Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning# 如果是交叉注意力且有past_key_value,并且它的序列长度与current_states一致,直接使用缓存的键和值if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]:key_layer, value_layer = past_key_valueelse:# 否则,计算新的键值对,并在非交叉注意力情况下,将它们与过去的键值对拼接key_layer = self.transpose_for_scores(self.key(current_states))value_layer = self.transpose_for_scores(self.value(current_states))if past_key_value is not None and not is_cross_attention:key_layer = torch.cat([past_key_value[0], key_layer], dim=2)value_layer = torch.cat([past_key_value[1], value_layer], dim=2)# 如果是解码器,则缓存当前的键和值,以便在后续步骤中使用if self.is_decoder:# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.# Further calls to cross_attention layer can then reuse all cross-attention# key/value_states (first "if" case)# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of# all previous decoder key/value_states. Further calls to uni-directional self-attention# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)# if encoder bi-directional self-attention `past_key_value` is always `None`"""如果是交叉注意力,将所有交叉注意力的 key/value 状态保存为一个包含两个 torch.Tensor 的元组。后续对交叉注意力层的调用可以重用所有交叉注意力的 key/value 状态(即第一个 "if" 情况)。如果是单向自注意力(解码器),则保存所有先前解码器的 key/value 状态为一个包含两个 torch.Tensor 的元组。后续对单向自注意力层的调用可以将先前解码器的 key/value 状态与当前投影的 key/value 状态拼接起来(即第三个 "elif" 情况)。如果是编码器的双向自注意力,则 `past_key_value` 始终为 `None`。"""past_key_value = (key_layer, value_layer)# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.# Reference: https://github.com/pytorch/pytorch/issues/112577"""在 torch==2.1.2 中,当使用非连续的输入和自定义的注意力掩码时,带有内存高效后端的 SDPA(缩放点积注意力)是有问题的,因此我们需要在这里调用 `.contiguous()` 方法来确保输入是连续的。这个问题在 torch==2.2.0 中已被修复。'参考:https://github.com/pytorch/pytorch/issues/112577"""# 如果 PyTorch 版本低于 2.2.0 且设备类型为 CUDA 且 attention_mask 不为空,则确保 qkv 输入是连续的if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:query_layer = query_layer.contiguous()key_layer = key_layer.contiguous()value_layer = value_layer.contiguous()# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal# mask in case tgt_len == 1."""需要保证 tgt_len > 1,以匹配 AttentionMaskConverter.to_causal_4d 的行为,因为当 tgt_len == 1 时,它不会创建因果掩码。"""# 如果是解码器且没有注意力掩码且目标序列长度大于1,则启用因果注意力is_causal = self.is_decoder and attention_mask is None and tgt_len > 1# ----------- 2.3 使用 torch.nn.functional.scaled_dot_product_attention 计算注意力 -------------# 使用 PyTorch 的 scaled_dot_product_attention 函数计算注意力输出,传入查询、键、值、注意力掩码和dropout概率attn_output = torch.nn.functional.scaled_dot_product_attention(query_layer,key_layer,value_layer,attn_mask=attention_mask,dropout_p=self.dropout_prob if self.training else 0.0,is_causal=is_causal,)# ----------- 2.4 调整 attn_output 的形状, 用于返回计算后的注意力输出, 以及在解码器模式下缓存的键值 -----------------# 将注意力输出张量转置并调整形状以匹配原始输入的形状attn_output = attn_output.transpose(1, 2)attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size)# 返回包含注意力输出的元组,如果是解码器,还会返回缓存的键和值outputs = (attn_output,)if self.is_decoder:outputs = outputs + (past_key_value,)return outputs
