ResNet神经网络搭建

news/2024/5/20 1:34:16

一、定义残差结构

BasicBlock   

18层、34层网络对应的残差结构

浅层网络主线由两个3x3的卷积层链接,相加后通过relu激活函数输出。还有一个shortcut捷径

参数解释

        expansion = 1  : 判断对应主分支的残差结构有无变化

        downsample=None : 下采样参数,默认为none

        stride步距为1,对应实线残差结构 ; 步距为2,对应虚线残差结构

        self.conv2 = nn.Conv2d(in_channels=out_channel :卷积层1的输出即为输入

class BasicBlock(nn.Module):expansion = 1def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):super(BasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channel)self.relu = nn.ReLU()self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channel)self.downsample = downsample

定义正向传播

 identity = x : shortcut捷径上的输出值

identity = self.downsample(x) : 将输出特征矩阵x输入到下采样函数中得到捷径分支的输出 

 def forward(self, x):   #定义正向传播identity = xif self.downsample is not None:identity = self.downsample(x)out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out += identity  # 相加后通过relu激活函数输出out = self.relu(out)return out

BottleBlock  

50层、101层、152层神经网路对应的残差结构

深层网络主线由一个1x1的降维卷积层,3x3卷积层、1x1升维卷积层和一个shortcut捷径组成。

 按照残差结构进行定义,大致与BasicBlock参数一样,不同的是expansion=4,卷积核个数是之前的4倍。

class Bottleneck(nn.Module):expansion = 4def __init__(self, in_channel, out_channel, stride=1, downsample=None,groups=1, width_per_group=64):super(Bottleneck, self).__init__()width = int(out_channel * (width_per_group / 64.)) * groupsself.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,kernel_size=1, stride=1, bias=False)  # squeeze channelsself.bn1 = nn.BatchNorm2d(width)# -----------------------------------------self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,kernel_size=3, stride=stride, bias=False, padding=1)self.bn2 = nn.BatchNorm2d(width)# -----------------------------------------self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,kernel_size=1, stride=1, bias=False)  # unsqueeze channelsself.bn3 = nn.BatchNorm2d(out_channel*self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsample

 定义正向传播

 if self.downsample is not None :  is None是实线  is not None 是虚线

    def forward(self, x):identity = xif self.downsample is not None:identity = self.downsample(x)out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)out += identityout = self.relu(out)return out

二、 定义网络结构

block,  根据定义不同的层结构传入不同的block
blocks_num,  所使用残差结构的数目、参数列表
num_classes=1000, 分类个数
include_top=True, 在ResNet基础上搭建其他的网络

self.layer1 对应conv2_x
self.layer2 对应conv3_x
self.layer3 对应conv4_x
self.layer4 对应conv5_x 这一系列的残差结构都通过_make_layer函数实线

self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)  不管输入是多杀,通过自适应平均池化下采样都会输出(1,1)

self.fc = nn.Linear(512 * block.expansion, num_classes)  通过全连接输出节点层,输入的节点个数是通过平均池化下采样层后的特征矩阵展平后所得到的节点个数,但是由于节点的高和宽都是1,所以节点的个数=深度

class ResNet(nn.Module):def __init__(self,block,blocks_num,num_classes=1000,include_top=True,groups=1,width_per_group=64):super(ResNet, self).__init__()self.include_top = include_top   # 将参数参入变为类变量self.in_channel = 64  # 表格中通过maxpooling后得到的深度self.groups = groupsself.width_per_group = width_per_groupself.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,padding=3, bias=False)self.bn1 = nn.BatchNorm2d(self.in_channel)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(block, 64, blocks_num[0])self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)if self.include_top:self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)self.fc = nn.Linear(512 * block.expansion, num_classes)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

三、定义_make_layer函数

block,  定义的残差结构

channel, 残差结构中卷积层使用卷积核的个数

block_num, 该层包含了几个残差结构

