Transformer与Seq2Seq中的Mask和Pad
Pad 填充和截断的应用
语言模型中的序列样本都有一个固定的长度, 无论这个样本是一个句子的一部分还是跨越了多个句子的一个片断。 这个固定长度是由 num_steps(时间步数或词元数量)参数指定的。 在机器翻译中,每个样本都是由源和目标组成的文本序列对, 其中的每个文本序列可能具有不同的长度。
为了提高计算效率,可以通过截断(truncation)和 填充(padding)方式实现一次只处理一个小批量的文本序列。 假设同一个小批量中的每个序列都应该具有相同的长度num_steps, 那么如果文本序列的词元数目少于num_steps时, 我们将继续在其末尾添加特定的“”词元, 直到其长度达到num_steps; 反之,我们将截断文本序列时,只取其前num_steps 个词元, 并且丢弃剩余的词元。这样,每个文本序列将具有相同的长度, 以便以相同形状的小批量进行加载。
Truncate的缺点:
- 丢弃上下文信息:在自然语言中,句子的意义往往依赖于上下文。如果截断了一个句子或文本片段,可能会丢失关键信息,导致模型无法理解完整的语义。例如,截断一个包含条件或因果关系的句子,可能会使得模型无法捕捉到这些关系。
- 重要信息在尾部:某些情况下,重要的信息可能位于文本的尾部。例如,在长句子中,结尾部分可能包含总结、结论或关键细节。如果截断只保留开头部分,模型将无法获取这些重要信息。
Pad
def truncate_pad(line, num_steps, padding_token):"""截断或填充文本序列"""if len(line) > num_steps:return line[:num_steps] # 截断return line + [padding_token] * (num_steps - len(line)) # 填充
-
这样的操作直接截断可能导致丢失一些信息
-
注意力机制:在深度学习模型中,尤其是使用循环神经网络(RNN)或Transformer模型时,可以利用注意力机制来关注序列中的重要部分。这样即使进行了截断,模型也能通过注意力机制捕捉到关键信息。
-
优化填充策略:在填充时,可以选择更有意义的填充词,或者使用特殊的填充策略,如在填充词中加入噪声,以减少模型对填充词的过度依赖。
-
使用变长序列模型:有些模型设计可以处理变长序列,例如Transformer模型,它不依赖于固定长度的输入,因此可以减少截断和填充的需要。
在这里补一下为什么Transformer可以截断之后,仍然可以捕捉到关键信息.
Mask
-
在Transformer中用到了masked_softmax
-
是为了Softmax之后,将Mask的位置置为0
def sequence_mask(X, valid_len, value=0):"""在序列中屏蔽不相关的项"""maxlen = X.size(1)mask = torch.arange((maxlen), dtype=torch.float32,device=X.device)[None, :] < valid_len[:, None]X[~mask] = valuereturn X
接下来学习一下Transformer中用到的三种Mask的机制
Encoder Mask
Encoder对输入序列的长度进行Pad一直到Max_len;
而在计算自注意力的时候,只需要对有效的序列长度进行Attention的计算,因此Pad的部分需要被Mask;
Decoder Mask-1
Decoder中的第一个Masked多头注意力模块输入序列,不能看到当前token之后的信息,所以要对当前token之后的tokens进行mask;
Decoder Mask-2
Decoder中的第二个Masked多头注意力模块中Query来源于Decoder当前输入的token,而Key-Value来源于Encoder的输出,因此,需要对不需要计算注意力的位置进行Mask;
Masked Softmax
def masked_softmax(X, valid_lens):"""通过在最后一个轴上掩蔽元素来执行softmax操作"""# X:3D张量,valid_lens:1D或2D张量if valid_lens is None:return nn.functional.softmax(X, dim=-1)else:shape = X.shapeif valid_lens.dim() == 1:valid_lens = torch.repeat_interleave(valid_lens, shape[1])else:valid_lens = valid_lens.reshape(-1)# 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens,value=-1e6)return nn.functional.softmax(X.reshape(shape), dim=-1)