🗽Denoising Diffusion Probabilistic Models (Part 2)
type
status
date
slug
summary
tags
category
icon
password
“The sculpture is already complete within the marble block, before I start my work. It is already there, I just have to chisel away the superfluous material.” — Michelangelo
DDPM 的影响力是巨大的,所以后续在原论文的基础上有非常多的改进优化,衍生出了其他很多不同种类的扩散模型。为了加快推进速度,原始 DDPM 的代码实现我将参考 YouTube 博主 Umar Jamil 的实现,尽量以抓大框架理解为主。
Model.py
我们直入主题,先来看
model.py 。 是一个超参数,在 DDPM 中 是随 线性增长的。可以理解为, 越到后面,加噪加得越狠(一定程度上也保证了 和 是近似相同的)。
对于一个 batch 中的每一个样本,加噪的步数 是在
t_range 范围内以均匀分布任取的。这样模型可以学会处理各种加噪水平的图像,提高生成质量。U-Net 的构建
后续应该会专门看一下 U-Net 的原论文
设定一些参数,更好地符合论文公式。
损失函数的计算,实际是在计算预测噪声和实际加噪间的 MSE (mean-square error,均方误差)。实际加噪通过
randn 在标准正态分布中采样;预测噪声通过把噪声图片和对应时间步喂给 U-Net 得到。原论文给出了去噪公式,在 U-Net 预测出第 步噪声后,可以从 去噪得到 。
data.py
data.py 对数据集(这里以 MNIST,FashionMNIST,CIFAR10,CelebA 为例)进行预处理和加载。train.py
原代码中,
train.py 提供了将 denoise 过程可视化输出为 GIF 的扩展。这里我们直接看核心的 train_model部分。Diffset 是自定义的数据集类,通过传入True 和False 来决定用于训练集还是验证集。Dataloader 用于从数据集中加载数据。可以选择训练一个新模型或加载预训练权重(checkpoint)。
pl.Trainer 是 PyTorch Lightning 用于训练模型的核心类。pl.Trainer 是该类中的关键函数,用于启动训练过程,执行包括训练循环、验证循环、日志记录等任务。其中,model 是需要训练的模型,继承自pl.LightningModule ,在之前的代码中通过DiffusionModel 被创建,其中会包括forward() 、training_step() 、validation_step() 、configure_optimizers() 方法。train.py完整代码
更为完整细致的 DDPM 代码即教程,可以参考 “The Annotated Diffusion Model”。
上一篇
Denoising Diffusion Probabilistic Models (Part 1)
下一篇
Flow Matching for Generative Modeling (Part 1)
Loading...



