当前位置: 首页 > news >正文

GPT2模型源码解析

# 注意力机制的核心思想是在给定的输入序列中,让模型能够关注(或分配权重)不同的位置。在Transformer模型中,注意力机制通常分为以下几个步骤:
# 计算注意力得分:通过查询(query)、键(key)和值(value)之间的矩阵乘法计算注意力得分。
# 缩放注意力得分:通常会对注意力得分进行缩放,以防止数值不稳定。
# 应用掩码(masking):根据不同的应用场景,可能需要对某些位置的注意力得分进行屏蔽。
# 归一化注意力得分:使用softmax函数将注意力得分转换为概率分布。
# 计算注意力输出:通过将归一化的注意力权重与值向量相乘得到注意力输出。

def _get_unpad_data(attention_mask):
    # 计算了每个样本的有效长度(即非填充部分的长度)
    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
    # 返回了 attention_mask 中所有非零元素(即非填充token)的索引。首先,attention_mask 被展平成一维(flatten()),
    # 然后通过 torch.nonzero(..., as_tuple=False) 获取所有非零元素的位置,并且返回的是一个包含索引的张量。最后
    # .flatten() 将这些索引从二维转换成一维,方便后续处理。
    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
    # 找到批次中的最大有效序列长度。seqlens_in_batch 包含了批次中每个样本的有效长度,通过对这些长度取最大值(max())
    # ,我们可以知道当前批次中最长的序列有多长,.item() 方法则是将张量转换为Python标量
    max_seqlen_in_batch = seqlens_in_batch.max().item()
    # seqlens_in_batch 是一个包含每个样本序列长度的一维张量。torch.cumsum 函数沿着指定维度(在这里是
    # dim=0,表示沿着第一个维度)对张量元素进行累积求和。
    # 累积求和的结果是一个新的张量,其中第 i 个元素是前 i+1 个序列长度的总和。这通常用于快速访问不同样本间非
    # 填充token的起始索引位置。padding 操作在累积求和得到的张量前面添加一个零元素这样做的目的是确保第一个元素为0,
    # 这可以方便地处理索引计算,尤其是在使用某些并行处理算法时,比如在实现 batched sequence 数据的高效处理时。
    # cu_seqlens 提供了每个样本(包括样本自身)及其之前的所有样本的非填充token的累计数量。这对于之后在没有填充元
    # 素的情况下处理变长序列是非常有用的。
    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
    return (
        indices,
        cu_seqlens,
        max_seqlen_in_batch,
    )
# 用于从TensorFlow检查点文件加载权重到PyTorch模型中。为了能够成功加载权重,TensorFlow模型与PyTorch模型的
# 架构必须在结构上是一致的,至少在权重对应的层上要有一一对应的关系。
# 确保两个框架下的模型架构一致性是很重要的,特别是在转换预训练权重时。如果你尝试加载的模型有不同版本或者有架
# 构上的差异,那么你需要修改上述代码中的逻辑,以便正确地映射变量名和权重。
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
    # 读取TensorFlow Checkpoint
    tf_path = os.path.abspath(gpt2_checkpoint_path)
    logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
    # Load weights from TF model
    # 列出和加载TensorFlow检查点文件中的变量名和对应的权重数组。
    init_vars = tf.train.list_variables(tf_path)
    names = []
    arrays = []
    # 映射变量名称到PyTorch模型
    for name, shape in init_vars:
        logger.info(f"Loading TF weight {name} with shape {shape}")
        array = tf.train.load_variable(tf_path, name)
        names.append(name)
        arrays.append(array.squeeze())
    # 根据变量名来定位PyTorch模型中的相应层。变量名被解析以确定如何导航到特定的权重或偏置项等。
    for name, array in zip(names, arrays):
        name = name[6:]  # skip "model/"
        name = name.split("/")
        pointer = model
        for m_name in name:
            if re.fullmatch(r"[A-Za-z]+\d+", m_name):
                scope_names = re.split(r"(\d+)", m_name)
            else:
                scope_names = [m_name]
            if scope_names[0] == "w" or scope_names[0] == "g":
                pointer = getattr(pointer, "weight")
            elif scope_names[0] == "b":
                pointer = getattr(pointer, "bias")
            elif scope_names[0] == "wpe" or scope_names[0] == "wte":
                pointer = getattr(pointer, scope_names[0])
                pointer = getattr(pointer, "weight")
            else:
                pointer = getattr(pointer, scope_names[0])
            if len(scope_names) >= 2:
                num = int(scope_names[1])
                pointer = pointer[num]
        try:
            if pointer.shape != array.shape:
                raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
        except ValueError as e:
            e.args += (pointer.shape, array.shape)
            raise
        logger.info(f"Initialize PyTorch weight {name}")
        # 如果找到了匹配的层并且形状一致,则使用 torch.from_numpy 将NumPy数组转换为PyTorch张量,
        # 并将其赋值给模型中的相应参数。
        pointer.data = torch.from_numpy(array)
    return model
