import torch.nn as nn
import torch
import torch.nn.functional as F
from utils import Shuffling

class FacDensityNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, z_dim):
        super(FacDensityNet, self).__init__()

        self.autoencoder = FacDensityAE(input_dim=input_dim, hidden_dim=hidden_dim, z_dim=z_dim)
        self.discriminator = Discriminator(input_dim=z_dim)
        self.shuffle = Shuffling()

    def forward(self, x):
        x_hat, z = self.autoencoder(x)
        z_shuffle = self.shuffle(z)
        return x_hat, z, z_shuffle


class FacDensityAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, z_dim=10):
        super(FacDensityAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, z_dim),
        )

        self.decoder = nn.Sequential(
            nn.Linear(z_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
        )

    def forward(self, x):
        h = self.encoder(x)
        x_hat = self.decoder(h)
        return x_hat, h


class Discriminator(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            # nn.Softmax(dim=1),
        )

    def forward(self, x):
        score = F.sigmoid(self.net(x))
        return score

    @staticmethod
    def sampling(z):
        n, d = z.shape
        z_shuffle = z.clone()
        # shuffling
        for i in range(d):
            z_shuffle[:, i] = z[torch.randperm(n), i]
        return z_shuffle, z

