当前位置: 首页 > news >正文

Conditional Generative Adversarial Nets

条件生成对抗网络

1.生成对抗网络

生成对网络由两个“对抗性”模型组成:一个生成模型 G,用于捕获数据分布,另一个判别模型 D,用于估计样本来自训练数据而不是 G 的概率。G 和 D 都可以是非线性映射函数。
为了学习数据 x 上的生成器分布 Pg,生成器构建从先验噪声分布 pz(z) 到数据空间的映射函数 G(z; θg)。判别器 D(x; θd) 输出一个标量,表示 x 来自训练数据而不是 pg 的概率。
G 和 D 都是同时训练的:我们调整 G 的参数以最小化 log(1 − D(G(z)) 并调整 D 的参数以最小化 logD(X),就好像它们遵循两人的最小-最大一样价值函数 V (G, D) 的博弈:
在这里插入图片描述

G(Generator) -> 生成模块
D (Discriminator) -> 鉴别模块(输出就结果可以是二进制也可以是一维的置信度)
在这里插入图片描述

2.条件生成对抗网络

如果生成器和判别器都以一些额外的信息 y 为条件,则生成对抗网络可以扩展到条件模型。y 可以是任何类型的辅助信息,例如类标签或来自其他模态的数据。我们可以通过将 y 作为额外的输入层输入到判别器和生成器中来执行调节。
在生成器中,先验输入噪声 pz(z) 和 y 被组合在联合隐藏表示中,并且对抗性训练框架允许在如何组成该隐藏表示方面具有相当大的灵活性。
在判别器中,x 和 y 作为输入呈现给判别函数(在本例中再次由 MLP 体现)。两人迷你最大游戏的目标函数为:
在这里插入图片描述
在这里插入图片描述

3.判别器损失函数

判别器(Discriminator)
判别器的目标是区分生成器生成的假数据和真实数据。它接受来自生成器的输出或真实数据集的样本作为输入,并输出一个概率值,表示输入样本是真实数据的概率。
生成器(Generator)
生成器(Generator)的损失函数是它在对抗过程中试图最小化的目标。生成器的目标是产生尽可能接近真实数据分布的假数据,以便判别器(Discriminator)难以区分真假数据。
训练过程

  • 初始化:生成器和判别器的参数随机初始化。
  • 对抗训练:
    生成器生成假数据。
    判别器尝试区分真假数据。
    判别器的损失函数是它对真实数据和生成数据的预测误差的总和

生成器的损失函数是它欺骗判别器的成功率,即判别器错误地将生成数据识别为真实数据的概率。
在这里插入图片描述

  • 参数更新:
    判别器根据损失函数更新参数,以更好地区分真假数据。
    生成器根据损失函数更新参数,以生成更逼真的数据,以欺骗判别器。

代码实现

#以fashionMNist
# 损失函数
def d_loss_fn(r_logit, f_logit):r_loss = torch.nn.functional.binary_cross_entropy_with_logits(r_logit, torch.ones_like(r_logit))f_loss = torch.nn.functional.binary_cross_entropy_with_logits(f_logit, torch.zeros_like(f_logit))return r_loss, f_lossdef g_loss_fn(f_logit):f_loss = torch.nn.functional.binary_cross_entropy_with_logits(f_logit, torch.ones_like(f_logit))return f_loss
# 生成模型
class GeneratorCGAN(nn.Module):def __init__(self, z_dim, c_dim, dim=128):super(GeneratorCGAN, self).__init__()def dconv_bn_relu(in_dim, out_dim, kernel_size=4, stride=2, padding=1, output_padding=0):return nn.Sequential(nn.ConvTranspose2d(in_dim, out_dim, kernel_size, stride, padding, output_padding),nn.BatchNorm2d(out_dim),nn.ReLU())self.ls = nn.Sequential(dconv_bn_relu(z_dim + c_dim, dim * 4, 4, 1, 0, 0),  # (N, dim * 4, 4, 4)dconv_bn_relu(dim * 4, dim * 2),  # (N, dim * 2, 8, 8)dconv_bn_relu(dim * 2, dim),   # (N, dim, 16, 16)nn.ConvTranspose2d(dim, 3, 4, 2, padding=1), nn.Tanh()  # (N, 3, 32, 32))def forward(self, z, c):# z: (N, z_dim), c: (N, c_dim) ->[64, 110]x = torch.cat([z, c], 1)# [64, 110] -> [64, 3, 32, 32]x = self.ls(x.view(x.size(0), x.size(1), 1, 1))# print(x.shape)# 输出生成的图像结果return xclass DiscriminatorCGAN(nn.Module):def __init__(self, x_dim, c_dim, dim=96, norm='none', weight_norm='spectral_norm'):super(DiscriminatorCGAN, self).__init__()norm_fn = _get_norm_fn_2d(norm)weight_norm_fn = _get_weight_norm_fn(weight_norm)def conv_norm_lrelu(in_dim, out_dim, kernel_size=3, stride=1, padding=1):return nn.Sequential(weight_norm_fn(nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding)),norm_fn(out_dim),nn.LeakyReLU(0.2))self.ls = nn.Sequential(  # (N, x_dim+c_dim, 32, 32)conv_norm_lrelu(x_dim + c_dim, dim),conv_norm_lrelu(dim, dim),conv_norm_lrelu(dim, dim, stride=2),  # (N, dim , 16, 16)conv_norm_lrelu(dim, dim * 2),conv_norm_lrelu(dim * 2, dim * 2),conv_norm_lrelu(dim * 2, dim * 2, stride=2),  # (N, dim*2, 8, 8)conv_norm_lrelu(dim * 2, dim * 2, kernel_size=3, stride=1, padding=0),conv_norm_lrelu(dim * 2, dim * 2, kernel_size=1, stride=1, padding=0),conv_norm_lrelu(dim * 2, dim * 2, kernel_size=1, stride=1, padding=0),  # (N, dim*2, 6, 6)nn.AvgPool2d(kernel_size=6),  # (N, dim*2, 1, 1)torchlib.Reshape(-1, dim * 2),  # (N, dim*2)weight_norm_fn(nn.Linear(dim * 2, 1))  # (N, 1))def forward(self, x, c):# x: (N, x_dim, 32, 32), c: (N, c_dim)# [64, 10] -> [64, 10, 32, 32]c = c.view(c.size(0), c.size(1), 1, 1) * torch.ones([c.size(0), c.size(1), x.size(2), x.size(3)], dtype=c.dtype, device=c.device)# 常规损失函数 [64, 10, 32, 32] ->[64, 1]logit = self.ls(torch.cat([x, c], 1))# 输出置信度return logit
# model:鉴别器输入维度3:三通道图像,输出维度10:对应类别
D = DiscriminatorCGAN(x_dim=3, c_dim=c_dim)
# 生成器模型:编码维度,输出维度10:对应类别
G = GeneratorCGAN(z_dim=z_dim, c_dim=c_dim)

训练架构

  # 训练鉴别器模型输入与输出# 图像x = x.to(device)# 对应类别c_dense = c_dense.to(device)# 随机图像z = torch.randn(batch_size, z_dim).to(device)# 条件标签c = torch.tensor(np.eye(c_dim)[c_dense.cpu().numpy()], dtype=z.dtype).to(device)# 随机数与条件输入生成器生成伪图像x_f = G(z, c).detach()# 原始图像与条件输入鉴别器计算标签图像分数x_gan_logit = D(x, c)  # [batchsize,1]# 输入伪图像与条件计算伪图像分数x_f_gan_logit = D(x_f, c) # [batchsize,1]
_x_gan_loss, d_x_f_gan_loss = d_loss_fn(x_gan_logit, x_f_gan_logit)
  # 训练生成器模型输入与输出z = torch.randn(batch_size, z_dim).to(device)# 生成器中计算损失函数x_f = G(z, c)x_f_gan_logit = D(x_f, c)g_gan_loss = g_loss_fn(x_f_gan_logit)

在这里插入图片描述


http://www.mrgr.cn/news/41462.html

相关文章:

  • Java try-catch结构异常处理机制与 IllegalArgumentException 详解
  • docker 部署 filebeat 采集日志导入到elasticsearch 设置pipeline
  • ADRC与INDI的关系
  • 过滤器 Filter 详解
  • C++【类和对象】(再探构造函数、类型转换与static成员)
  • 如何选择与运用编程工具提升工作效率的秘密武器
  • 基于物理信息神经网络(PINN)求解Burgers方程(附PyTorch源代码)
  • 进程和线程之间的通用方式
  • [20241002] OpenAI融资文件曝光,ChatGPT年收入涨4倍,月费5年内翻倍
  • OpenGL笔记十九之相机系统
  • WSL--安装各种软件包
  • CompletableFuture常用方法
  • 计算机网络思维导图
  • 【微服务】组件、基础工程构建(day2)
  • C++中substr用法记录
  • 云原生(四十一)| 阿里云ECS服务器介绍
  • 什么是 Supply chain attack(供应链攻击)
  • 差分基准站
  • MySQL高阶2051-商店中每个成员的级别
  • Blazor开发框架Known-V2.0.13