import math
import torch


class Mish(torch.nn.Module):
    def forward(self, x):
        return x * torch.tanh(torch.nn.functional.softplus(x))


class TimeEmbedding(torch.nn.Module):

    def __init__(self,
                 dim: int,
                 num_steps: int,
                 rescale_steps: int=4000,
                 use_mlp: bool=False,
                 hidden_dim_factor: int=4):
        super(TimeEmbedding, self).__init__()
        self.dim = dim
        self.num_steps = float(num_steps)
        self.rescale_steps = float(rescale_steps)
        self.use_mlp = use_mlp
        if self.use_mlp:
            self.mlp = torch.nn.Sequential(
                torch.nn.Linear(dim, dim * hidden_dim_factor),
                Mish(),
                torch.nn.Linear(dim * hidden_dim_factor, dim)
            )

    def forward(self, x: torch.tensor) -> torch.tensor:
        x = x / self.num_steps * self.rescale_steps
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        if self.use_mlp:
            return self.mlp(emb)
        else:
            return emb