# 我们希望找到可以被安全移除的heads(也就是那些对模型性能贡献较小或冗余的heads)
# 同时也要确定哪些heads仍然是活跃的(即没有被修剪的)
# 传人参数:heads (List[int]):整数列表,要修剪的头的索引。n_heads (int):头数
# head_size: int,指的是每一个注意力头的维度大小。already_pruned_heads (Set[int]):这是一个整数集合,包含了已经被修剪
# 掉的heads的索引。集合的数据结构保证了元素的唯一性,所以不会有重复的索引
# 函数的目的是确定哪些heads是可以进一步修剪的(即它们不在already_pruned_heads集合中),并且计算出所有未被修剪的
# heads的索引。这对于优化模型的大小或者减少计算成本是有帮助的
# 返回的结果是一个元组,包含了两个元素:第一个元素是一个集合 (Set[int]),包含了被修剪的heads的索引。
# 第二个元素是一个torch.LongTensor,包含了所有未被修剪的头的嵌入的索引
# 修剪操作:当我们修剪掉一个或几个head时,实际上减少的是注意力机制的输出维度的一部分。这意味着在修剪之后,注意力机制的输
# 出维度将会是剩余未被修剪的heads的数量乘以每个head的head_size。
# 保持维度一致:为了保持模型其他部分的一致性,通常会有一个额外的线性层来将修剪后的注意力机制输出转换回原来的维度。这样做的
# 目的是确保修剪后的模型可以继续与其它层(如前馈神经网络层等)兼容。
def find_pruneable_heads_and_indices(
    heads: List[int], n_heads: int, head_size: int, already_pruned_heads: Set[int]
) -> Tuple[Set[int], torch.LongTensor]:
    # 创建一个全为1的掩码,用于标记哪些头是可用的(未修剪的)
    mask = torch.ones(n_heads, head_size)
    # 转换输入的heads列表为集合,并去除已经修剪过的头的索引
    heads = set(heads) - already_pruned_heads
    # 遍历剩余的、需要修剪的head索引,更新掩码
    for head in heads:
        # 因为要根据这个head索引来设置把mask中哪些嵌入设置为0,这里把在当前头索引之前的
        # 设置为1,之后的为0,为1说明要移位
        head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
        mask[head] = 0 # 将要修剪的头的对应位置在掩码中设为0
    # 将二维的掩码展平为一维,并只保留值为1的元素(即未修剪的头)的索引,注意:.eq(1) 生成一
    # 个布尔型tensor,其中True表示对应位置的值是1
    mask = mask.view(-1).contiguous().eq(1)
    # 使用torch.arange生成一个从0到len(mask)-1的tensor,然后通过布尔索引选出值为True的索引  
    # 这些索引对应于未修剪的头的位置  
    index: torch.LongTensor = torch.arange(len(mask))[mask].long()
    # 函数返回当前修剪过的头索引列表,未修剪的头索引对应的嵌入位置
    return heads, index
# 什么是连续存储?
# 一个张量在内存中被认为是连续存储的(contiguous),如果它的元素按照其形状和步幅(strides)在内存中紧密排列在一起
# 一维张量的所有元素都是连续的。
# 多维张量按照其形状和步幅在内存中紧密排列,即每个维度的元素都是紧密相连的。
# 何时需要使用contiguous()?
# 当你需要确保张量在内存中是连续存储的,尤其是在以下情况时:
# 将张量传递给某些操作或函数时,这些操作可能要求输入是连续的。
# 使用.view()方法改变张量的形状时,如果原张量不是连续存储的,则需要先调用contiguous()。
# 在进行某些操作(如某些类型的内存拷贝)时,确保数据是连续的可以提高效率。
# 返回新的layer,new_layer就是一个新的Conv1D层,其权重和偏置已被修剪,并且保持了与原层相同的设备和计算梯度的能力。
def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1) -> Conv1D:
    index = index.to(layer.weight.device) # 确保索引张量位于与layer相同的设备上
    # layer.weight得到的是形状如[512, 1536]的权重矩阵,index_select(dim, index)
    # 根据提供的索引选择权重张量中的列或行。
    # Conv1D层的权重形状通常为 [input_features,output_features]。修剪时,根据dim参数的不同,可以
    # 选择沿输入特征维度(dim=1)或输出特征维度(dim=0)进行修剪。
    W = layer.weight.index_select(dim, index).clone().detach()
    if dim == 0:
        b = layer.bias.clone().detach()
    else: # 选中index对应的那些项
        b = layer.bias[index].clone().detach()
    new_size = list(layer.weight.size()) # 获取layer权重的形状[512, 1536]
    new_size[dim] = len(index)  # [512, 1152],改变new_size索引1的维度的长度
    #复制权重和偏置:复制选中的权重和偏置,并关闭梯度计算。
    # 创建新层:创建一个新的Conv1D层,并设置其权重和偏置
    new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
    # 在复制权重和偏置时,先关闭梯度计算,然后再启用,这是为了避免在复制过程中引入不必要的计算图。
    new_layer.weight.requires_grad = False # 设置禁用w梯度
    # W.contiguous() 的使用是为了确保在复制权重数据时,数据是连续存储的。这样做可以确保在进行某些操作时不会出
    # 现问题,并且可以提高操作的效率。
    new_layer.weight.copy_(W.contiguous()) # 复制W的权重数据
    new_layer.weight.requires_grad = True # 启用梯度
    # 启用梯度计算:在复制完权重和偏置后,重新启用梯度计算。
    new_layer.bias.requires_grad = False # 禁用bias梯度
    new_layer.bias.copy_(b.contiguous())
    new_layer.bias.requires_grad = True
    return new_layer
class Conv1D(nn.Module):
    def __init__(self, nf, nx):
        super().__init__()
        self.nf = nf # nf:输出特征的数量。nx:输入特征的数量
        # 初始化权重和偏置,并使用nn.init.normal_对权重进行初始化。
        self.weight = nn.Parameter(torch.empty(nx, nf))
        self.bias = nn.Parameter(torch.zeros(nf))
        nn.init.normal_(self.weight, std=0.02) # 权重初始化
    def forward(self, x):
        size_out = x.size()[:-1] + (self.nf,) # size_out:计算输出的形状
        # x.view(-1, x.size(-1)):将输入张量展平,以便进行矩阵乘法。
        # 使用addmm函数进行矩阵乘法,并加上偏置。
        x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
        x = x.view(size_out) # 将结果重塑为原来的形状
        return x