block(传入第一层残差结构

        self.in_channel,   
        channel,   主分支第一个卷积层卷积核的个数
        downsample=downsample,  下采样函数

 for _ in range(1, block_num): 通过循环,将剩下一系列的实线残差结构压入进去

 return nn.Sequential(*layers)  非关键字传入,将定义的一系列层结构组合并返回。

    def _make_layer(self, block, channel, block_num, stride=1):downsample = Noneif stride != 1 or self.in_channel != channel * block.expansion:   # 18、34层会跳过这一部分 ;50、101、152层不会downsample = nn.Sequential(  # 生成下采样函数nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(channel * block.expansion))layers = []  # 定义空的列表layers.append(block(self.in_channel,channel,downsample=downsample,stride=stride,groups=self.groups,width_per_group=self.width_per_group))self.in_channel = channel * block.expansionfor _ in range(1, block_num):layers.append(block(self.in_channel,channel,groups=self.groups,width_per_group=self.width_per_group))return nn.Sequential(*layers)

 定义正向传播

    def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)if self.include_top:x = self.avgpool(x)x = torch.flatten(x, 1)  # 展平后链接x = self.fc(x)return x

四、建立不同的层结构 

传入的参数分别按照定义的顺序传入。

def resnet34(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnet34-333f7ec4.pthreturn ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)def resnet50(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnet50-19c8e357.pthreturn ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)def resnet101(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnet101-5d3b4d8f.pthreturn ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)

后续继续补充。。。


http://www.mrgr.cn/p/81432256

相关文章

解锁AI的神秘力量:LangChain4j带你步入智能化实践之门

关注微信公众号 “程序员小胖” 每日技术干货,第一时间送达! 引言 在数字化转型的浪潮中,人工智能(AI)正逐渐成为推动企业创新和增长的关键力量。然而,将AI技术融入到日常业务流程并非易事,它…

ubuntu下多jdk环境轻松却换

在实际coding生活中,维护老项目与开发新项目常是并行的。快速企业java开发jdk版本,收首先要解决的问题。 今天看到一篇Blog,参考配置后完美实现了一键快速切换,nice!!!!!! 环境: 1、ubuntu 22 2、openjdk1.8、openjdk17 具体操作步骤: 1、安装openjdk(略),安装位…

CentOS 7 部署 NET6.0 项目过程

1、环境配置NET6.0 环境搭建主要是SDK 和 runtime 的安装,下图截自官网说明了SDK 和 runtime 的关系CentOS7 安装SDK 方法第一步:rpm -Uvh https://packages.microsoft.com/config/centos/7/packages-microsoft-prod.rpm第二部:yum install dotnet-sdk-6.0也可以只安装对应的…

接入大量设备后,视频汇聚系统EasyCVR安防监控视频融合平台是如何实现负载均衡的?

一、负载均衡 随着技术的不断进步和监控需求的日益增长,企业视频监控系统的规模也在不断扩大,接入大量监控设备已成为一项常态化的挑战。为确保企业能够有效应对这一挑战,视频汇聚系统EasyCVR视频融合平台凭借其卓越的高并发处理能力&#x…

软件杯 深度学习花卉识别 - python 机器视觉 opencv

文章目录 0 前言1 项目背景2 花卉识别的基本原理3 算法实现3.1 预处理3.2 特征提取和选择3.3 分类器设计和决策3.4 卷积神经网络基本原理 4 算法实现4.1 花卉图像数据4.2 模块组成 5 项目执行结果6 最后 0 前言 🔥 优质竞赛项目系列,今天要分享的是 &a…

细说夜莺监控系统告警自愈机制

虽说监控系统最侧重的功能是指标采集、存储、分析、告警,为了能够快速恢复故障,告警自愈机制也是需要重点投入建设的,所有可以固化为脚本的应急预案都可以使用告警自愈机制来快速驱动。夜莺开源项目从 v7 版本开始内置了告警自愈模块,本文将详细介绍告警自愈的原理和实现。…

掌握Android Fragment开发之魂:Fragment的深度解析(上)

Fragment是Android开发中用于构建动态和灵活界面的基石。它不仅提升了应用的模块化程度,还增强了用户界面的动态性和交互性,允许开发者将应用界面划分为多个独立、可重用的部分,每个部分都可以独立于其他部分进行操作。本文将从以下几个方面深…

【web网页制作】html+css旅游家乡河南开封主题网页制作(4页面)【附源码】

HTMLCSS家乡河南主题网页目录 🍔涉及知识🥤写在前面🍧一、网页主题🌳二、页面效果Page1 首页Page2 开封游玩Page 3 开封美食Page4 留言 🌈 三、网页架构与技术3.1 脑海构思3.2 整体布局3.3 技术说明书 🐋四…

程序员副业创富:业余时间解锁首笔财富里程碑

