斯洛文尼亚布莱德湖上的布莱德岛鸟瞰图。

前言

在实现 DDPM 代码的过程中,发现了很多可以补充自己代码能力的知识点,在此简单记录一下,顺带也是自己加深印象,不至于学完就忘。然后也借着这个文章,梳理一下在实现 DDPM 中很多自己遇到的困难,以及我自己的解决方案。

参考资料

1. 微信公众号文章:深入浅出扩散模型,代码篇
2. 科学空间:生成扩散模型漫谈(三):DDPM = 贝叶斯 + 去噪
3. Labml 复现 DDPM 代码

其他

torch 数据类型转换

1. gather()函数

先学习一个小知识,torch.gather 用法,对于 Tensor 中的数据进行切片的一个方式。

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

含义:对于输入 input 的 dim 维度取 index 对应的元素。
输入:可以是 torch.gather() ,也可以是某个实例化张量调用 tensor.gather(),省去了 input 参数。
输出:对 input 张量进行数据切片的结果。

在官方文档中说的很清楚,但是实际理解还需要根据实例看。

这里链接一篇文章:图解PyTorch中的torch.gather函数,文章中所举的例子较为全面,但是实际上网友的评价更是加深了对知识的理解。

其中,某位网友提到的方式比较清晰,引用在此:

以下面代码为例:

index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)

(1) output.shape = index.shape # 确定最后输出的output的shape必须与index的相同,这里是1*3的tensor,那么output必须也是1*3的tensor,先把壳打起来torch.tensor([[?,?,?]])

(2) 对output所有值的索引,按shape方式排出来,也就是[[(0,0),(0,1),(0,2)]]

(3) 还是对output,拿index里的值替换上面dim指定位置,dim=0替换行,dim=1即替换列。变成[[(0,2),(0,1),(0,0)]]

(4) 按这个索引获取tensor_0相应位置的值,填进去就好了,得到torch.tensor([[5,4,3]])

最后,回到官方文档的举例中,对于一个 3-d Tensor:

1
2
3
out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2

在代码中怎么用的呢?

1
2
3
4
5
6
7
8
9
def gather(consts: torch.Tensor, t: torch.Tensor):
"""
Gather consts for $t$ and reshape to feature map shape
consts: 固定的已知系数
t: [batchsize],一个批量的时间步
"""
c = consts.gather(-1, t)
# c: [batchsize] -> [batchsize,1,1,1] 方便后续广播
return c.reshape(-1, 1, 1, 1)

Denoise Diffusion 实现

这一部分较 U-net 网络模型而言较为清晰和直观,暂时搁置一下,主要的精力放在学习 U-net 网络中。在本笔记中主要参考的网络模型图取自参考资料[1],很感谢大牛朋友认真无私的奉献。

位置编码

首先上一道开胃菜,class TimeEmbeeding,该类实现对时间步 t 编码,可以把一个 [batchsize] 的时间步 t 编码为大小 [batchsize, n_channels] 的位置向量。

口说无凭,或者说表述抽象,我们举例来看:

1
2
3
4
5
6
7
8
9
# 128 dims per t
timeEmd = TimeEmbedding(128)
# batchsize = 64
time = torch.randint(0,1000,(64,))
emb = timeEmd(time)
print(emb.shape)

# output
torch.Size([64, 128])

现在来看代码实现,如下所示。在使用正弦编码对时间步进行编码是在低维中进行,最后使用线性变换为高维。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class TimeEmbedding(nn.Module):
"""
### Embeddings for $t$
"""

def __init__(self, n_channels: int):
"""
* `n_channels` is the number of dimensions in the embedding
"""
super().__init__()
self.n_channels = n_channels
# First linear layer
self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels)
# Activation
self.act = Swish()
# Second linear layer
self.lin2 = nn.Linear(self.n_channels, self.n_channels)

def forward(self, t: torch.Tensor):
# Create sinusoidal position embeddings
# [same as those from the transformer](../../transformers/positional_encoding.html)
#
# \begin{align}
# PE^{(1)}_{t,i} &= sin\Bigg(\frac{t}{10000^{\frac{i}{d - 1}}}\Bigg) \\
# PE^{(2)}_{t,i} &= cos\Bigg(\frac{t}{10000^{\frac{i}{d - 1}}}\Bigg)
# \end{align}
#
# where $d$ is `half_dim`
half_dim = self.n_channels // 8
emb = math.log(10_000) / (half_dim - 1)
# emb: [half_dim]
emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
# t[:, None] [batchsize, 1]
# emb[None, :] [1, half_dim]
# emb [batchsize, half_dim]
emb = t[:, None] * emb[None, :]
# emb.sin() [batchsize, half_dim]
# emb.cos() [batchsize, half_dim]
# emb [batchsize, half_dim*2]
emb = torch.cat((emb.sin(), emb.cos()), dim=1)