class GPT2Attention(nn.Module):
    def __init__(self, config, is_cross_attention=False, layer_idx=None):
        super().__init__()
        self.config = config
        max_positions = config.max_position_embeddings
        # 偏置矩阵 (bias) 和掩码偏置 (masked_bias) 的注册
        # bias 是一个布尔类型的下三角矩阵,用于实现因果掩码(即,防止当前位置看到未来的信息
        self.register_buffer(
            "bias",
            torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
                1, 1, max_positions, max_positions
            ),
            persistent=False,
        )
        # masked_bias 是一个用于掩码操作的极大负数,确保在softmax操作后,被掩码的位置概率接近于0。
        self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
        # 嵌入维度 (embed_dim), 头数 (num_heads), 和每头的维度 (head_dim) 的设置
        self.embed_dim = config.hidden_size # d
        self.num_heads = config.num_attention_heads # h
        self.head_dim = self.embed_dim // self.num_heads # dk
        self.split_size = self.embed_dim
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
                f" {self.num_heads})."
            )
        # 注意力缩放和重新排序配置
        self.scale_attn_weights = config.scale_attn_weights # 决定是否在注意力计算时缩放权重。
        self.is_cross_attention = is_cross_attention # 是否是交叉注意力
        # Layer-wise attention scaling, reordering, and upcasting
        # 表示是否按层索引逆序缩放注意力权重。
        self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
        self.layer_idx = layer_idx
        # 表示是否重新排序和提升精度进行注意力计算。
        self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
        # 卷积层定义 (Conv1D)
        # c_attn 和 c_proj 层用于线性变换输入和输出。
        if self.is_cross_attention:
            # 在交叉注意力情况下,c_attn 和 q_attn 分别用于编码键/值对和查询向量
            self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
            self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
        else:
            self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
        self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
        # attn_dropout 和 resid_dropout 分别用于注意力和残差连接中的dropout。
        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)
        # 因果掩码和已剪枝头部集合 (pruned_heads):
        self.is_causal = True  # 因果掩码
        self.pruned_heads = set()  # 用于存储已被剪枝的注意力头
    # 剪枝 (prune_heads 方法),通过剪枝减少计算复杂度
    # 此方法用于删除指定的注意力头,减少模型参数数量。它通过调整卷积层的大小来实现这一点,并更新模型的超参数。
    def prune_heads(self, heads):
        # 如果没有要修剪的头,方法返回
        if len(heads) == 0:
            return
        heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
        # 这个是q,k,v对应的嵌入位置
        index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
        # Conv1D(3 * self.embed_dim, self.embed_dim),这里第一个是nf,表示输出,
        # 第二个是nx,表示输入,c_attn调整dim=1,是在调整输出特征
        # dim=0是在调整输入特征,因为权重形状是(in_feature,out_feature),
        # c_proj最后还要转换成512维的向量
        self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
        self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
        # Update hyper params,这里设置新的split_size
        self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
        self.num_heads = self.num_heads - len(heads) # 更新新头数
        # 更新已经修剪的头的索引
        self.pruned_heads = self.pruned_heads.union(heads)
    # 注意力计算 (_attn 方法)
    # 此方法实现了多头注意力机制的核心计算步骤
    def _attn(self, query, key, value, attention_mask=None, head_mask=None):
        # 计算注意力权重 (attn_weights),通过矩阵乘法计算原始注意力得分
        attn_weights = torch.matmul(query, key.transpose(-1, -2))
        # 根据配置进行注意力权重的缩放。
        if self.scale_attn_weights:
            attn_weights = attn_weights / torch.full(
                [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
            )
        # Layer-wise attention scaling
        if self.scale_attn_by_inverse_layer_idx:
            attn_weights = attn_weights / float(self.layer_idx + 1)
        if not self.is_cross_attention:
            # if only "normal" attention layer implements causal mask
            query_length, key_length = query.size(-2), key.size(-2)
            causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
            # 生成一个非常小的数值,用于在注意力机制中掩码(mask)那些不应该被考虑的位置。这样做是为了确保在应
            # 用softmax函数之后,这些位置的注意力权重几乎为零
            # torch.finfo 是一个获取浮点数类型的数值信息的方法。它可以返回一个对象,包含了该类型的最小正数、最
            # 大正数、精度等信息。
            mask_value = torch.finfo(attn_weights.dtype).min
            # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
            # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
            # mask_value是个很小的负数
            mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
            # condition,input,other,根据条件:When True (nonzero), yield input, otherwise yield other
            attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
        # 应用attention_mask,attention_mask 通常用于指示哪些位置在计算注意力得分时应该被忽略
        if attention_mask is not None:
            # Apply the attention mask
            attn_weights = attn_weights + attention_mask
        # 对注意力权重进行归一化处理。
        attn_weights = nn.functional.softmax(attn_weights, dim=-1)

        # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
        attn_weights = attn_weights.type(value.dtype)
        # 在归一化后的注意力权重上应用dropout。
        attn_weights = self.attn_dropout(attn_weights)
        # 如果我们想要禁用某些注意力头,可以使用 head_mask 来指定哪些头应该被忽略。
        # 示例:假设我们有8个头,但只想保留前4个头,那么 head_mask 可能是 [1, 1, 1, 1, 0, 0, 0, 0]
        # Mask heads if we want to
        if head_mask is not None:
            attn_weights = attn_weights * head_mask
        # 通过将注意力权重与值向量相乘得到注意力输出。
        attn_output = torch.matmul(attn_weights, value)
        return attn_output, attn_weights

    def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
        # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
        bsz, num_heads, q_seq_len, dk = query.size()
        _, _, k_seq_len, _ = key.size()

        # Preallocate attn_weights for `baddbmm`
        attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)

        # Compute Scale Factor
        scale_factor = 1.0
        if self.scale_attn_weights:
            scale_factor /= float(value.size(-1)) ** 0.5

        if self.scale_attn_by_inverse_layer_idx:
            scale_factor /= float(self.layer_idx + 1)

        # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
        with torch.amp.autocast(query.device.type, enabled=False):
            q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
            attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
            attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)

        if not self.is_cross_attention:
            # if only "normal" attention layer implements causal mask
            query_length, key_length = query.size(-2), key.size(-2)
            causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
            mask_value = torch.finfo(attn_weights.dtype).min
            # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
            # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
            mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
            attn_weights = torch.where(causal_mask, attn_weights, mask_value)

        if attention_mask is not None:
            # Apply the attention mask
            attn_weights = attn_weights + attention_mask

        attn_weights = nn.functional.softmax(attn_weights, dim=-1)

        # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
        if attn_weights.dtype != torch.float32:
            raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
        attn_weights = attn_weights.type(value.dtype)
        attn_weights = self.attn_dropout(attn_weights)

        # Mask heads if we want to
        if head_mask is not None:
            attn_weights = attn_weights * head_mask

        attn_output = torch.matmul(attn_weights, value)

        return attn_output, attn_weights

    def _split_heads(self, tensor, num_heads, attn_head_size):
        new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) # (b,s,h,dk)
        tensor = tensor.view(new_shape)
        return tensor.permute(0, 2, 1, 3)  # (batch, head, seq_length, head_features)

    def _merge_heads(self, tensor, num_heads, attn_head_size):
        tensor = tensor.permute(0, 2, 1, 3).contiguous() # (b,s,h,dk)
        new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
        return tensor.view(new_shape) # (b,s,d)

    def forward(
        self,
        hidden_states: Optional[Tuple[torch.FloatTensor]],
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
        if encoder_hidden_states is not None: # 如果有编码器输出
            if not hasattr(self, "q_attn"):
                raise ValueError(
                    "If class is used as cross attention, the weights `q_attn` have to be defined. "
                    "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
                )
            
            query = self.q_attn(hidden_states) # (b,s,d)
            # (b,s,d)-->(b,s,2d),之后再dim=2上拆分
            key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
            attention_mask = encoder_attention_mask # 编码器填充掩码
        else:
            # 在dim=2上拆分成query,key,value
            query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
        # (b,s,d)-->(b,h,s,dk)
        query = self._split_heads(query, self.num_heads, self.head_dim)
        key = self._split_heads(key, self.num_heads, self.head_dim)
        value = self._split_heads(value, self.num_heads, self.head_dim)
        # 在推理阶段,需要用到之前时间步的key,value表示
        # 这种缓存机制是用在推理阶段,因为训练时,推理具有并行性,各个时间步的预测同步进行
        # 而在推理时,因为新预测的token是基于已经生成的token进行预测的,这时候目标序列输入的
        # 自注意力是query是只有上一步新生成的token,key和value是把当前token(只有这么一个)
        # 的表示和之前缓存的token(当前token之前所有token)的表示在序列长度维度进行了合并
        # 当前新生成的token与之前的token进行交互,以确定其在上下文中的位置和意义,之后预测下个token
        # 而解码器跨注意力时,因为推理时,编码器中token嵌入不会再改变,这时候可以缓存编码器输出的表示
        # 目的就是能减少计算开销
        if layer_past is not None:
            past_key, past_value = layer_past
            # 在序列长度维度合并
            key = torch.cat((past_key, key), dim=-2)
            value = torch.cat((past_value, value), dim=-2)
        # 如果使用缓存,就保存
        if use_cache is True:
            present = (key, value)
        else:
            present = None

        if self.reorder_and_upcast_attn:
            attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
        else:
            attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
        # 合并头的嵌入为一个整体
        attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
        # 因为有裁剪头,需要把维度投影到embed_dim
        attn_output = self.c_proj(attn_output)
        attn_output = self.resid_dropout(attn_output) # dropout
        
        outputs = (attn_output, present)
        if output_attentions:
            outputs += (attn_weights,)
        return outputs  # a, present, (attentions)
def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]:
    # Check if the package spec exists and grab its version to avoid importing a local directory
    package_exists = importlib.util.find_spec(pkg_name) is not None
    package_version = "N/A"
    if package_exists:
        try:
            # Primary method to get the package version
            package_version = importlib.metadata.version(pkg_name)
        except importlib.metadata.PackageNotFoundError:
            # Fallback method: Only for "torch" and versions containing "dev"
            if pkg_name == "torch":
                try:
                    package = importlib.import_module(pkg_name)
                    temp_version = getattr(package, "__version__", "N/A")
                    # Check if the version contains "dev"
                    if "dev" in temp_version:
                        package_version = temp_version
                        package_exists = True
                    else:
                        package_exists = False
                except ImportError:
                    # If the package can't be imported, it's not available
                    package_exists = False
            else:
                # For packages other than "torch", don't attempt the fallback and set as not available
                package_exists = False
        logger.debug(f"Detected {pkg_name} version: {package_version}")
    if return_version:
        return package_exists, package_version
    else:
        return package_exists
