当前位置: 首页 > news >正文

MATLAB代码解析:利用DCGAN实现图像数据的生成

摘要

经典代码:利用DCGAN生成花朵

MATLAB官方其实给出了DCGAN生成花朵的示范代码,原文地址:训练生成对抗网络 (GAN) - MATLAB & Simulink - MathWorks 中国

先看看训练效果

训练1周期

训练11周期

训练56个周期

脚本文件 

为了能让各位更好的复现,该代码已打包,下载后解压运行用MATLAB运行"gan.mlx"即可
链接: https://pan.baidu.com/s/1hNYLw1xku2AdKf5CanoFzA?pwd=fb7n 提取码: fb7n 
 

代码详解:

首先是脚本gan:

数据获取
clear all
clc
imageFolder = fullfile("flower_photos");
imds = imageDatastore(imageFolder,IncludeSubfolders=true);
augmenter = imageDataAugmenter(RandXReflection=true);
augimds = augmentedImageDatastore([64 64],imds,DataAugmentation=augmenter);
生成器
filterSize = 5;
numFilters = 64;
numLatentInputs = 100;projectionSize = [4 4 512];%layersGenerator = [featureInputLayer(numLatentInputs)projectAndReshapeLayer(projectionSize)transposedConv2dLayer(filterSize,4*numFilters)batchNormalizationLayerreluLayertransposedConv2dLayer(filterSize,2*numFilters,Stride=2,Cropping="same")batchNormalizationLayerreluLayertransposedConv2dLayer(filterSize,numFilters,Stride=2,Cropping="same")batchNormalizationLayerreluLayertransposedConv2dLayer(filterSize,3,Stride=2,Cropping="same")tanhLayer];
netG = dlnetwork(layersGenerator);
判别器
dropoutProb = 0.5;
numFilters = 64;
scale = 0.2;inputSize = [64 64 3];
filterSize = 5;layersDiscriminator = [imageInputLayer(inputSize,Normalization="none")dropoutLayer(dropoutProb)convolution2dLayer(filterSize,numFilters,Stride=2,Padding="same")leakyReluLayer(scale)convolution2dLayer(filterSize,2*numFilters,Stride=2,Padding="same")batchNormalizationLayerleakyReluLayer(scale)convolution2dLayer(filterSize,4*numFilters,Stride=2,Padding="same")batchNormalizationLayerleakyReluLayer(scale)convolution2dLayer(filterSize,8*numFilters,Stride=2,Padding="same")batchNormalizationLayerleakyReluLayer(scale)convolution2dLayer(4,1)sigmoidLayer];
netD = dlnetwork(layersDiscriminator);
指定训练选项
numEpochs = 500;
miniBatchSize = 128;
learnRate = 0.00008;
gradientDecayFactor = 0.5;
squaredGradientDecayFactor = 0.999;
flipProb = 0.35;
validationFrequency = 100;
训练模型
augimds.MiniBatchSize = miniBatchSize;mbq = minibatchqueue(augimds, ...MiniBatchSize=miniBatchSize, ...PartialMiniBatch="discard", ...MiniBatchFcn=@preprocessMiniBatch, ...MiniBatchFormat="SSCB");
trailingAvgG = [];
trailingAvgSqG = [];
trailingAvg = [];
trailingAvgSqD = [];
numValidationImages = 25;
ZValidation = randn(numLatentInputs,numValidationImages,"single");
ZValidation = dlarray(ZValidation,"CB");
if canUseGPUZValidation = gpuArray(ZValidation);
endf = figure;
f.Position(3) = 2*f.Position(3);imageAxes = subplot(1,2,1);
scoreAxes = subplot(1,2,2);C = colororder;
lineScoreG = animatedline(scoreAxes,Color=C(1,:));
lineScoreD = animatedline(scoreAxes,Color=C(2,:));
legend("Generator","Discriminator");
ylim([0 1])
xlabel("Iteration")
ylabel("Score")
grid oniteration = 0;
start = tic;% Loop over epochs.
for epoch = 1:numEpochs% Reset and shuffle datastore.shuffle(mbq);% Loop over mini-batches.while hasdata(mbq)iteration = iteration + 1;% Read mini-batch of data.X = next(mbq);% Generate latent inputs for the generator network. Convert to% dlarray and specify the format "CB" (channel, batch). If a GPU is% available, then convert latent inputs to gpuArray.Z = randn(numLatentInputs,miniBatchSize,"single");Z = dlarray(Z,"CB");if canUseGPUZ = gpuArray(Z);end% Evaluate the gradients of the loss with respect to the learnable% parameters, the generator state, and the network scores using% dlfeval and the modelLoss function.[L,~,gradientsG,gradientsD,stateG,scoreG,scoreD] = ...dlfeval(@modelLoss,netG,netD,X,Z,flipProb);netG.State = stateG;%%show data%"epoch"%epoch%"scoreG-D"%[scoreG,scoreD]% Update the discriminator network parameters.[netD,trailingAvg,trailingAvgSqD] = adamupdate(netD, gradientsD, ...trailingAvg, trailingAvgSqD, iteration, ...learnRate, gradientDecayFactor, squaredGradientDecayFactor);% Update the generator network parameters.[netG,trailingAvgG,trailingAvgSqG] = adamupdate(netG, gradientsG, ...trailingAvgG, trailingAvgSqG, iteration, ...learnRate, gradientDecayFactor, squaredGradientDecayFactor);% Every validationFrequency iterations, display batch of generated% images using the held-out generator input.if mod(iteration,validationFrequency) == 0 || iteration == 1% Generate images using the held-out generator input.XGeneratedValidation = predict(netG,ZValidation);% Tile and rescale the images in the range [0 1].I = imtile(extractdata(XGeneratedValidation));I = rescale(I);% Display the images.subplot(1,2,1);image(imageAxes,I)xticklabels([]);yticklabels([]);title("Generated Images");end% Update the scores plot.subplot(1,2,2)scoreG = double(extractdata(scoreG));addpoints(lineScoreG,iteration,scoreG);scoreD = double(extractdata(scoreD));addpoints(lineScoreD,iteration,scoreD);% Update the title with training progress information.D = duration(0,0,toc(start),Format="hh:mm:ss");title(..."Epoch: " + epoch + ", " + ..."Iteration: " + iteration + ", " + ..."Elapsed: " + string(D))drawnowend
end
生成新图像  
numObservations = 4;
ZNew = randn(numLatentInputs,numObservations,"single");
ZNew = dlarray(ZNew,"CB");
if canUseGPUZNew = gpuArray(ZNew);
endXGeneratedNew = predict(netG,ZNew);I = imtile(extractdata(XGeneratedNew));
I = rescale(I);
figure
image(I)
axis off
title("Generated Images")

