1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
| class TimeEmbedding(nn.Module): """ ### Embeddings for $t$ """
def __init__(self, n_channels: int): """ * `n_channels` is the number of dimensions in the embedding """ super().__init__() self.n_channels = n_channels self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels) self.act = Swish() self.lin2 = nn.Linear(self.n_channels, self.n_channels)
def forward(self, t: torch.Tensor): half_dim = self.n_channels // 8 emb = math.log(10_000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb) emb = t[:, None] * emb[None, :] emb = torch.cat((emb.sin(), emb.cos()), dim=1)
emb = self.act(self.lin1(emb)) emb = self.lin2(emb)
return emb
|