前言
在实现 DDPM 代码的过程中,发现了很多可以补充自己代码能力的知识点,在此简单记录一下,顺带也是自己加深印象,不至于学完就忘。然后也借着这个文章,梳理一下在实现 DDPM 中很多自己遇到的困难,以及我自己的解决方案。
参考资料
1. 微信公众号文章:深入浅出扩散模型,代码篇
2. 科学空间:生成扩散模型漫谈(三):DDPM = 贝叶斯 + 去噪
3. Labml 复现 DDPM 代码
其他
torch 数据类型转换
1. gather()函数
先学习一个小知识,torch.gather 用法,对于 Tensor 中的数据进行切片的一个方式。
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
含义:对于输入 input 的 dim 维度取 index 对应的元素。
输入:可以是 torch.gather() ,也可以是某个实例化张量调用 tensor.gather(),省去了 input 参数。
输出:对 input 张量进行数据切片的结果。
在官方文档中说的很清楚,但是实际理解还需要根据实例看。
这里链接一篇文章:图解PyTorch中的torch.gather函数 ,文章中所举的例子较为全面,但是实际上网友的评价更是加深了对知识的理解。
其中,某位网友提到的方式比较清晰,引用在此:
以下面代码为例:
index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)
(1) output.shape = index.shape # 确定最后输出的output的shape必须与index的相同,这里是1*3的tensor,那么output必须也是1*3的tensor,先把壳打起来torch.tensor([[?,?,?]])
(2) 对output所有值的索引,按shape方式排出来,也就是[[(0,0),(0,1),(0,2)]]
(3) 还是对output,拿index里的值替换上面dim指定位置,dim=0替换行,dim=1即替换列。变成[[(0,2),(0,1),(0,0)]]
(4) 按这个索引获取tensor_0相应位置的值,填进去就好了,得到torch.tensor([[5,4,3]])
最后,回到官方文档的举例中,对于一个 3-d Tensor:
1 2 3 out[i][j][k] = input [index[i][j][k]][j][k] out[i][j][k] = input [i][index[i][j][k]][k] out[i][j][k] = input [i][j][index[i][j][k]]
在代码中怎么用的呢?
1 2 3 4 5 6 7 8 9 def gather (consts: torch.Tensor, t: torch.Tensor ): """ Gather consts for $t$ and reshape to feature map shape consts: 固定的已知系数 t: [batchsize],一个批量的时间步 """ c = consts.gather(-1 , t) return c.reshape(-1 , 1 , 1 , 1 )
Denoise Diffusion 实现
这一部分较 U-net 网络模型而言较为清晰和直观,暂时搁置一下,主要的精力放在学习 U-net 网络中。在本笔记中主要参考的网络模型图取自参考资料[1],很感谢大牛朋友认真无私的奉献。
位置编码
首先上一道开胃菜,class TimeEmbeeding ,该类实现对时间步 t 编码,可以把一个 [batchsize] 的时间步 t 编码为大小 [batchsize, n_channels] 的位置向量。
口说无凭,或者说表述抽象,我们举例来看:
1 2 3 4 5 6 7 8 9 timeEmd = TimeEmbedding(128 ) time = torch.randint(0 ,1000 ,(64 ,)) emb = timeEmd(time) print (emb.shape)torch.Size([64 , 128 ])
现在来看代码实现,如下所示。在使用正弦编码对时间步进行编码是在低维中进行,最后使用线性变换为高维。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 class TimeEmbedding (nn.Module): """ ### Embeddings for $t$ """ def __init__ (self, n_channels: int ): """ * `n_channels` is the number of dimensions in the embedding """ super ().__init__() self.n_channels = n_channels self.lin1 = nn.Linear(self.n_channels // 4 , self.n_channels) self.act = Swish() self.lin2 = nn.Linear(self.n_channels, self.n_channels) def forward (self, t: torch.Tensor ): half_dim = self.n_channels // 8 emb = math.log(10_000 ) / (half_dim - 1 ) emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb) emb = t[:, None ] * emb[None , :] emb = torch.cat((emb.sin(), emb.cos()), dim=1 ) emb = self.act(self.lin1(emb)) emb = self.lin2(emb) return emb
Unet
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 class UNet (Module ): """ DDPM UNet去噪模型主体架构 """ def __init__ (self, image_channels: int = 3 , n_channels: int = 64 , ch_mults: Union [Tuple [int , ...], List [int ]] = (1 , 2 , 2 , 4 ), is_attn: Union [Tuple [bool , ...], List [int ]] = (False , False , True , True ), n_blocks: int = 2 ): """ Params: image_channels:原始输入图片的channel数,对RGB图像来说就是3 n_channels: 在进UNet之前,会对原始图片做一次初步卷积,该初步卷积对应的 out_channel数,也就是图中左上角的第一个墨绿色箭头 ch_mults: 在Encoder下采样的每一层的out_channels倍数, 例如ch_mults[i] = 2,表示第i层特征图的out_channel数, 是第i-1层的2倍。Decoder上采样时也是同理,用的是反转后的ch_mults is_attn: 在Encoder下采样/Decoder上采样的每一层,是否要在CNN做特征提取后再引入attention (会在下文对该结构进行详细说明) n_blocks: 在Encoder下采样/Decoder下采样的每一层,需要用多少个DownBlock/UpBlock(见图), Deocder层最终使用的UpBlock数=n_blocks + 1 【到此为止没有完全看懂注释也没关系,可以一遍打开示意图,一遍继续往下阅读源码,就能满满加深理解】 """ super ().__init__() n_resolutions = len (ch_mults) self.image_proj = nn.Conv2d(image_channels, n_channels, kernel_size=(3 , 3 ), padding=(1 , 1 )) self.time_emb = TimeEmbedding(n_channels * 4 ) down = [] out_channels = in_channels = n_channels for i in range (n_resolutions): out_channels = in_channels * ch_mults[i] for _ in range (n_blocks): down.append(DownBlock(in_channels, out_channels, n_channels * 4 , is_attn[i])) in_channels = out_channels if i < n_resolutions - 1 : down.append(Downsample(in_channels)) self.down = nn.ModuleList(down) self.middle = MiddleBlock(out_channels, n_channels * 4 , ) up = [] in_channels = out_channels for i in reversed (range (n_resolutions)): out_channels = in_channels for _ in range (n_blocks): up.append(UpBlock(in_channels, out_channels, n_channels * 4 , is_attn[i])) out_channels = in_channels // ch_mults[i] up.append(UpBlock(in_channels, out_channels, n_channels * 4 , is_attn[i])) in_channels = out_channels if i > 0 : up.append(Upsample(in_channels)) self.up = nn.ModuleList(up) self.norm = nn.GroupNorm(8 , n_channels) self.act = Swish() self.final = nn.Conv2d(in_channels, image_channels, kernel_size=(3 , 3 ), padding=(1 , 1 )) def forward (self, x: torch.Tensor, t: torch.Tensor ): """ Params: x: 输入数据xt,尺寸大小为(batch_size, in_channels, height, width) t: 输入数据t,尺寸大小为(batch_size) """ t = self.time_emb(t) x = self.image_proj(x) h = [x] for m in self.down: x = m(x, t) h.append(x) x = self.middle(x, t) for m in self.up: if isinstance (m, Upsample): x = m(x, t) else : s = h.pop() x = torch.cat((x, s), dim=1 ) x = m(x, t) return self.final(self.act(self.norm(x)))