import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.mlp import MLP
from utils.fourier_mlp import FourierMLP


def get_timestep_embedding(timesteps, embedding_dim=128):
    """
        From Fairseq.
        Build sinusoidal embeddings.
        This matches the implementation in tensor2tensor, but differs slightly
        from the description in Section 3.5 of "Attention Is All You Need".
        https://github.com/pytorch/fairseq/blob/master/fairseq/modules/sinusoidal_positional_embedding.py
    """
    half_dim = embedding_dim // 2
    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float, device=timesteps.device) * -emb)

    emb = timesteps.float().view(-1, 1) * emb.unsqueeze(0)
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
    if embedding_dim % 2 == 1:  # zero pad
        emb = F.pad(emb, [0,1])

    return emb


class TimeFourierMLP(nn.Module):
    def __init__(self, encoder_layers=[16], pos_dim=16, 
                decoder_layers=[128, 128], 
                space_dim=2, output_dim=1, act_fn=nn.SiLU()):
        super().__init__()
        self.temb_dim = pos_dim
        self.space_dim = space_dim
        t_enc_dim = pos_dim * 2

        self.net = MLP(
            [2 * t_enc_dim] + decoder_layers + [output_dim],
            activation=act_fn)

        self.t_encoder = MLP(
            [pos_dim] + encoder_layers + [t_enc_dim],
            activation=act_fn)

        self.x_encoder = FourierMLP(
            input_dim=space_dim, output_dim=t_enc_dim, 
            n_hidden=128, act=act_fn)

    def forward(self, x_t):
        if len(x_t.shape) == 1:
            x_t = x_t.unsqueeze(0)
        x = x_t[:, :self.space_dim]
        t = x_t[:, self.space_dim:]

        temb = get_timestep_embedding(t, self.temb_dim)
        temb = self.t_encoder(temb)
        xemb = self.x_encoder(x)
        h = torch.cat([xemb ,temb], -1)
        out = self.net(h) 
        return out