def is_flash_attn_greater_or_equal_2_10():
    if not _is_package_available("flash_attn"):
        return False
    return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0")
# GPT2闪存注意模块。这个模块继承自GPT2Attention,因为模块的权重保持不变。唯一需要改动的地方是在前向传递过程中,
# 需要正确调用闪存注意的公共API,并在输入中包含任何填充标记的情况下处理这些填充标记
# 你提到的情况是在生成序列时,每次生成一个新的token,并且将这个新生成的token加入到现有的序列中,形成一个新的序列,
# 然后对这个新的序列进行自注意力计算。这是典型的序列生成过程,在每次生成新的token之后,都需要重新计算整个序列的注意力权重。
# 而我之前解释的情况是指另一种优化方法,即在生成过程中保留之前计算过的key和value对,以避免重复计算。这种方法主要用于加速推
# 理过程,特别是在长序列的情况下。在这种情况下,确实只是对当前的token做自注意力计算,并且通过保留之前的部分计算结果来节省计算资源。
# 具体来说,当我们说“将新的key/value与旧的key/value拼接起来”时,实际上是在更新一个缓存(cache),这个缓存包含了之前的计算结果
# 。每次生成新的token时,只需要计算这个新token的key和value,然后将它们添加到缓存中,而不是重新计算整个序列的key和value。
# 两种情况的区别在于:
# 生成新序列并计算自注意力:每次生成新的token后,将它加入到序列中,然后对整个序列重新计算自注意力权重。
# 使用缓存的key/value:只计算当前新token的key和value,并将它们添加到之前的缓存中,以供下一步计算使用。
# 这两种方式都可以用来进行序列生成,但使用缓存的方法在长序列任务中更加高效,因为它减少了重复计算
class GPT2FlashAttention2(GPT2Attention):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # 一旦针对RoCm的Flash Attention升级到2.1版本,就应该移除这部分
        # 我的是不存在,这个值设置成True,使用左上掩码
        self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
    def forward(
        self,
        hidden_states: Optional[Tuple[torch.FloatTensor]],
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
        bsz, _, _ = hidden_states.size()
        # 如果encoder_hidden_states不是None,就是有编码器输出,模型应该做交叉注意力,这时如果q_attn为None,就要抛出异常
        # 因为跨注意力,query和vaue,key不同
        if encoder_hidden_states is not None:
            if not hasattr(self, "q_attn"):
                raise ValueError(
                    "If class is used as cross attention, the weights `q_attn` have to be defined. "
                    "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
                )

            query = self.q_attn(hidden_states) # q
            key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) # k,v
            attention_mask = encoder_attention_mask # 编码器掩码
        else:
            # 如果encoder_hidden_states is None,这种是自注意力,q=k=v
            query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
        # 之后变形q,k,v,以便他们能做自注意力或交叉注意力
        query = self._split_heads(query, self.num_heads, self.head_dim)
        key = self._split_heads(key, self.num_heads, self.head_dim)
        value = self._split_heads(value, self.num_heads, self.head_dim)
        # 如果layer_past不为空,这个在训练时是没用的,在推理时的缓存机制,可以缓存解码器自注意力时目标
        # 输入序列之前的token表示,这时传进来的token只有新token,之后query只有一个token,而k,v
        # 和之前缓存的在序列上拼接,之后来预测下个token
        if layer_past is not None:
            past_key = layer_past[0] # 缓存的上次的key,就是当前token之前的token
            past_value = layer_past[1] # 缓存的上次的v
            # 在序列长度维度拼接之前的token表示和当前token表示
            key = torch.cat((past_key, key), dim=-2)
            value = torch.cat((past_value, value), dim=-2)

        present = None
        if use_cache is True: #如果使用缓存
            present = (key, value) # 缓存当前的key,value

        query_length = query.shape[2] # q_len
        tgt_len = key.shape[2] # k_len和v_len

        # Flash attention requires the input to have the shape
        # batch_size x seq_length x head_dim x hidden_dim
        query = query.transpose(1, 2).view(bsz, query_length, self.num_heads, self.head_dim) # (b,q_len,h,dk)
        key = key.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
        value = value.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)

        attn_dropout = self.attn_dropout.p if self.training else 0.0
        
        # In PEFT, usually we cast the layer norms in float32 for training stability reasons
        # therefore the input hidden states gets silently casted in float32. Hence, we need
        # cast them back in the correct dtype just to be sure everything works as expected.
        # This might slowdown training & inference so it is recommended to not cast the LayerNorms
        # in fp32. (LlamaRMSNorm handles it correctly)
        if query.dtype == torch.float32:
            if torch.is_autocast_enabled():
                target_dtype = torch.get_autocast_gpu_dtype() # float16
            # Handle the case where the model is quantized
            elif hasattr(self.config, "_pre_quantization_dtype"):
                target_dtype = self.config._pre_quantization_dtype
            else:
                target_dtype = self.c_proj.weight.dtype
            logger.warning_once(
                f"The input hidden states seems to be silently casted in float32, this might be related to"
                f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
                f" {target_dtype}."
            )

            query = query.to(target_dtype)
            key = key.to(target_dtype)
            value = value.to(target_dtype)

        attn_output = self._flash_attention_forward(
            query, key, value, attention_mask, query_length, dropout=attn_dropout
        )

        attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim)
        attn_output = self.c_proj(attn_weights_reshaped)
        attn_output = self.resid_dropout(attn_output)

        outputs = (attn_output, present)
        if output_attentions:
            outputs += (attn_weights_reshaped,)

        return outputs

    # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
    def _flash_attention_forward(
        self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
    ):
        # 如果不是用左上因果掩码
        if not self._flash_attn_uses_top_left_mask:
            causal = self.is_causal # 布尔型,True
        else:
            # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
            # 如果self.is_causal是True,但是query_length==1的话,这个值是False
            causal = self.is_causal and query_length != 1

        # Contains at least one padding token in the sequence
        if attention_mask is not None:
            batch_size = query_states.shape[0] # b
            query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
                query_states, key_states, value_states, attention_mask, query_length
            )

            cu_seqlens_q, cu_seqlens_k = cu_seq_lens
            max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens

            attn_output_unpad = flash_attn_varlen_func(
                query_states,
                key_states,
                value_states,
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_k=cu_seqlens_k,
                max_seqlen_q=max_seqlen_in_batch_q,
                max_seqlen_k=max_seqlen_in_batch_k,
                dropout_p=dropout,
                softmax_scale=softmax_scale,
                causal=causal,
            )

            attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
        else:
            attn_output = flash_attn_func(
                query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
            )
        return attn_output
        # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
    # 实现的是一个用于消除输入张量中的填充(padding)的过程,即所谓的“unpadding”,这对于处理变长序列特别有用,因为在实际应用
    # 中,不同序列的长度往往是不一样的,而模型通常需要固定长度的输入,这就需要对较短的序列进行填充以匹配最长序列的长度。然而,在
    # 计算过程中,填充的部分实际上并不参与有效的计算,因此通过这种方式去除这些无效部分可以显著减少计算资源的需求。
    def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
        # 所有非零元素的位置,之前的所有样本的非填充token的累计数量,批次中最长的序列
        indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
        # b,k_len,h,dk
        batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
        
        key_layer = index_first_axis( # (b*k_len,h,dk)
            key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
        )
        value_layer = index_first_axis(
            value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
        )
        if query_length == kv_seq_len:
            query_layer = index_first_axis(
                query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
            )
            cu_seqlens_q = cu_seqlens_k
            max_seqlen_in_batch_q = max_seqlen_in_batch_k
            indices_q = indices_k
        elif query_length == 1: # 如果q_len是1
            max_seqlen_in_batch_q = 1
            cu_seqlens_q = torch.arange(
                batch_size + 1, dtype=torch.int32, device=query_layer.device
            )  # There is a memcpy here, that is very bad.
            indices_q = cu_seqlens_q[:-1]
            query_layer = query_layer.squeeze(1)
        else:
            # The -q_len: slice assumes left padding.
            attention_mask = attention_mask[:, -query_length:]
            query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)

        return (
            query_layer,
            key_layer,
            value_layer,
            indices_q,
            (cu_seqlens_q, cu_seqlens_k),
            (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
        )
# GPT-2 注意力模块现在使用了 PyTorch 提供的 scaled_dot_product_attention 函数来实现注意力机制。
# 局限性:GPT2SdpaAttention 类使用了 torch.nn.functional.scaled_dot_product_attention,但后者不支
# 持 output_attentions=True 或 head_mask 这两个特性。
# 实际用途:如果你的应用场景不需要这两个特性,那么这个类依然是有用的。它提供了基于 scaled_dot_product_attention
# 的实现,通常比手动实现更加高效。
# 回退到手动实现:由于 scaled_dot_product_attention 不支持 output_attentions=True 或 head_mask,因此程序
# 会回退到使用手动实现的注意力机制。
# 手动实现:指的是不使用 GPT2SdpaAttention 类,而是使用传统的手动实现,即完全自定义的注意力机制实现。
# attn_implementation="eager" 是一个参数,用于指定在加载模型时应该使用哪种注意力机制实现。
# 含义:如果你希望强制使用传统的手动实现,而不是尝试使用 scaled_dot_product_attention,可以在加载模型时通过设置 at
# tn_implementation="eeger" 来指定。
# 效果:这样做可以避免因 scaled_dot_product_attention 不支持某些特性而导致的回退警告,并确保始终使用手动实现。
# 通过设置 attn_implementation="eager",你可以显式地告诉模型使用传统手动实现,从而避免回退警告。
class GPT2SdpaAttention(GPT2Attention):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # 在使用非连续输入和自定义注意力掩码时,torch==2.1.2 版本中的 SDPA(scaled dot product attention)
        # 带有内存高效后端的功能存在问题,因此我们需要调用 .contiguous()。这个问题在 torch==2.2.0 版本中得到了修复。
        # 在 PyTorch 2.1.2 版本中,使用内存高效的 SDPA 时,如果输入张量是非连续的(即内存布局不是连续的),并且使用了自
        # 定义的注意力掩码,那么可能会出现问题。为了绕过这个问题,可以在调用 SDPA 之前确保输入张量是连续的,即通过调用
        # .contiguous() 方法使输入张量变为连续存储。
        # 在 PyTorch 2.2.0 版本中,这个问题已经被修复,因此在使用 2.2.0 及以上版本时,不需要显式地调用 .contiguous()。
        self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")
    def forward(
        self,
        hidden_states: Optional[Tuple[torch.FloatTensor]],
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
        # 如果要输出注意力权重,或者遮挡头的掩码不是None,调用父类的注意力机制
        if output_attentions or head_mask is not None:
            logger.warning_once(
                "`GPT2SdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
                "`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
                "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
                'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
            )
            return super().forward(
                hidden_states=hidden_states,
                layer_past=layer_past,
                attention_mask=attention_mask,
                head_mask=head_mask,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
                use_cache=use_cache,
                output_attentions=output_attentions,
            )
        
        bsz, q_len, _ = hidden_states.size() # (b,q_len,d)

        # Initial attention projections
        is_cross_attention = encoder_hidden_states is not None
        if is_cross_attention:
            if not hasattr(self, "q_attn"):
                raise ValueError(
                    "If class is used as cross attention, the weights `q_attn` have to be defined. "
                    "Please make sure to instantiate class with `GPT2SdpaAttention(..., is_cross_attention=True)`."
                )

            query = self.q_attn(hidden_states)
            key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
            attention_mask = encoder_attention_mask
        else:
            query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)

        query = self._split_heads(query, self.num_heads, self.head_dim)
        key = self._split_heads(key, self.num_heads, self.head_dim)
        value = self._split_heads(value, self.num_heads, self.head_dim)

        # Optional kv caching
        if layer_past is not None:
            past_key = layer_past[0]
            past_value = layer_past[1]
            key = torch.cat((past_key, key), dim=-2)
            value = torch.cat((past_value, value), dim=-2)

        present = None
        if use_cache is True:
            present = (key, value)

        # Avoid torch==2.1.2 specific bug for the memory-efficient backend in SDPA
        if self.require_contiguous_qkv and query.device.type == "cuda" and attention_mask is not None:
            query = query.contiguous()
            key = key.contiguous()
            value = value.contiguous()
        # 内联条件(inline condition)通常是指在代码中直接使用条件表达式(如三元运算符
        # condition ? true_expr : false_expr)来选择不同的执行路径。在某些情况下,
        # 内联条件可能会导致编译器或优化器无法有效地处理动态形状。
        # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
        # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
        # 使用因果掩码的条件:attention_mask是None,并且q_len > 1,并且不是交叉注意力
        is_causal = True if attention_mask is None and q_len > 1 and not is_cross_attention else False
        #
        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query,
            key,
            value,
            attn_mask=attention_mask,
            dropout_p=self.attn_dropout.p if self.training else 0.0,
            is_causal=is_causal,
        )

        # Reshape outputs
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(bsz, q_len, self.embed_dim)

        # Final projection
        attn_output = self.c_proj(attn_output)
        attn_output = self.resid_dropout(attn_output)

        return attn_output, present, None
