MATLAB代码解析:利用DCGAN实现图像数据的生成 全网最细DCGAN设计-训练入门
0.摘要
本文介绍了如何利用MATLAB中的DCGAN(深度卷积生成对抗网络)实现图像数据的生成,具体以生成花朵图像为例。文章首先提供了训练效果的展示,包括训练1周期、11周期和56个周期的效果图。随后,文章详细解析了MATLAB官方给出的DCGAN生成花朵的示范代码。
代码部分主要包括以下几个部分:
- 数据获取与预处理:通过
imageDatastore
函数加载花朵图像数据集,并使用imageDataAugmenter
进行数据增强,生成增强后的图像数据存储augimds
。 - 生成器网络构建:定义了生成器的网络结构,包括输入层、投影与重塑层、多个转置卷积层、批量归一化层、ReLU激活层和tanh输出层。
- 判别器网络构建:定义了判别器的网络结构,包括输入层、dropout层、多个卷积层、批量归一化层、leakyReLU激活层和sigmoid输出层。
- 训练选项指定:设置了训练周期数、小批量大小、学习率等训练参数。
- 模型训练设计:使用
minibatchqueue
函数创建小批量队列,并定义了预处理函数preprocessMiniBatch
。随后,通过循环迭代进行模型训练,同时记录生成器和判别器的梯度平均值等信息。
此外,文章还提供了代码打包下载的链接,方便读者复现实验结果。通过本文的介绍和代码解析,读者可以深入了解DCGAN在图像生成方面的应用和实现过程。
1.经典代码:利用DCGAN生成花朵
MATLAB官方其实给出了DCGAN生成花朵的示范代码,原文地址:训练生成对抗网络 (GAN) - MATLAB & Simulink - MathWorks 中国
先看看训练效果:
训练1周期
训练11周期
训练56个周期
2.脚本文件
为了能让各位更好的复现,该代码已打包,下载后解压运行用MATLAB运行"gan.mlx"即可
链接: https://pan.baidu.com/s/1OyVYpMuve6KdKD81CZoyXQ?pwd=83sd 提取码: 83sd
3.代码详解:
首先是脚本gan:
数据获取
clear all
clc
imageFolder = fullfile("flower_photos");
imds = imageDatastore(imageFolder,IncludeSubfolders=true);
augmenter = imageDataAugmenter(RandXReflection=true);
augimds = augmentedImageDatastore([64 64],imds,DataAugmentation=augmenter);
生成器
filterSize = 5;
numFilters = 64;
numLatentInputs = 100;projectionSize = [4 4 512];%layersGenerator = [featureInputLayer(numLatentInputs)projectAndReshapeLayer(projectionSize)transposedConv2dLayer(filterSize,4*numFilters)batchNormalizationLayerreluLayertransposedConv2dLayer(filterSize,2*numFilters,Stride=2,Cropping="same")batchNormalizationLayerreluLayertransposedConv2dLayer(filterSize,numFilters,Stride=2,Cropping="same")batchNormalizationLayerreluLayertransposedConv2dLayer(filterSize,3,Stride=2,Cropping="same")tanhLayer];
netG = dlnetwork(layersGenerator);
判别器
dropoutProb = 0.5;
numFilters = 64;
scale = 0.2;inputSize = [64 64 3];
filterSize = 5;layersDiscriminator = [imageInputLayer(inputSize,Normalization="none")dropoutLayer(dropoutProb)convolution2dLayer(filterSize,numFilters,Stride=2,Padding="same")leakyReluLayer(scale)convolution2dLayer(filterSize,2*numFilters,Stride=2,Padding="same")batchNormalizationLayerleakyReluLayer(scale)convolution2dLayer(filterSize,4*numFilters,Stride=2,Padding="same")batchNormalizationLayerleakyReluLayer(scale)convolution2dLayer(filterSize,8*numFilters,Stride=2,Padding="same")batchNormalizationLayerleakyReluLayer(scale)convolution2dLayer(4,1)sigmoidLayer];
netD = dlnetwork(layersDiscriminator);
指定训练选项
numEpochs = 500;
miniBatchSize = 128;
learnRate = 0.00008;
gradientDecayFactor = 0.5;
squaredGradientDecayFactor = 0.999;
flipProb = 0.35;
validationFrequency = 100;
训练模型
augimds.MiniBatchSize = miniBatchSize;mbq = minibatchqueue(augimds, ...MiniBatchSize=miniBatchSize, ...PartialMiniBatch="discard", ...MiniBatchFcn=@preprocessMiniBatch, ...MiniBatchFormat="SSCB");
trailingAvgG = [];
trailingAvgSqG = [];
trailingAvg = [];
trailingAvgSqD = [];
numValidationImages = 25;
ZValidation = randn(numLatentInputs,numValidationImages,"single");
ZValidation = dlarray(ZValidation,"CB");
if canUseGPUZValidation = gpuArray(ZValidation);
endf = figure;
f.Position(3) = 2*f.Position(3);imageAxes = subplot(1,2,1);
scoreAxes = subplot(1,2,2);C = colororder;
lineScoreG = animatedline(scoreAxes,Color=C(1,:));
lineScoreD = animatedline(scoreAxes,Color=C(2,:));
legend("Generator","Discriminator");
ylim([0 1])
xlabel("Iteration")
ylabel("Score")
grid oniteration = 0;
start = tic;% Loop over epochs.
for epoch = 1:numEpochs% Reset and shuffle datastore.shuffle(mbq);% Loop over mini-batches.while hasdata(mbq)iteration = iteration + 1;% Read mini-batch of data.X = next(mbq);% Generate latent inputs for the generator network. Convert to% dlarray and specify the format "CB" (channel, batch). If a GPU is% available, then convert latent inputs to gpuArray.Z = randn(numLatentInputs,miniBatchSize,"single");Z = dlarray(Z,"CB");if canUseGPUZ = gpuArray(Z);end% Evaluate the gradients of the loss with respect to the learnable% parameters, the generator state, and the network scores using% dlfeval and the modelLoss function.[L,~,gradientsG,gradientsD,stateG,scoreG,scoreD] = ...dlfeval(@modelLoss,netG,netD,X,Z,flipProb);netG.State = stateG;%%show data%"epoch"%epoch%"scoreG-D"%[scoreG,scoreD]% Update the discriminator network parameters.[netD,trailingAvg,trailingAvgSqD] = adamupdate(netD, gradientsD, ...trailingAvg, trailingAvgSqD, iteration, ...learnRate, gradientDecayFactor, squaredGradientDecayFactor);% Update the generator network parameters.[netG,trailingAvgG,trailingAvgSqG] = adamupdate(netG, gradientsG, ...trailingAvgG, trailingAvgSqG, iteration, ...learnRate, gradientDecayFactor, squaredGradientDecayFactor);% Every validationFrequency iterations, display batch of generated% images using the held-out generator input.if mod(iteration,validationFrequency) == 0 || iteration == 1% Generate images using the held-out generator input.XGeneratedValidation = predict(netG,ZValidation);% Tile and rescale the images in the range [0 1].I = imtile(extractdata(XGeneratedValidation));I = rescale(I);% Display the images.subplot(1,2,1);image(imageAxes,I)xticklabels([]);yticklabels([]);title("Generated Images");end% Update the scores plot.subplot(1,2,2)scoreG = double(extractdata(scoreG));addpoints(lineScoreG,iteration,scoreG);scoreD = double(extractdata(scoreD));addpoints(lineScoreD,iteration,scoreD);% Update the title with training progress information.D = duration(0,0,toc(start),Format="hh:mm:ss");title(..."Epoch: " + epoch + ", " + ..."Iteration: " + iteration + ", " + ..."Elapsed: " + string(D))drawnowend
end
生成新图像
numObservations = 4;
ZNew = randn(numLatentInputs,numObservations,"single");
ZNew = dlarray(ZNew,"CB");
if canUseGPUZNew = gpuArray(ZNew);
endXGeneratedNew = predict(netG,ZNew);I = imtile(extractdata(XGeneratedNew));
I = rescale(I);
figure
image(I)
axis off
title("Generated Images")
3.1生成器与判别器的设计
脚本gan中以及包含了生成器Generator和判别器Discriminator的结构设计,生成器利用装置卷积对特征进行上采样,最终生成了64*64*3的图像,而判别器则用卷积进行下采样,将输入提取至1*1的格式大小,利用sigmoid作为激活函数,判断输入图像的真假。
3.1.1projectAndReshapeLayer
其中,在生成器中,应用了一自定义的网络层:“projectAndReshapeLayer”,该层代码如下所示:
classdef projectAndReshapeLayer < nnet.layer.Layer ...& nnet.layer.Formattable ...& nnet.layer.Acceleratableproperties% Layer properties.OutputSizeendproperties (Learnable)% Layer learnable parameters.WeightsBiasendmethodsfunction layer = projectAndReshapeLayer(outputSize,NameValueArgs)% layer = projectAndReshapeLayer(outputSize)% creates a projectAndReshapeLayer object that projects and% reshapes the input to the specified output size.%% layer = projectAndReshapeLayer(outputSize,Name=name)% also specifies the layer name.% Parse input arguments.argumentsoutputSizeNameValueArgs.Name = "";end% Set layer name.name = NameValueArgs.Name;layer.Name = name;% Set layer description.layer.Description = "Project and reshape to size " + ...join(string(outputSize));% Set layer type.layer.Type = "Project and Reshape";% Set output size.layer.OutputSize = outputSize;endfunction layer = initialize(layer,layout)% layer = initialize(layer,layout) initializes the layer% learnable parameters.%% Inputs:% layer - Layer to initialize% layout - Data layout, specified as a % networkDataLayout object%% Outputs:% layer - Initialized layer% Layer output size.outputSize = layer.OutputSize;% Initialize fully connect weights.if isempty(layer.Weights)% Find number of channels.idx = finddim(layout,"C");numChannels = layout.Size(idx);% Initialize using Glorot.sz = [prod(outputSize) numChannels];numOut = prod(outputSize);numIn = numChannels;layer.Weights = initializeGlorot(sz,numOut,numIn);end% Initialize fully connect bias.if isempty(layer.Bias)% Initialize with zeros.layer.Bias = initializeZeros([prod(outputSize) 1]);endendfunction Z = predict(layer, X)% Forward input data through the layer at prediction time and% output the result.%% Inputs:% layer - Layer to forward propagate through% X - Input data, specified as a formatted dlarray% with a "C" and optionally a "B" dimension.% Outputs:% Z - Output of layer forward function returned as% a formatted dlarray with format "SSCB".% Fully connect.weights = layer.Weights;bias = layer.Bias;X = fullyconnect(X,weights,bias);% Reshape.outputSize = layer.OutputSize;Z = reshape(X,outputSize(1),outputSize(2),outputSize(3),[]);Z = dlarray(Z,"SSCB");endend
end
这个层的作用是对输入数据进行投影(通过一个全连接操作)并重塑为指定的输出尺寸(SSCB),S对应为空间维度、C代表为通道维度,B代表批次维度。
-1 类的定义和继承
classdef projectAndReshapeLayer < nnet.layer.Layer ... & nnet.layer.Formattable ... & nnet.layer.Acceleratable
这个类继承自 nnet.layer.Layer
(基本的神经网络层类)、nnet.layer.Formattable
(支持格式化的层)和 nnet.layer.Acceleratable
(支持加速的层)。这允许 projectAndReshapeLayer
使用和扩展 MATLAB 深度学习工具箱中层的标准功能。
-2 属性
properties OutputSize
end properties (Learnable) Weights Bias
end
OutputSize
:这是一个普通属性,用于存储层的输出尺寸。Weights
和Bias
:这些是可学习参数,用于全连接操作。
-3 构造方法
function layer = projectAndReshapeLayer(outputSize,NameValueArgs) arguments outputSize NameValueArgs.Name = ""; end name = NameValueArgs.Name; layer.Name = name; layer.Description = "Project and reshape to size " + join(string(outputSize)); layer.Type = "Project and Reshape"; layer.OutputSize = outputSize;
end
构造方法用于创建 projectAndReshapeLayer
对象。它接受输出尺寸和一个可选的名称参数。构造方法设置了层的名称、描述、类型和输出尺寸。具体而言,其对应脚本gan.m中的“projectAndReshapeLayer(projectionSize)”这段语句,使我们可以快速在网络构建中利用改层
-4 初始化方法
function layer = initialize(layer,layout) outputSize = layer.OutputSize; if isempty(layer.Weights) idx = finddim(layout,"C"); numChannels = layout.Size(idx); sz = [prod(outputSize) numChannels]; numOut = prod(outputSize); numIn = numChannels; layer.Weights = initializeGlorot(sz,numOut,numIn); end if isempty(layer.Bias) layer.Bias = initializeZeros([prod(outputSize) 1]); end
end
初始化方法用于设置层的可学习参数(权重和偏置)。它根据输入数据的布局(layout)来确定输入通道的数量,并使用 Glorot 初始化方法来初始化权重。偏置被初始化为零。
-5 预测方法
function Z = predict(layer, X) weights = layer.Weights; bias = layer.Bias; X = fullyconnect(X,weights,bias); outputSize = layer.OutputSize; Z = reshape(X,outputSize(1),outputSize(2),outputSize(3),[]); Z = dlarray(Z,"SSCB");
end
预测方法是层的前向传播函数。它首先对输入数据 X
应用全连接操作(使用层的权重和偏置),然后将结果重塑为指定的输出尺寸,并将输出格式设置为 "SSCB"(空间、空间、通道、批次)。
-6小结
这个 projectAndReshapeLayer
自定义层的主要作用是将输入数据通过全连接层进行投影,并将输出重塑为特定的多维尺寸。这种类型的层在深度学习模型中可能用于特征映射或维度变换,特别是在需要输出特定形状的特征图时。通过使用全连接操作和重塑,这个层能够灵活地适应不同的输入和输出需求。
3.2 调整生成图像的大小和尺寸
那么,我们如何自定义生成对抗网络?很简单,把握上采样和下采样的规模就行,利用MATLAB的DLtool(deep network designer)可以很好的观察到这一点,以刚刚的生成器为例,我们可以观察到,转置卷积后(步幅为2),输出的空间(S)长宽都翻倍,深度对应我们给定的filters数量,因此,我们想要生成特定大小的数据时,修改转置卷积的步幅、卷积核数量以及转置卷积层的数量就行,同时记得在添加的转置卷积层后连接新的BN层和ReLU激活函数。
比如我想生成128*128*3的图片,我只需要将刚刚示例中的其中一个转置卷积核的大小提高至7*7,同时步幅修改成4。或者,我直接添加一层步幅为2的转置卷积层。对于一些数据尺寸为非2倍数问题,如311*171*3,我们可以先生成312*172*3再resize一下,或者你提前将数据预处理成312*172.
同时我们也可以观察到“projectAndReshapeLayer”的layer_1可学习的权重参数为8192*100,其中,100为输入的特征数量,8192来自于4*4*512,和刚刚对projectAndReshapeLayer的分析一致,projectAndReshapeLayer本质就是全连接层,只不过将输出重新排列了
注意:定义完网络结构后,要用dlnetwork()函数将layer参数转变成可训练的dlnetwork。
3.3数据预处理
3.3.1数据库建立
在进行对抗训练前,我们需要先对数据进行预处理,在脚本gan.m中,我们先对数据库进行了构建:
imageFolder = fullfile("flower_photos");
imds = imageDatastore(imageFolder,IncludeSubfolders=true);
augmenter = imageDataAugmenter(RandXReflection=true);
augimds = augmentedImageDatastore([64 64],imds,DataAugmentation=augmenter);
在这一过程中,我们利用文件夹flower_photos所有的图片子文件构建了数据库,再对数据进行了X轴的随机方法作为数据增强方法,最后,将图片的空间尺寸压缩至64*64
3.3.2归一化与模型训练准备
在训练脚本gan.m中,我们使用了minibatchqueue()这一函数:
mbq = minibatchqueue(augimds, ...MiniBatchSize=miniBatchSize, ...PartialMiniBatch="discard", ...MiniBatchFcn=@preprocessMiniBatch, ...MiniBatchFormat="SSCB");
注:minibatchqueue
是 MATLAB 中用于深度学习和数据处理的函数,特别是在处理大规模数据集时用于创建和管理小批量(minibatch)数据的队列。这个功能非常有用,因为它允许你高效地从大数据集中提取小批量数据进行训练,从而加速深度学习模型的训练过程。
在该函数中,augimds是输入的数据库,数据类型是datastore; PartialMiniBatch为是否运行拆分minibatch(当显存不足时),MiniBatchFcn为数据归一化,@后面跟着的是自定义函数preprocessMinibatch,也就是uint8格式的数据归一化方法:
function X = preprocessMiniBatch(data)% Concatenate mini-batch
X = cat(4,data{:});% Rescale the images in the range [-1 1].
X = rescale(X,-1,1,InputMin=0,InputMax=255);end
该方法先是将图片数据在第四个维度(也就是SSCB中的B,“批次维度”)整合起来,再把 8位二进制整型(0~2^8-1)映射至-1~1间。
3.4自定义模型训练
3.4.1 训练逻辑
再看看训练前的数据准备,我们先看前半部分
% Loop over epochs.
for epoch = 1:numEpochs% Reset and shuffle datastore.shuffle(mbq);% Loop over mini-batches.while hasdata(mbq)iteration = iteration + 1;% Read mini-batch of data.X = next(mbq);% Generate latent inputs for the generator network. Convert to% dlarray and specify the format "CB" (channel, batch). If a GPU is% available, then convert latent inputs to gpuArray.Z = randn(numLatentInputs,miniBatchSize,"single");Z = dlarray(Z,"CB");if canUseGPUZ = gpuArray(Z);end
在我们定义的训练周期“numEpochs”中,我们会对数据集进行“洗牌”也就是shuffle(mbq),mbq是我们创建的小批量数据格式“minibatchqueue”,然后进行一个条件判断,观察mbq是否还有未训练的数据,如果没有则进入下一个Epoch,有的话则提取下一个最小批次也就是:
X = next(mbq);
这里的X就是我们对抗训练的真样本了,而我们还需要生成器提供的假样本,所以在每个批次中我们还要随机生成数据,批量和大小和该批次的X相同,也就是:
Z = randn(numLatentInputs,miniBatchSize,"single");
Z = dlarray(Z,"CB");
注:这里生成的数据类型未single,因为GPU的计算一般是单精度的浮点运算,再利用dlarray将数据传入GPU,“CB”是数据格式。
那么再训练过程中,模型是如何计算损失函数和更新数据的?
如下所示:训练过程中,程序要先计算生成器和判别器的梯度(gradientsG,gradientsD),这里引用了自定义函数modelLoss,这里使用dlfeval()则是利用GPU去有先执行modelLoss这一自定义函数,因为模型梯度的计算涉及到“反向传播过程”或者更加宏观地说“自动微分技术”。
反向传播与自动微分的关系
- 理论基础:反向传播算法实际上利用了自动微分中的链式法则来计算梯度。链式法则是微积分中的一个基本规则,它描述了复合函数的导数计算方式。在神经网络中,由于网络是由多个层级和非线性函数组成的复合函数,因此需要使用链式法则来计算每个参数相对于损失函数的梯度。
- 实现方式:在深度学习中,反向传播算法通常是通过自动微分技术来实现的。自动微分提供了一种高效且准确的方式来计算神经网络中每个参数相对于损失函数的梯度。这使得反向传播算法能够自动地、逐层地传递梯度并更新网络参数。
- 优化效率:通过自动微分技术,反向传播算法能够高效地处理大规模神经网络中的参数优化问题。自动微分技术不仅提高了梯度计算的准确性,还显著降低了计算复杂度,从而加速了神经网络的训练过程。
计算好梯度后,将计算好的梯度传输给优化器adam,计算新模型的参数。这里利用了adamupdate()函数,在更新参数时,我们仅需要输入模型参数以及梯度参数,其他参数则是提前预设好的,具体数值直接参考脚本原文件即可。
% Evaluate the gradients of the loss with respect to the learnable% parameters, the generator state, and the network scores using% dlfeval and the modelLoss function.[L,~,gradientsG,gradientsD,stateG,scoreG,scoreD] = ...dlfeval(@modelLoss,netG,netD,X,Z,flipProb);netG.State = stateG;% Update the discriminator network parameters.[netD,trailingAvg,trailingAvgSqD] = adamupdate(netD, gradientsD, ...trailingAvg, trailingAvgSqD, iteration, ...learnRate, gradientDecayFactor, squaredGradientDecayFactor);% Update the generator network parameters.[netG,trailingAvgG,trailingAvgSqG] = adamupdate(netG, gradientsG, ...trailingAvgG, trailingAvgSqG, iteration, ...learnRate, gradientDecayFactor, squaredGradientDecayFactor);% Every validationFrequency iterations, display batch of generated% images using the held-out generator input.
注:adamupdate()
是 MATLAB 中的一个函数,用于实现 Adam 优化算法的一步更新。Adam(Adaptive Moment Estimation)是一种基于梯度的一阶优化算法,常用于机器学习和深度学习中的参数优化。它结合了动量(Momentum)和均方根传播(RMSProp)的思想,通过计算梯度的一阶矩估计和二阶矩估计来动态调整学习率,从而加速收敛并改善优化性能。
3.4.2损失函数及梯度计算-反向传播
在获取模型梯度时,利用了自定义函数modelLoss,用于计算损失函数并用于反向传播:
function [lossG,lossD,gradientsG,gradientsD,stateG,scoreG,scoreD] = ...modelLoss(netG,netD,X,Z,flipProb)% Calculate the predictions for real data with the discriminator network.
YReal = forward(netD,X);
% Calculate the predictions for generated data with the discriminator
% network.
[XGenerated,stateG] = forward(netG,Z);
YGenerated = forward(netD,XGenerated);% Calculate the score of the discriminator.
scoreD = (mean(YReal) + mean(1-YGenerated)) / 2;% Calculate the score of the generator.
scoreG = mean(YGenerated);% Randomly flip the labels of the real images.
numObservations = size(YReal,4);
idx = rand(1,numObservations) < flipProb;
YReal(:,:,:,idx) = 1 - YReal(:,:,:,idx);% Calculate the GAN loss.
[lossG, lossD] = ganLoss(YReal,YGenerated);% For each network, calculate the gradients with respect to the loss.
gradientsG = dlgradient(lossG,netG.Learnables,RetainData=true);
gradientsD = dlgradient(lossD,netD.Learnables);end
modelLoss中,代码先是计算真样本再判别器中的预测结果,再利用生成器生成样本,再进行判断,随后计算这两对抗网络的真实得分,但是再训练过程中,我们设置了一定的翻转概率,避免对抗网络“模式坍缩”,所以损失函数要基于随机的标签翻转后计算,最后将计算的损失函数lossG和lossD及其对应模型输入至函数dlgradient()进行反向传播,计算梯度,反向传播和自动微分技术是一个很大的内容,这里就不再赘述。
注:模式坍缩指的是在GANs的训练过程中,生成器(Generator)将不同的输入映射到同一个或少数几个输出的情况。这意味着生成器无法再现训练数据中模式的全部多样性,导致生成的样本缺乏多样性,变得非常相似甚至完全相同
其中,损失函数计算代码:
function [lossG,lossD] = ganLoss(YReal,YGenerated)% Calculate the loss for the discriminator network.
lossD = -mean(log(YReal)) - mean(log(1-YGenerated));% Calculate the loss for the generator network.
lossG = -mean(log(YGenerated));end
具体而言,损失函数可以用以下式子解释:
生成器损失函数数学表达式:
判别器损失函数的数学表达式 :
其中为输入噪声,
为输入的真实样本,
4.总结
DCGAN的设计比较巧妙,但模型本身复杂程度并不算高,只是需要注意计算过程中的一些细节,通过对各个函数进行详解,望读者可以系统、快速的上手MATLAB的DCGAN设计和训练!