生成器与判别器的设计

脚本gan中以及包含了生成器Generator和判别器Discriminator的结构设计,生成器利用装置卷积对特征进行上采样,最终生成了64*64*3的图像,而判别器则用卷积进行下采样,将输入提取至1*1的格式大小,利用sigmoid作为激活函数,判断输入图像的真假

如何自定义生成对抗网络?很简单,把握上采样和下采样的规模就行,利用MATLAB的DLtool(deep network designer)可以很好的观察到这一点,以刚刚的生成器为例,我们可以观察到,转置卷积后(步幅为2),输出的空间(S)长宽都翻倍,深度对应我们给定的filters数量,因此,我们想要生成特定大小的数据时,修改转置卷积的步幅、卷积核数量以及转置卷积层的数量就行,同时记得在添加的转置卷积层后连接新的BN层和ReLU激活函数。

比如我想生成128*128*3的图片,我只需要将刚刚示例中的其中一个转置卷积核的大小提高至7*7,同时步幅修改成4。或者,我直接添加一层步幅为2的转置卷积层。对于一些数据尺寸为非2倍数问题,如311*171*3,我们可以先生成312*172*3再resize一下,或者你提前将数据预处理成312*172.

注意:定义完网络结构后,要用dlnetwork()函数将layer参数转变成可训练的dlnetwork。

最近比较忙,先在这里停笔了,后面再慢慢补充-24-10-14

数据预处理

自定义模型训练

损失函数与梯度下降

优化器与参数更新

总结


http://www.mrgr.cn/news/50149.html

相关文章:

  • 【时间盒子】-【10.自定义弹窗】CustomDialogController
  • 扭亏年只是开始,赛力斯的成长性仍在继续
  • LabVIEW智能可变温循环PCT测试系统
  • micro-memoize 缓存计算结果
  • 一文搞懂进程、线程、协程以及并发、并行、串行的概念
  • 解锁机器学习的新维度:元学习的算法与应用探秘
  • 【随手记】IE和精益的区别
  • CST学习笔记(二)Floquet模式激励设置
  • Java之反射机制详解
  • 【LLM】三种多轮对话微调数据格式
  • 《大道平渊》· 廿叁 —— 不要急着创业,潜龙勿用。
  • 外包干了5天,技术明显退步
  • 我是怎么赚到第一个100万的
  • Java多线程
  • MySQL中的增查操作:探索数据的奥秘,开启数据之门
  • 基于SSM+微信小程序的实验室设备故障报修管理系统2
  • 【JavaScript进阶】深入探讨JS中的对象及其事件处理
  • 四、Spring Boot集成Spring Security之认证流程
  • 命令注入和文件上传分析
  • 计组-CPU构成(运算器与控制器的组成)