🗽Flow Matching for Generative Modeling (Part 2)
type
status
date
slug
summary
tags
category
icon
password
Coding
coding 部分参照的是 Flow Matching Guide and Code 。这本书(雾)极为详细,制作精美,实在制作精美,但现在实在没时间👻
loss
__init__.py __all__ 是一个列表。它决定了当使用 from package import * 声明时,会被导入的模块或类。generalized_loss.py该方法继承了
_Loss 父类,用于计算由 path 决定的真实的条件概率路径 和由 logist 决定的模型预测的后验概率 之间的 KL 散度。公式在 Flow Matching Guide and Code 中有给出(大概),这里就不深究下去了reduction 有三种模式:
'none' | 'mean' | 'sum' ,这里因为是加权平均所以选的是 mean 。logits 是模型的原始输出,输出 shape 是 (batch, d, K) batch代表批次大小,即一次前项传播中模型处理的样本数;
d是每个样本的特征纬度或数据点纬度,比如输入的数据如果是图像,d可以表示像素点数;
K是分类任务中类别的数量;
像是一个三维空间,上述三个值可以确定空间一个确定的点,该点的值就是模型输出的原始得分。比如
(1, 8, 9) 是第一个样本第八个特征在第九类的原始得分。后三行计算了一个 。原始得分不是概率,通过激活函数
log_softmax (比 softmax 数值更稳定)可以约束到 且和为 。x_1.unsqueeze(-1) 会将 x_1 的形状从 (batch, d) 转换为 (batch, d, 1) ,使得它能够作为 index 参数传递给 torch.gather 。torch.gather 函数的作用是根据给定的索引从 log_p_1t 中选择相应的值。这里, x_1 包含的是每个样本的真实类别索引,所以 gather 会在每个数据点的类别维度中选择对应于 x_1 的对数概率值。最后把
log_p_1t_x1 的形状转回 (batch, d) 。最后三行同理。
[(...,) * (x_1.dim() - 1)] 是对 jump_coefficient 进行广播,使之能与 x_1 的形状符合。总的来说是翻译了一下损失函数的公式。
path
__init__.py affine.py< to be done >
时间有限,有缘来补🤫
上一篇
Flow Matching for Generative Modeling (Part 1)
下一篇
SOFAR: Language-Grounded Orientation Bridges
Spatial Reasoning and Object Manipulation
Loading...