# Transform with the MLP
emb = self.act(self.lin1(emb))
emb = self.lin2(emb)

return emb

Unet

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136

class UNet(Module):
"""
DDPM UNet去噪模型主体架构
"""
def __init__(self, image_channels: int = 3, n_channels: int = 64,
ch_mults: Union[Tuple[int, ...], List[int]] = (1, 2, 2, 4),
is_attn: Union[Tuple[bool, ...], List[int]] = (False, False, True, True),
n_blocks: int = 2):
"""
Params:
image_channels:原始输入图片的channel数,对RGB图像来说就是3

n_channels: 在进UNet之前,会对原始图片做一次初步卷积,该初步卷积对应的
out_channel数,也就是图中左上角的第一个墨绿色箭头

ch_mults: 在Encoder下采样的每一层的out_channels倍数,
例如ch_mults[i] = 2,表示第i层特征图的out_channel数,
是第i-1层的2倍。Decoder上采样时也是同理,用的是反转后的ch_mults

is_attn: 在Encoder下采样/Decoder上采样的每一层,是否要在CNN做特征提取后再引入attention
(会在下文对该结构进行详细说明)

n_blocks: 在Encoder下采样/Decoder下采样的每一层,需要用多少个DownBlock/UpBlock(见图),
Deocder层最终使用的UpBlock数=n_blocks + 1

【到此为止没有完全看懂注释也没关系,可以一遍打开示意图,一遍继续往下阅读源码,就能满满加深理解】
"""
super().__init__()

# 在Encoder下采样/Decoder上采样的过程中,图像依次缩小/放大,
# 每次变动都会产生一个新的图像分辨率
# 这里指的就是不同图像分辨率的个数,也可以理解成是Encoder/Decoder的层数
n_resolutions = len(ch_mults)

# 对原始图片做预处理,例如图中,将32*32*3 -> 32*32*64
self.image_proj = nn.Conv2d(image_channels, n_channels, kernel_size=(3, 3), padding=(1, 1))

# time_embedding,TimeEmbedding是nn.Module子类,我们会在下文详细讲解它的属性和forward方法
self.time_emb = TimeEmbedding(n_channels * 4)

# --------------------------
# 定义Encoder部分
# --------------------------
# down列表中的每个元素表示Encoder的每一层
down = []
# 初始化out_channel和in_channel
out_channels = in_channels = n_channels
# 遍历每一层
for i in range(n_resolutions):
# 根据设定好的规则,得到该层的out_channel
out_channels = in_channels * ch_mults[i]
# 根据设定好的规则,每一层有n_blocks个DownBlock
for _ in range(n_blocks):
down.append(DownBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
in_channels = out_channels
# 对Encoder来说,每一层结束后,我们都做一次下采样,但Encoder的最后一层不做下采样
if i < n_resolutions - 1:
down.append(Downsample(in_channels))

# self.down即是完整的Encoder部分
self.down = nn.ModuleList(down)

# --------------------------
# 定义Middle部分
# --------------------------
self.middle = MiddleBlock(out_channels, n_channels * 4, )

# --------------------------
# 定义Decoder部分
# --------------------------

# 和Encoder部分基本一致,可对照绘制的架构图阅读
up = []
in_channels = out_channels
for i in reversed(range(n_resolutions)):
# `n_blocks` at the same resolution
out_channels = in_channels
for _ in range(n_blocks):
up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))

out_channels = in_channels // ch_mults[i]
up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
in_channels = out_channels

if i > 0:
up.append(Upsample(in_channels))

# self.up即是完整的Decoder部分
self.up = nn.ModuleList(up)

# 定义group_norm, 激活函数,和最后一层的CNN(用于将Decoder最上一层的特征图还原成原始尺寸)
self.norm = nn.GroupNorm(8, n_channels)
self.act = Swish()
self.final = nn.Conv2d(in_channels, image_channels, kernel_size=(3, 3), padding=(1, 1))

def forward(self, x: torch.Tensor, t: torch.Tensor):
"""
Params:
x: 输入数据xt,尺寸大小为(batch_size, in_channels, height, width)
t: 输入数据t,尺寸大小为(batch_size)
"""

# 取得time_embedding
t = self.time_emb(t)

# 对原始图片做初步CNN处理
x = self.image_proj(x)

# -----------------------
# Encoder
# -----------------------
h = [x]
# First half of U-Net
for m in self.down:
x = m(x, t)
h.append(x)

# -----------------------
# Middle
# -----------------------
x = self.middle(x, t)

# -----------------------
# Decoder
# -----------------------
for m in self.up:
if isinstance(m, Upsample):
x = m(x, t)
else:
s = h.pop()
# skip_connection
x = torch.cat((x, s), dim=1)
x = m(x, t)

return self.final(self.act(self.norm(x)))