import torch
import torch.nn as nn
from dl.src.constants import BASE_EN_DE


class FactorVAE(nn.Module):
    def __init__(self, config):
        super(FactorVAE, self).__init__()
        encoder, decoder = BASE_EN_DE[config.dataset]
        self.encoder = encoder(config)
        self.decoder = decoder(config)
        self.discriminator = Discriminator(config)
        self.D_loss_fn = nn.CrossEntropyLoss()

    def forward(self, input, loss_fn):
        encoder_output = self.encoder(input)
        decoder_output = self.decoder(encoder_output[0])
        outputs = (encoder_output,) + (decoder_output,)
        loss = self.loss(input, outputs, loss_fn)
        loss = (loss,) + (encoder_output,) + (decoder_output,)
        return loss

    def loss(self, input, outputs, loss_fn):
        result = {"elbo": {}, "obj": {}, "id": {}}
        batch = input.size(0)
        reconsted_images = outputs[1][0]
        z, mu, logvar = (
            outputs[0][0].squeeze(),
            outputs[0][1].squeeze(),
            outputs[0][2].squeeze(),
        )
        D_z = self.discriminator(z)

        reconst_err = loss_fn(reconsted_images, input) / batch  # * input.size(-1) ** 2
        kld_err = torch.mean(
            -0.5 * torch.sum(1 + logvar - mu**2 - logvar.exp(), dim=-1)
        )
        tc_err = (D_z[:, :1] - D_z[:, 1:]).mean()

        permuted_z = self.permute_dim(z.clone())
        D_z_permuted = self.discriminator(permuted_z)
        ones = torch.ones(D_z_permuted.size(0), dtype=torch.long).to(
            D_z_permuted.device
        )
        zeros = torch.zeros(D_z_permuted.size(0), dtype=torch.long).to(
            D_z_permuted.device
        )
        disc_err = 0.5 * (
            self.D_loss_fn(D_z_permuted, ones) + self.D_loss_fn(D_z, zeros)
        )

        result["obj"]["reconst"] = reconst_err
        result["obj"]["kld"] = kld_err
        result["obj"]["tc"] = tc_err
        result["obj"]["disc"] = disc_err

        return result

    def init_weights(self):
        for n, p in self.named_parameters():
            if p.data.ndimension() >= 2:
                nn.init.xavier_uniform_(p.data)
            # else:
            # nn.init.zeros_(p.data)

    def permute_dim(self, z):
        assert z.dim() == 2

        B, _ = z.size()
        perm_z = []
        for z_j in z.split(1, 1):
            perm = torch.randperm(B).to(z.device)
            perm_z_j = z_j[perm]
            perm_z.append(perm_z_j)
        output = torch.cat(perm_z, 1)

        return output


class Discriminator(nn.Module):
    def __init__(self, config):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(config.latent_dim, 1000),
            nn.LeakyReLU(0.2, inplace=False),
            nn.Linear(1000, 1000),
            nn.LeakyReLU(0.2, inplace=False),
            nn.Linear(1000, 1000),
            nn.LeakyReLU(0.2, inplace=False),
            nn.Linear(1000, 1000),
            nn.LeakyReLU(0.2, inplace=False),
            nn.Linear(1000, 1000),
            nn.LeakyReLU(0.2, inplace=False),
            nn.Linear(1000, 2),
        )
        self.init_weights()

    def forward(self, input):
        output = self.net(input).squeeze()
        return output

    def init_weights(self):
        for n, p in self.named_parameters():
            if p.data.ndimension() >= 2:
                nn.init.xavier_uniform_(p.data)