在这个充满机遇的数字时代,我,一个普通的程序猿,编程爱好者,终于在云端源想这个平台上收获了属于我的第一桶金。这是一个关于兼职、学习与成长的故事,希望能激发同在编程路上的你,勇敢迈出那一步。 先晒晒我的首笔收入:一个普通的周末,我像往常一样,泡上一杯咖啡,坐在…

(一)文本分类经典模型之CNN篇

这篇blog对NLP领域的基本任务文本分类的CNN经典模型做了梳理CNN源于计算机视觉研究,后来诸多学者将其应用于短文本分类,其基本结构如下图所示:由上图可知,基于CNN的短文本分类模型,通常包括输入层、卷积层、池化层、全连接层和输出层五部分,其中卷积层和池化层是最为关键…

抖音小店是什么?它和直播带货有什么区别和联系?一篇详解!

大家好,我是电商糖果 在网上大家都说抖音的流量大,在抖音做电商比较赚钱。 可是有很多人对抖音电商并不了解。 甚至搞不懂抖音小店是什么?它和直播带货的区别和联系也不清楚。 下面,糖果就来给大家好好解答一下这个问题。 抖音…

Django 4.x 智能分页get_elided_page_range

Django智能分页 分页效果 第1页的效果 第10页的效果 带输入框的效果 主要函数 # 参数解释 # number: 当前页码,默认:1 # on_each_side:当前页码前后显示几页,默认:3 # on_ends:首尾固定显示几页&#…

Apache DolphinScheduler 3.3.0 版本重磅更新提前看!

Apache DolphinScheduler 3.3.0版本终于要在万众期待中发布啦!本次发版将有重大功能更新,包括架构上的调整。 为了让广大用户提前尝鲜,社区特别准备了直播活动提前揭秘3.3.0版本中的重要更新,到时候你将会了解到这些信息:3.3.0版本的工作流引擎改进 任务执行流程的优化 架…

激光雕刻优化:利用RLE压缩技术提高雕刻效率与节省能源成本

什么是 RLE ?RLE 在激光雕刻应用实现代码:总结 什么是 RLE ? RLE 是 Run-Length Encoding(游程长度编码)的缩写。这是一种数据压缩技术,它通过减少连续重复的数据来减小文件的大小。RLE 在图像处理、无损…

【重塑世界的火种】制造业:从匠人之心到智能未来之旅

在人类文明的宏伟乐章中,有一段旋律始终激昂,它既古老又现代,既是力量的象征,也是智慧的结晶——这就是制造业,一个将梦想变为现实,将创意铸就为生活的神奇领域。今天,让我们一起走进这个塑造世…

【触想智能】工业级平板电脑五大特征与应用领域分析

工业级平板电脑是专供工业环境使用的工业控制计算机,也被称为工控一体机。工业级平板电脑基本性能及兼容性与商用平板电脑几乎相同,但是工业级平板电脑更注重在不同环境下的稳定性能,因此,工业级平板电脑与普通的商用平板电脑存在一定的区别。一、工业级平板电脑的五大特征…

2024软件测试自动化面试题(含答案)

1.如何把自动化测试在公司中实施并推广起来的? 选择长期的有稳定模块的项目 项目组调研选择自动化工具并开会演示demo案例,我们主要是演示selenium和robot framework两种。 搭建自动化测试框架,在项目中逐步开展自动化。 把该项目的自动化…

58微聊消息自动回复 – 58微聊自动回复机器人 – 浏览器插件

58同城上发布了产品,有咨询客户通过微聊联系我们,我们插件可以实现自动回复消息效果演示 58微聊消息自动回复,浏览器插件实现 #自动回复 #58同城 #58 – 抖音 (douyin.com) 功能列表关键词自动回复AI知识库自动回复下载插件 请联系微信:llike620 付费获取浏览器插件 原文地…

企业网站从传统服务器迁移到弹性云有什么优势呢?

现代企业对于网站和应用程序的可用性和性能要求越来越高,传统基础设施可能无法满足这些需求。弹性云作为一种新兴的云计算服务模式,对于企业网站的运行和管理带来了许多优势。下面是企业网站从传统服务器迁移到弹性云的五大优势: 灵活弹性&a…

黑马点评项目总结

登录 基于session登录 短信验证码登录 配置登录拦截器 向 Spring MVC 框架中添加拦截器,LoginInterceptor 是一个自定义的拦截器,用于拦截用户的登录请求。 excludePathPatterns这一句是设置拦截器需要放行的请求路径列表。 "/user/code", …