import torch
import torch.nn as nn
import torch.nn.functional as F

# simple MLP denoiser
class MLPDenoiser(nn.Module):
    def __init__(self, *, in_dim=3, hid_dim=64, out_dim=2, num_hid_layers=1, dropout=0.1, activation=F.relu):
        super().__init__()
  
        self.num_hid_layers = num_hid_layers
        self.dropout = dropout
        self.activation = activation
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(in_dim, hid_dim))
        for i in range(num_hid_layers):
            self.layers.append(nn.Linear(hid_dim, hid_dim))
        self.out = nn.Linear(hid_dim, out_dim)

    def forward(self, x, t):

        # t shape [n]
        h = torch.cat([x, t[None,:].T], dim=-1) # concatenate t as an extra feature to x
        for i in range(self.num_hid_layers+1):
            h = self.activation(self.layers[i](h))
            # add dropout
            h = F.dropout(h, p=self.dropout, training=self.training)
        return self.out(h)
