🗽Denoising Diffusion Probabilistic Models (Part 2)

type
status
date
slug
summary
tags
category
icon
password
“The sculpture is already complete within the marble block, before I start my work. It is already there, I just have to chisel away the superfluous material.” — Michelangelo
 
DDPM 的影响力是巨大的,所以后续在原论文的基础上有非常多的改进优化,衍生出了其他很多不同种类的扩散模型。为了加快推进速度,原始 DDPM 的代码实现我将参考 YouTube 博主 Umar Jamil 的实现,尽量以抓大框架理解为主。
 

Model.py

我们直入主题,先来看 model.py
是一个超参数,在 DDPM 中 是随 线性增长的。可以理解为, 越到后面,加噪加得越狠(一定程度上也保证了 是近似相同的)。
对于一个 batch 中的每一个样本,加噪的步数 是在 t_range 范围内以均匀分布任取的。这样模型可以学会处理各种加噪水平的图像,提高生成质量。
关于 U-Net
U-Net 的构建
后续应该会专门看一下 U-Net 的原论文
 
设定一些参数,更好地符合论文公式。
 
损失函数的计算,实际是在计算预测噪声和实际加噪间的 MSE (mean-square error,均方误差)。实际加噪通过 randn 在标准正态分布中采样;预测噪声通过把噪声图片和对应时间步喂给 U-Net 得到。
 
原论文给出了去噪公式,在 U-Net 预测出第 步噪声后,可以从 去噪得到
 
 

data.py

data.py 对数据集(这里以 MNIST,FashionMNIST,CIFAR10,CelebA 为例)进行预处理和加载。
 

train.py

原代码中,train.py 提供了将 denoise 过程可视化输出为 GIF 的扩展。这里我们直接看核心的 train_model部分。
Diffset 是自定义的数据集类,通过传入TrueFalse 来决定用于训练集还是验证集。Dataloader 用于从数据集中加载数据。
 
可以选择训练一个新模型或加载预训练权重(checkpoint)。
 
pl.Trainer 是 PyTorch Lightning 用于训练模型的核心类。pl.Trainer 是该类中的关键函数,用于启动训练过程,执行包括训练循环、验证循环、日志记录等任务。其中,model 是需要训练的模型,继承自pl.LightningModule ,在之前的代码中通过DiffusionModel 被创建,其中会包括forward()training_step()validation_step()configure_optimizers() 方法。
 
train.py完整代码

 
更为完整细致的 DDPM 代码即教程,可以参考 “The Annotated Diffusion Model”。
上一篇
Denoising Diffusion Probabilistic Models (Part 1)
下一篇
Flow Matching for Generative Modeling (Part 1)
Loading...