class GPT2MLP(nn.Module): # 前馈全连接层
    def __init__(self, intermediate_size, config):
        super().__init__()
        embed_dim = config.hidden_size # d
        self.c_fc = Conv1D(intermediate_size, embed_dim)  
        self.c_proj = Conv1D(embed_dim, intermediate_size)
        self.act = ACT2FN[config.activation_function]
        self.dropout = nn.Dropout(config.resid_pdrop)

    def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
        hidden_states = self.c_fc(hidden_states)  # d-->hidden_d
        hidden_states = self.act(hidden_states)
        hidden_states = self.c_proj(hidden_states) # hidden_d-->d
        hidden_states = self.dropout(hidden_states)
        return hidden_states
class GPT2Block(nn.Module):
    def __init__(self, config, layer_idx=None):
        super().__init__()
        hidden_size = config.hidden_size # d
        inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size # h_d
        # 如果config._attn_implementation是"eager",这个就是GPT2Attention
        attention_class = GPT2_ATTENTION_CLASSES[config._attn_implementation]
        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        #注意力机制
        self.attn = attention_class(config=config, layer_idx=layer_idx)
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)

        if config.add_cross_attention:
            # 交叉注意力
            self.crossattention = attention_class(config=config, is_cross_attention=True, layer_idx=layer_idx)
            self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.mlp = GPT2MLP(inner_dim, config)
    def forward(
        self,
        hidden_states: Optional[Tuple[torch.FloatTensor]],
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
    ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
        residual = hidden_states
        # 先进行层标准化
        hidden_states = self.ln_1(hidden_states)
        attn_outputs = self.attn(
            hidden_states,
            layer_past=layer_past,
            attention_mask=attention_mask,
            head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
        )
        attn_output = attn_outputs[0]  # output_attn: a, present, (attentions)
        outputs = attn_outputs[1:]
        # residual connection,自注意力前后残差连接
        hidden_states = attn_output + residual
        # 如果有编码器输出的话
        if encoder_hidden_states is not None:
            # add one self-attention block for cross-attention
            # 如果它没有crossattention这个属性的话,就抛出错误
            if not hasattr(self, "crossattention"):
                raise ValueError(
                    f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
                    "cross-attention layers by setting `config.add_cross_attention=True`"
                )
            
            residual = hidden_states # 目标序列自注意力的输出残差后的
            hidden_states = self.ln_cross_attn(hidden_states) # 层标准化
            # 获取交叉注意力
            cross_attn_outputs = self.crossattention(
                hidden_states,
                attention_mask=attention_mask,
                head_mask=head_mask,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
                output_attentions=output_attentions,
            )
            attn_output = cross_attn_outputs[0]
            # residual connection
            hidden_states = residual + attn_output # 跨注意力前后残差
            outputs = outputs + cross_attn_outputs[2:]  # add cross attentions if we output attention weights
        
        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states # 前馈前后残差

        if use_cache: # 如果使用缓存,就输出缓存
            outputs = (hidden_states,) + outputs
        else: # 否则只输出注意力权重
            outputs = (hidden_states,) + outputs[1:]
        return outputs  # hidden_states, present, (attentions, cross_attentions)
# 该类继承自 PreTrainedModel,用于处理模型权重初始化以及预训练模型的下载和加载
class GPT2PreTrainedModel(PreTrainedModel):
    config_class = GPT2Config # 指定模型的配置类,用于存储模型的各种超参数和配置信息。
    # load_tf_weights = load_tf_weights_in_gpt2 # 提供一个方法用于从 TensorFlow 模型加载权重。
    base_model_prefix = "transformer" # 指定模型的前缀,用于在加载或保存模型时标识模型的主体部分
    is_parallelizable = True # 表示模型支持并行化处理
    supports_gradient_checkpointing = True # 表示模型支持梯度检查点(gradient checkpointing),这是一种节省内存的技术
    _no_split_modules = ["GPT2Block"] # 指定不应在这些模块之间分割模型,这对于某些并行化策略很重要。
    _skip_keys_device_placement = "past_key_values" # 指定在放置设备时应跳过的键,这在多GPU设置中很有用。
    _supports_flash_attn_2 = True # 表示模型支持 Flash Attention 的第二版本。
    _supports_sdpa = True # 表示模型支持使用 scaled_dot_product_attention

    def __init__(self, *inputs, **kwargs):
        super().__init__(*inputs, **kwargs) # 初始化父类 PreTrainedModel 的构造函数,没有额外的操作
    # 权重初始化方法,用于初始化模型的权重,处理不同类型的模块
    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, Conv1D)):
            # 线性层(nn.Linear, Conv1D),使用正态分布初始化权重。
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            # 如果有偏置项,则将其初始化为零
            if module.bias is not None:
                module.bias.data.zero_()
        # 嵌入层(nn.Embedding)       
        elif isinstance(module, nn.Embedding):
            # 使用正态分布初始化权重。
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            # 如果有填充索引,则将填充索引对应的权重初始化为零。
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        # 归一化层(nn.LayerNorm):
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_() # 初始化偏置项为零。
            module.weight.data.fill_(1.0) # 初始化权重为 1。
        # 特殊权重初始化
        for name, p in module.named_parameters(): # 循环
            if name == "c_proj.weight": # 对于名为 "c_proj.weight" 的参数,使用特殊的缩放初始化
                # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
                # 初始化权重时,考虑到了模型深度的影响,按照 GPT-2 论文中的建议进行缩放。
                p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
