当前位置: 首页 > news >正文

【HuggingFace Transformers】BertIntermediate 和 BertPooler源码解析

BertIntermediate 和 BertPooler源码解析

  • 1. 介绍
    • 1.1 位置与功能
    • 1.2 相似点与不同点
  • 2. 源码解析
    • 2.1 BertIntermediate 源码解析
    • 2.2 BertPooler 源码解析

1. 介绍

1.1 位置与功能

(1) BertIntermediate

  • 位置:位于 BertLayer 的注意力层(BertSelfAttention)和输出层(BertOutput)之间。
  • 功能:它执行一个线性变换(通过全连接层)并跟随一个激活函数(通常是 ReLU),为后续层提供更高层次的特征表示。

(2) BertPooler

  • 位置:位于整个 BertModel 的最后一层之后,直接处理经过编码的序列表示。
  • 功能:从序列的第一个标记(即 [CLS] 标记)提取特征,并通过一个线性变换和 Tanh 激活函数来生成一个全局表示,通常用于分类任务中的最终输出。

1.2 相似点与不同点

(1) 相似点

  • 两者都涉及到线性变换,并且都通过激活函数来增强模型的表达能力。
  • 都是 BERT 模型中的重要组成部分,从不同的角度和层次上处理输入数据。

(2) 不同点

  • 应用层次:
    BertIntermediate 作用于每个 Transformer 层,用于构建更深的层级特征。
    BertPooler 只在模型的最后一层作用,用于提取全局特征。
  • 功能目标:
    BertIntermediate 增强中间层的非线性特征,助于后续的自注意力机制。
    BertPooler 为分类或回归任务提供一个紧凑的全局特征表示。

2. 源码解析

源码地址:transformers/src/transformers/models/bert/modeling_bert.py

2.1 BertIntermediate 源码解析

# -*- coding: utf-8 -*-
# @time: 2024/7/15 14:17
import torchfrom torch import nn
from transformers.activations import ACT2FNclass BertIntermediate(nn.Module):def __init__(self, config):super().__init__()# 全连接层,将 hidden_size 映射到 intermediate_sizeself.dense = nn.Linear(config.hidden_size, config.intermediate_size)# 根据 config.hidden_act 定义激活函数if isinstance(config.hidden_act, str):self.intermediate_act_fn = ACT2FN[config.hidden_act]else:self.intermediate_act_fn = config.hidden_actdef forward(self, hidden_states: torch.Tensor) -> torch.Tensor:hidden_states = self.dense(hidden_states)  # 线性变换hidden_states = self.intermediate_act_fn(hidden_states)  # 激活函数return hidden_states

2.2 BertPooler 源码解析

# -*- coding: utf-8 -*-
# @time: 2024/7/19 11:41import torchfrom torch import nnclass BertPooler(nn.Module):def __init__(self, config):super().__init__()self.dense = nn.Linear(config.hidden_size, config.hidden_size)  # 全连接层,将 hidden_size 映射回 hidden_sizeself.activation = nn.Tanh()  # 激活函数为 Tanh 函数def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:# We "pool" the model by simply taking the hidden state corresponding# to the first token.# 提取序列中的第一个 token,也就是 [CLS] 的 hidden statefirst_token_tensor = hidden_states[:, 0]pooled_output = self.dense(first_token_tensor)  # 线性变换pooled_output = self.activation(pooled_output)  # 激活函数return pooled_output

http://www.mrgr.cn/news/11003.html

相关文章:

  • 沈阳网站建设手机能看的网站
  • 0基础学习Python路径(29)collections模块
  • ubuntu系统在线安装下载firefox-esr流览器
  • WIN11环境下,如何在指定目录快速启动jupyter lab或jupyter notebook
  • MongoDB适用场景
  • 空气净化器怎么选能除猫毛?宠物空气净化器除味好的分享
  • GLM-4-Flash 大模型API免费了,手把手构建“儿童绘本”应用实战(附源码)
  • redis面试(二十三)写锁释放
  • 国产游戏技术的崛起:能否挑战全球引领地位?
  • 【单调栈】|代码随想录算法训练营第42天|42. 接雨水、 84.柱状图中最大的矩形
  • 王立铭脑科学50讲:29,敌对型社交,如何抑制自己的共攻击本能
  • 【搜索引擎】ElasticSearch 7.x版本
  • android aar适配uniapp
  • IDEA工具设置默认使用maven的settings.xml文件
  • 游戏开发设计模式之装饰模式
  • 汽车耐老化太阳跟踪聚光户外加速老化试验
  • 可集成多模型的机器人开发框架 dora:让机器人编程走向大众
  • HarmonyOS 鸿蒙获取微信授权和持续获取位置信息
  • 力扣1074.元素和为目标值的子矩阵数量
  • redisj集群之哨兵模式