当前位置: 首页 > news >正文

神经网络卷积层

一、卷积操作

对应位置相乘相加,最终组成一个新的矩阵,实现了降维。

二、代码

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d
from torch.utils.data import DataLoaderdataset = torchvision.datasets.CIFAR10("../data", train= False, transform=torchvision.transforms.ToTensor(),download=True)
dataloader = DataLoader(dataset, batch_size=64)# 搭建简单神经网络
class aying(nn.Module):def __init__(self): # 初始化super(aying, self).__init__() # 父类# 定义卷积层,self使用后,后面的参数在其他函数中也可以使用self.conv1 = Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=0)# 定义forward函数def forward(self, x):# x是输出x = self.conv1(x) # x 已经放进卷积层中return x
# 初始化该网络A
Aying = aying()
print(Aying) # 查看该网络结构

运行该代码,就可以在运行结果中,查看到:

aying((conv1): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1)))

 将加载的数据放进卷积神经网络中处理:

# 将加载的数据放进卷积神经网络中,对数据进行处理
for data in dataloader:imgs,targets = dataoutput = Aying(imgs)print(imgs.shape)print(output.shape)

处理后,得到的输出:

  (conv1): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1))
)
torch.Size([64, 3, 32, 32])
torch.Size([64, 6, 30, 30])

输入的imgs图片的batch_size为64,通道为3,图片大小为32*32

卷积后的输出output:batch_size为64,通道为6,图片大小为30*30(降维)

使用tensorboard查看:

writer = SummaryWriter("../juanji")
step =0

在for循环体中:

    # 卷积神经网络在tensorboard中查看writer.add_image("input", imgs, step)writer.add_image("output", output, step)step +=1writer.close()

使用这样的方法按理说,我们能够在tensorboard中查看我们所需要的图像,但是出现了报错,显示的是,输入为3通道的图像,输出的却为6通道的图像,于是使用reshape的方法对其分开:

 # torch.Size([64, 3, 32, 32])writer.add_image("input", imgs, step, dataformats="NCHW")# torch.Size([64, 6, 30, 30]) ,六个chanel,不能显示# 将其拆分,6通道的拆分为3通道,使用-1实现划分batch_sizeoutput = torch.reshape(output,(-1, 3, 30, 30))writer.add_image("output", output, step, dataformats="NCHW")step +=1

 同时,会显示报错,代码中写入(-1,3,30,30)型的,但是在    writer.add_image("output", output, step)却只有三个参数,于是还有需要加入dataformats="NCHW"。

成功运行,显示结果:

输出进行卷积后图像。


http://www.mrgr.cn/news/15058.html

相关文章:

  • 零基础一文学会Docker与Kubernetes
  • LVS工作模式
  • Python制作的桌面宠物-python实战-python源码-python项目练习
  • 《深入浅出WPF》读书笔记.9Command系统
  • Redis: 用于纯缓存模式需要注意的地方
  • ubuntu 更新网卡丢失
  • Java 入门指南:初识 Java NIO
  • 数据结构——归并排序
  • “npm run serve”到51%就卡住【完美解决】
  • redis的紧凑列表ziplist、quicklist、listpack
  • C语言阴阳迷宫
  • C# 实现傅里叶变化(DFT)
  • 38. 字符串的排列【难】
  • 工作中常用的100个知识点
  • centos yum 源停用整改
  • PostgreSQL支持的数据类型
  • 28 TreeView组件
  • MyBatis中#{}和 ${}的区别是什么?
  • VScode应用有哪些?
  • 设计模式 8 组合模式