# 它继承自 ModelOutput 并使用了 dataclass 装饰器来简化类的定义。
# 这个类主要用于封装 GPT-2 双头模型的输出结果,其中包含了模型预测结果、损失、过去的键值对以及其他中间结果
# 使用 dataclass 装饰器的好处
# 使用 dataclass 装饰器可以简化类的定义,自动为类添加一些常用的方法,如 __init__、__repr__、__eq__ 等。此外,
# dataclass 还支持默认值和类型注解,使得类的定义更加简洁明了。
# GPT2DoubleHeadsModelOutput 类主要用于封装 GPT-2 双头模型的输出结果,其中包括了损失、原始输出、过去的键值对以
# 及中间计算的结果。使用 dataclass 装饰器可以简化类的定义,并提供便捷的数据封装功能。通过这种方式,可以方便地管理和
# 访问模型的输出数据。
@dataclass
class GPT2DoubleHeadsModelOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None # 模型的总损失(如果有)
    mc_loss: Optional[torch.FloatTensor] = None # 多分类(Multi-Class Classification)损失
    logits: torch.FloatTensor = None # 模型的原始输出(通常用于后续处理,如分类)
    mc_logits: torch.FloatTensor = None # 多分类任务的原始输出
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None # 过去的键值对(用于解码阶段的缓存)
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None # 每一层的隐藏状态
    attentions: Optional[Tuple[torch.FloatTensor]] = None # 每一层的注意力权重
