import torch
from torch import nn
from agents.helpers import SinusoidalPosEmb


class Denoiser(nn.Module):
    def __init__(self, action_dim, t_dim=16, c_dim=0):
        super().__init__()

        self.has_c = c_dim > 0

        self.time_mlp = nn.Sequential(
            # nn.Linear(1, t_dim),
            # nn.Mish(),
            SinusoidalPosEmb(t_dim),
            nn.Linear(t_dim, t_dim * 2),
            nn.Mish(),
            nn.Linear(t_dim * 2, t_dim),
        )

        input_dim = action_dim + t_dim + c_dim
        self.mid_layer = nn.Sequential(nn.Linear(input_dim, 256),
                                       nn.Mish(),
                                       nn.Linear(256, 256),
                                       nn.Mish(),
                                       nn.Linear(256, 256),
                                       nn.Mish())

        self.final_layer = nn.Linear(256, action_dim)

    def forward(self, x, time, c=None):
        """
        Args:
            x: Tensor of shape (batch_size, action_dim)
            time: Tensor of shape (batch_size, 1)
        """
        time = time.squeeze(-1)
        t = self.time_mlp(time)
        if self.has_c:
            assert c is not None
            x = torch.cat([x, t, c], dim=1)
        else:
            x = torch.cat([x, t], dim=1)
        x = self.mid_layer(x)

        return self.final_layer(x)



