DDPM 做了什么
本博客主要侧重点在于HOW也就是DDPM怎么做的而不是WHY为什么要这样做
那么第一个问题DDPM做了一件什么事:这个算法通过逐渐向原图像添加噪声来破坏图像,然后再学习如何从噪声成恢复图像。
第二件事如何做到的:通过训练一个网络,这个网络输入为加噪声图片和添加噪声的次数,输出为网络预测施加在图像上的噪声。
那么就分为两步:添加噪声和去除噪声来讲解
添加噪声
添加噪声的过程专业点来说叫“前向扩散” 满足下式:
逐步添加高斯噪声到数据 x 0 x_0 x0
q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) q(x_t | x_{t-1}) = \mathcal{N}\left(x_t; \sqrt{1-\beta_t} x_{t-1}, \beta_t I\right) q(xt∣xt−1)=N(xt;1−βtxt−1,βtI)
经过推导可以得到原图像和加噪t次图像的关系
最终隐式表达:
q ( x t ∣ x 0 ) = N ( x t ; α ˉ t x 0 , ( 1 − α ˉ t ) I ) q(x_t | x_0) = \mathcal{N}\left(x_t; \sqrt{\bar{\alpha}_t} x_0, (1-\bar{\alpha}_t)I\right) q(xt∣x0)=N(xt;αˉtx0,(1−αˉt)I)
其中:
- α t = 1 − β t \alpha_t = 1 - \beta_t αt=1−βt
- α ˉ t = ∏ i = 1 t α i \bar{\alpha}_t = \prod_{i=1}^t \alpha_i αˉt=∏i=1tαi
这边的 β t \beta_t βt是自己设的
这个式子用人话来说就是由原图像加噪t
次后产生的图像(就命名为 I t I_t It吧)要满足偏差为 α ˉ t x 0 \sqrt{\bar{\alpha}_t} x_0 αˉtx0 方差为 ( 1 − α ˉ t ) I (1-\bar{\alpha}_t)I (1−αˉt)I 的正态分布。
听起来是不是还是不像人话,没事代码一看便懂
def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:mean = gather(self.alpha_bar, t) ** 0.5 * x0var = 1 - gather(self.alpha_bar, t)return mean, vardef q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):if eps is None:eps = torch.randn_like(x0)mean, var = self.q_xt_x0(x0, t)return mean + (var ** 0.5) * eps
也就是 I t I_t It是由 I 0 I_0 I0乘上一个系数然后加上由标准正态分布采样得到的和原图像大小一致的随机噪声乘上系数得到的。
那么为什么 m e a n + ( v a r ∗ ∗ 0.5 ) ∗ e p s ∼ N ( x t ; α ˉ t x 0 , ( 1 − α ˉ t ) I ) mean + (var ** 0.5) * eps \sim \mathcal{N}\left(x_t; \sqrt{\bar{\alpha}_t} x_0, (1-\bar{\alpha}_t)I\right) mean+(var∗∗0.5)∗eps∼N(xt;αˉtx0,(1−αˉt)I) 呢?
因为这边的 e p s ∼ N ( 0 , I ) eps\sim \mathcal{N}(0, I) eps∼N(0,I) 所以 ( v a r ∗ ∗ 0.5 ) ∗ e p s ∼ N ( 0 , ( 1 − α ˉ t ) I ) (var ** 0.5) * eps \sim \mathcal{N}(0,(1-\bar{\alpha}_t)I) (var∗∗0.5)∗eps∼N(0,(1−αˉt)I) (这块看不懂去看看概率论吧)
那么 m e a n + ( v a r ∗ ∗ 0.5 ) ∗ e p s ∼ N ( α ˉ t x 0 , ( 1 − α ˉ t ) I ) mean + (var ** 0.5) * eps \sim N(\sqrt{\bar{\alpha}_t} x_0, (1-\bar{\alpha}_t)I) mean+(var∗∗0.5)∗eps∼N(αˉtx0,(1−αˉt)I) 满足了隐式表达的式子 。
去除噪声
说完了添加噪声,那么自然来到了如何去除噪声,前面也说过,我们训练一个网络网络输入为 I t I_t It和t,输出为网络预测的第t次施加在图像上的噪声。我们把这个网络就记作 ϵ θ ( I t , t ) \epsilon_\theta(I_t, t) ϵθ(It,t) ,我们的目标是使得网络预测的噪声和添加在图像上的噪声越相似越好,就得到了网络的损失函数。
L ( θ ) = E t , x 0 , ϵ [ ∥ ϵ − ϵ θ ( I t , t ) ∥ 2 ] \mathcal{L}(\theta) = \mathbb{E}_{t, x_0, \epsilon} \left[ \| \epsilon - \epsilon_\theta(I_t, t) \|^2 \right] L(θ)=Et,x0,ϵ[∥ϵ−ϵθ(It,t)∥2]
训练过程就是采样,计算损失函数,反向传播更新参数。具体就不多说了
TODO:DDPM的噪声预测网络结构
问题
为啥不能训练个去噪器直接从完全噪声预测原图?