# GPT-2 双头模型(Double Heads Model)中的“双头”并不是指多头注意力(multi-head attention)中的“头”。在这里,
# “双头”指的是模型具有两个不同的输出头(output heads),分别用于不同的任务。
# 在自然语言处理(NLP)任务中,双头模型通常是指一个模型同时具有两个不同的输出,每个输出头负责一个特定的任务。这样的
# 设计可以让模型在一次前向传播中完成多项任务,从而提高效率并共享底层特征表示。
# 在 GPT-2 的上下文中,双头模型通常包括以下两个输出头:
# LM Head(语言模型头):用于生成下一个词的概率分布。通常用于语言建模任务,如文本生成。
# MC Head(多分类头),用于多分类任务,如文本分类或其他监督学习任务。通常用于下游任务,如情感分析、问答等。
# 在许多 NLP 任务中,第一个 token(通常是 BOS 或 [CLS])被用来表示整个句子的语义信息。
# 在 GPT-2 双头模型中,确实应该取句子的第一个 token(通常是 BOS)的输出作为句子级别的表示。这样可以更好地
# 与句子相关的任务(如文本分类)相结合,并且与许多其他预训练模型的做法保持一致


http://www.mrgr.cn/news/36746.html

相关文章:

  • 静态多态和动态多态
  • cross apply 和 outer apply 的区别
  • Redis相关知识
  • docker(一)之cgroup详解
  • Excel怎么自动排序?4种方法任君选择
  • 【IOS】申请开发者账号(公司)
  • SLM2304S 600V, 130mA/270mA 高压半桥驱动芯片,隐藏着哪些强大功能?
  • 为什么我安装了open3d但是在调用的时候没有报错但是什么都没有发生呢
  • 详解swoole框架快速入门
  • MyBatis-Plus的使用基础入门案例
  • 3d gaussian splatting公式推导
  • 使用 Llama 3.1 和 Qdrant 构建多语言医疗保健聊天机器人的步骤
  • 智能雷达AI名片小程序源码系统 基于PHP+MySQL组合开发 带完整的安装代码包以及搭建部署教程
  • 设计模式之观察者模式
  • Git提示信息 Pulling is not possible because you have unmerged files.
  • 桌面PDU插座应用于工业自动化场景
  • AOT源码解析4.4 -decoder生成预测mask并计算loss
  • 《Linux从小白到高手》开篇:脱胎换骨之为什么要深度学习Linux?
  • 一七零、GORM值为0或者空字符串的时候不能被更新创建的五种解决办法
  • 【JavaEE初阶】深入解析死锁的产生和避免以及内存不可见问题