1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class TimeEmbedding(nn.Module):
"""
### Embeddings for $t$
"""

def __init__(self, n_channels: int):
"""
* `n_channels` is the number of dimensions in the embedding
"""
super().__init__()
self.n_channels = n_channels
# First linear layer
self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels)
# Activation
self.act = Swish()
# Second linear layer
self.lin2 = nn.Linear(self.n_channels, self.n_channels)

def forward(self, t: torch.Tensor):
# Create sinusoidal position embeddings
# [same as those from the transformer](../../transformers/positional_encoding.html)
#
# \begin{align}
# PE^{(1)}_{t,i} &= sin\Bigg(\frac{t}{10000^{\frac{i}{d - 1}}}\Bigg) \\
# PE^{(2)}_{t,i} &= cos\Bigg(\frac{t}{10000^{\frac{i}{d - 1}}}\Bigg)
# \end{align}
#
# where $d$ is `half_dim`
half_dim = self.n_channels // 8
emb = math.log(10_000) / (half_dim - 1)
# emb: [half_dim]
emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
# t[:, None] [batchsize, 1]
# emb[None, :] [1, half_dim]
# emb [batchsize, half_dim]
emb = t[:, None] * emb[None, :]
# emb.sin() [batchsize, half_dim]
# emb.cos() [batchsize, half_dim]
# emb [batchsize, half_dim*2]
emb = torch.cat((emb.sin(), emb.cos()), dim=1)

# Transform with the MLP
emb = self.act(self.lin1(emb))
emb = self.lin2(emb)

return emb