import torch
import numpy as np
import torch.nn as nn


class GaussianFourierProjection(nn.Module):
    """Gaussian random features for encoding time steps."""

    def __init__(self, embed_dim, scale=30.0):
        super().__init__()
        self.register_buffer("W", torch.randn(embed_dim // 2) * scale)

    def forward(self, x):
        x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
        return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)


class MLP(nn.Module):
    """Score network"""

    def __init__(self, in_dim, hid_dim):
        super().__init__()
        self.time_embed = GaussianFourierProjection(hid_dim)
        self.space_embed = nn.Linear(in_dim, hid_dim)
        self.layer_1 = nn.Linear(hid_dim, hid_dim)
        self.layer_2 = nn.Linear(hid_dim, hid_dim)
        self.layer_3 = nn.Linear(hid_dim, hid_dim)
        self.out_layer = nn.Linear(hid_dim, in_dim)

        self.activation = nn.ReLU()

    def forward(self, x, t):
        if t.dim() == 2:
            t = t.squeeze(-1)
        embed = self.time_embed(t) + self.space_embed(x)
        embed = self.layer_1(self.activation(embed))
        embed = self.layer_2(self.activation(embed))
        embed = self.layer_3(self.activation(embed))
        embed = self.activation(embed)
        return self.out_layer(embed)
