import torch
import torch.nn as nn
import numpy as np


# class GaussianFourierProjection(nn.Module):
#     """Gaussian random features for encoding time steps."""

#     def __init__(self, embed_dim, scale=30.0):
#         super().__init__()
#         self.register_buffer("W", torch.randn(embed_dim // 2) * scale)

#     def forward(self, x):
#         x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
#         return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)


# class MLP(nn.Module):
#     """Score network"""

#     def __init__(self, in_dim, hid_dim):
#         super().__init__()
#         self.time_embed = GaussianFourierProjection(hid_dim)
#         self.space_embed = nn.Linear(in_dim, hid_dim)
#         self.layer_1 = nn.Linear(hid_dim, hid_dim)
#         self.layer_2 = nn.Linear(hid_dim, hid_dim)
#         self.layer_3 = nn.Linear(hid_dim, hid_dim)
#         self.out_layer = nn.Linear(hid_dim, in_dim)

#         self.activation = nn.ReLU()

#     def forward(self, x, t):
#         if t.dim() == 2:
#             t = t.squeeze(-1)
#         embed = self.time_embed(t) + self.space_embed(x)
#         embed = self.layer_1(self.activation(embed))
#         embed = self.layer_2(self.activation(embed))
#         embed = self.layer_3(self.activation(embed))
#         embed = self.activation(embed)
#         return self.out_layer(embed)


class MLP(nn.Module):
    """Score network"""

    def __init__(self, in_dim, hid_dim):
        super().__init__()
        self.time_embed = GaussianFourierProjection(hid_dim)
        self.space_embed = nn.Linear(in_dim, hid_dim)

        self.bn1 = nn.LayerNorm(hid_dim)
        self.layer_1 = nn.Linear(hid_dim, hid_dim)

        self.bn4 = nn.LayerNorm(hid_dim)
        self.layer_4 = nn.Linear(hid_dim, hid_dim)
        self.layer_4_1 = nn.Linear(hid_dim, hid_dim)


        self.bn5 = nn.LayerNorm(hid_dim)
        self.out_layer = nn.Linear(hid_dim, in_dim)

        self.activation = nn.ReLU()

    def forward(self, x, t):
        # Ensure t is (B,)
        if t.dim() == 2:
            t = t.squeeze(-1)
        embed = self.time_embed(t) + self.space_embed(x)
        embed = self.layer_1(self.activation(embed))
        embed = self.layer_4(self.activation(embed))
        embed = self.layer_4_1(self.activation(embed))
        embed = self.activation(embed)
        return self.out_layer(embed)

class GaussianFourierProjection(nn.Module):
    """Gaussian random features for encoding time steps."""

    def __init__(self, embed_dim, scale=30.0):
        super().__init__()
        self.register_buffer(
            "W", torch.randn(embed_dim // 2) * scale
        )

    def forward(self, x):
        # x: (B,)
        x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
        return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
