import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from datetime import datetime
from utils import *


class DiffusionModel(nn.Module):

    def __init__(self, input_dim, hidden_dim, time_emb_dim):

        super(DiffusionModel, self).__init__()

        self.time_emb = nn.Embedding(1000, time_emb_dim)

        self.fc1 = nn.Linear(input_dim + time_emb_dim, hidden_dim)

        self.fc2 = nn.Linear(hidden_dim, hidden_dim)

        self.fc3 = nn.Linear(hidden_dim, input_dim)
        
    def forward(self, x_noisy, t):

        t_emb = self.time_emb(t)

        x_in = torch.cat([x_noisy, t_emb], dim=1)

        h = F.relu(self.fc1(x_in))

        h = F.relu(self.fc2(h))

        noise_pred = self.fc3(h)

        return noise_pred
    
    def encode(self, x):

        batch_size = x.size(0)

        t = torch.zeros(batch_size, dtype=torch.long, device=x.device)

        t_emb = self.time_emb(t)

        x_in = torch.cat([x, t_emb], dim=1)

        h = F.relu(self.fc1(x_in))

        return h 

class MutualInfoEstimator(nn.Module):

    def __init__(self, latent_dim, sensitive_dim, hidden_dim):

        super(MutualInfoEstimator, self).__init__()

        self.fc1 = nn.Linear(latent_dim + sensitive_dim, hidden_dim)

        self.fc2 = nn.Linear(hidden_dim, hidden_dim)

        self.fc3 = nn.Linear(hidden_dim, 1)
        
    def forward(self, z, s):

        x = torch.cat([z, s], dim=1)

        h = F.relu(self.fc1(x))

        h = F.relu(self.fc2(h))

        score = self.fc3(h)

        return score
    
class Classifier(nn.Module):

    def __init__(self, input_dim, hidden_dim, num_classes):

        super(Classifier, self).__init__()

        self.fc1 = nn.Linear(hidden_dim, hidden_dim)

        self.fc2 = nn.Linear(hidden_dim, hidden_dim)

        self.fc3 = nn.Linear(hidden_dim, num_classes)
        
    def forward(self, x):

        h = F.relu(self.fc1(x))

        h = F.relu(self.fc2(h))

        logits = self.fc3(h)

        return logits