import torch
import torch.nn as nn
import math
from cg.src.constants import BASE_EN_DE

class BetaMAGAVAE(nn.Module):
    def __init__(self, config):
        super(BetaMAGAVAE, self).__init__()
        self.beta = config.beta
        encoder, decoder = BASE_EN_DE[config.dataset]
        self.symmetry = encoder(config)
        self.encoder = encoder(config)
        self.decoder = decoder(config)

    def forward(self, input, loss_fn):
        batch = input.size(0)
        encoder_output = self.encoder(input[: batch // 2])  # z1
        symmetry = self.symmetry(input)
        symmetries = (
                symmetry[0][batch // 2:] - symmetry[0][: batch // 2]
        )  # z'2 - z'1
        transformed_z = encoder_output[0] + symmetries  # z2 = z1 + z'2 -z'1
        zs = torch.cat([encoder_output[0], transformed_z], dim=0)

        decoder_output = self.decoder(zs)
        z_hat_2 = self.symmetry(decoder_output[0][batch // 2:])[0]

        latent_loss = (
            ((symmetry[0][batch // 2:] - z_hat_2) ** 2).sum(dim=-1).mean()
        )

        outputs = (encoder_output,) + (
            decoder_output,
        )  # ((z,mu, logvar,(encoder)), (reconst, (decoder)))
        loss = self.loss(input, outputs, symmetries, loss_fn)
        loss["obj"]["latent_loss"] = latent_loss
        loss = (loss,) + (encoder_output,) + (decoder_output,)
        # ((elbo, reconst_err, kld_err, id_mea, id_var), (z,mu, logvar,(encoder)), (reconst, (decoder)))
        return loss

    # Add loss function
    def loss(self, input, outputs, symmetries=None, loss_fn=None):
        result = {"elbo": {}, "obj": {}, "id": {}}
        batch = input.size(0)
        reconsted_images = outputs[1][0]
        z, mu, logvar = (
            outputs[0][0].squeeze(),
            outputs[0][1].squeeze(),
            outputs[0][2].squeeze(),
        )
        if symmetries is not None:
            sz, smu, slogvar = (
                symmetries[0][0].squeeze(),
                symmetries[0][1].squeeze(),
                symmetries[0][2].squeeze(),
            )

        # zeros = torch.zeros_like(z)
        # logqzx = self.log_density_gaussian(z, mu, logvar).sum(dim=1)
        # logpz = self.log_density_gaussian(z, zeros, zeros).sum(1)  # size: batch
        # _logqz = self.log_density_gaussian(
        #     z.view(batch, 1, -1), mu.view(1, batch, -1), logvar.view(1, batch, -1)
        # )  # size: (batch, batch, dim)
        # logqz_prodmarginal = torch.logsumexp(_logqz, dim=1, keepdim=False).sum(
        #     1
        # )  # - math.log(batch)).sum(1)  # size: batch
        # logqz = torch.logsumexp(
        #     _logqz.sum(2), dim=1, keepdim=False
        # )  # - math.log(batch)

        # criteria = nn.BCELoss(reduction='sum')
        if self.training:
            reconst_err = (
                    loss_fn(reconsted_images, input) / batch
            )  # * input.size(-1) ** 2
        else:
            reconst_err = (
                    loss_fn(reconsted_images[batch // 2:], input[batch // 2:]) / batch * 2
            )

        if symmetries is not None:
            kld_err = torch.mean(
                -0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp(), dim=-1)
                - 0.5 * torch.sum(1 + slogvar - smu ** 2 - slogvar.exp(), dim=-1)
            )
        else:
            kld_err = torch.mean(
                -0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp(), dim=-1)
            )
        # mi = torch.mean(logqzx - logqz)
        # elbo = reconst_err + self.beta * kld_err

        result["obj"]["reconst"] = reconst_err  # .unsqueeze(0)
        result["obj"]["kld"] = kld_err  # .unsqueeze(0)
        # result["obj"]["mi"] = mi  # .unsqueeze(0)
        # result['obj']['kld'] = self.beta * kld_err
        # result['id']['mean'] = id_mean.unsqueeze(0)
        # result['id']['var'] = id_var.unsqueeze(0)

        # loss = (elbo, reconst_err, kld_err, id_mean, id_var,)
        return result

    def log_density_gaussian(self, x, mu, logvar):
        # f(x) = \frac{1}{sigma * sqrt(2 * pi)} exp(-0.5 * ((z - mu) / sigma) ** 2): Gaussian Distribution
        norm = -0.5 * (math.log(2 * math.pi) + logvar)
        log_density = norm - 0.5 * (x - mu) ** 2 * torch.exp(-logvar)
        return log_density

    def init_weights(self):
        for n, p in self.named_parameters():
            if p.data.ndimension() >= 2:
                nn.init.xavier_uniform_(p.data)