import torch
import torch.nn.functional as F
from torchvision.utils import make_grid, save_image
import torch.distributions as dist
from Models.mvae import MVAE
from Models.encoder_decoder.resnet_MNIST import Enc_MNIST, Dec_MNIST
from Models.encoder_decoder.resnet_SVHN import Enc_SVHN, Dec_SVHN
        

class MVAE_MNIST_SVHN(MVAE):
    def __init__(self, args, pseudo_samples_a, pseudo_samples_b):
        super(MVAE_MNIST_SVHN, self).__init__(args, pseudo_samples_a, pseudo_samples_b)
        self.mnist_scale_val = 1.0
        self.svhn_scale_val = 1.0
        self.a_to_z = Enc_SVHN(self.z_dim, self.w_dim)
        self.z_to_a = Dec_SVHN(self.latent_dim)
        self.b_to_z = Enc_MNIST(self.z_dim, self.w_dim)
        self.z_to_b = Dec_MNIST(self.latent_dim)
        self.num_steps = 0

        self.to(self.device)

    def svhn_likelihood(self, pred, targ):
        px_u = dist.Laplace(*pred)
        return px_u.log_prob(targ).view(*px_u.batch_shape[:2], -1).mul(self.svhn_scale_val).mean(dim=-1)

    def mnist_likelihood(self, pred, targ):
        px_u = dist.Laplace(*pred)
        return px_u.log_prob(targ).view(*px_u.batch_shape[:2], -1).mul(self.mnist_scale_val).mean(dim=-1)

    def run(self, svhn_batch, mnist_batch, direction, fn):
        if direction == 's2m':
            self.direction = 's2m'
            self.encoder = self.a_to_z
            self.decoder = self.z_to_a
            self.cond_prior = self.b_to_z
            self.classifier = self.z_to_b
            self.likelihood_s = self.svhn_likelihood
            self.likelihood_t = self.mnist_likelihood
            self.pseudo_samples = self.pseudo_samples_b
            data, targ = svhn_batch, mnist_batch
        elif direction == 'm2s':
            self.direction = 'm2s'
            self.encoder = self.b_to_z
            self.decoder = self.z_to_b
            self.cond_prior = self.a_to_z
            self.classifier = self.z_to_a
            self.likelihood_s = self.mnist_likelihood
            self.likelihood_t = self.svhn_likelihood
            self.pseudo_samples = self.pseudo_samples_a
            data, targ = mnist_batch, svhn_batch
        elif direction == 'bi':
            loss_s2m = self.run(svhn_batch, mnist_batch, 's2m', fn)
            loss_m2s = self.run(svhn_batch, mnist_batch, 'm2s', fn)
            return 0.5 * (loss_s2m + loss_m2s)
        return fn(data, targ)

    def self_and_cross_modal_generation(self, svhn, mnist, num=1, N=10, dim=-1):
        recon_triess = [[[] for _ in range(2)] for _ in range(2)]
        outputss = [[[] for _ in range(2)] for _ in range(2)]

        for _ in range(num):
            recon_svhn = self.a_to_a(svhn).cpu()
            svhn_to_mnist = self.a_to_b(svhn).cpu()
            mnist_to_svhn = self.b_to_a(mnist).cpu()
            recon_mnist = self.b_to_b(mnist).cpu()

            recon_triess[0][0].append(recon_svhn)
            recon_triess[0][1].append(F.pad(svhn_to_mnist.view(-1, 1, 28, 28), (2, 2, 2, 2),
                                        mode='constant', value=0).expand(-1, 3, -1, -1))
            recon_triess[1][0].append(mnist_to_svhn)
            recon_triess[1][1].append(F.pad(recon_mnist.view(-1, 1, 28, 28), (2, 2, 2, 2),
                                        mode='constant', value=0).expand(-1, 3, -1, -1))

        input_svhn = svhn.cpu()
        input_mnist = F.pad(mnist.view(-1, 1, 28, 28), (2, 2, 2, 2), mode='constant', value=0).expand(-1, 3, -1, -1).cpu()
        
        outputss[0][0] = make_grid(torch.cat([input_svhn]+recon_triess[0][0], dim=dim), nrow=N)
        outputss[0][1] = make_grid(torch.cat([input_svhn]+recon_triess[0][1], dim=dim), nrow=N)
        outputss[1][0] = make_grid(torch.cat([input_mnist]+recon_triess[1][0], dim=dim), nrow=N)
        outputss[1][1] = make_grid(torch.cat([input_mnist]+recon_triess[1][1], dim=dim), nrow=N)
        
        return outputss
    
    def self_and_cross_modal_generation_spec(self, svhn, mnist, num=5, N=10, dim=-1):
        recon_triess = [[[] for _ in range(2)] for _ in range(2)]
        outputss = [[[] for _ in range(2)] for _ in range(2)]

        recon_svhn = self.a_to_a(svhn, num).cpu()
        svhn_to_mnist = self.a_to_b(svhn, num).cpu()
        mnist_to_svhn = self.b_to_a(mnist, num).cpu()
        recon_mnist = self.b_to_b(mnist, num).cpu()

        for i in range(num):
            recon_triess[0][0].append(recon_svhn[i])
            recon_triess[0][1].append(F.pad(svhn_to_mnist[i].view(-1, 1, 28, 28), (2, 2, 2, 2),
                                    mode='constant', value=0).expand(-1, 3, -1, -1))
            recon_triess[1][0].append(mnist_to_svhn[i])
            recon_triess[1][1].append(F.pad(recon_mnist[i].view(-1, 1, 28, 28), (2, 2, 2, 2),
                                    mode='constant', value=0).expand(-1, 3, -1, -1))
         
        input_svhn = svhn.cpu()
        input_mnist = F.pad(mnist.view(-1, 1, 28, 28), (2, 2, 2, 2), mode='constant', value=0).expand(-1, 3, -1, -1).cpu()
        
        outputss[0][0] = make_grid(torch.cat([input_svhn]+recon_triess[0][0], dim=dim), nrow=N)
        outputss[0][1] = make_grid(torch.cat([input_svhn]+recon_triess[0][1], dim=dim), nrow=N)
        outputss[1][0] = make_grid(torch.cat([input_mnist]+recon_triess[1][0], dim=dim), nrow=N)
        outputss[1][1] = make_grid(torch.cat([input_mnist]+recon_triess[1][1], dim=dim), nrow=N)
        
        return outputss

    def self_and_cross_modal_generation_for_fid_calculation(self, data, dataPath, fidPath, i):
        px_us = [[None for _ in range(2)] for _ in range(2)]
        with torch.no_grad():
            svhn, mnist = data[0].to(self.device), data[1].to(self.device)
            px_us[0][0] = self.a_to_a(svhn).cpu()
            px_us[0][1] = F.pad(self.a_to_b(svhn).view(-1, 1, 28, 28), (2, 2, 2, 2),
                            mode='constant', value=0).cpu()
            px_us[1][0] = self.b_to_a(mnist).cpu()
            px_us[1][1] = F.pad(self.b_to_b(mnist).view(-1, 1, 28, 28), (2, 2, 2, 2),
                            mode='constant', value=0).cpu()
        
        # for image in range(svhn.size(0)):
        #     save_image(svhn[image, :, :, :],
        #                 '{}/svhn/{}_{}.png'.format(dataPath, image, i))
        # mnist = F.pad(mnist.view(-1, 1, 28, 28), (2, 2, 2, 2),
        #                     mode='constant', value=0)
        # for image in range(mnist.size(0)):
        #     save_image(mnist[image, :, :, :],
        #                 '{}/mnist/{}_{}.png'.format(dataPath, image, i))   
        
        for image in range(px_us[0][0].size(0)):
            save_image(px_us[0][0][image, :, :, :],
                        '{}/svhn/svhn/{}_{}.png'.format(fidPath, image, i))            
        for image in range(px_us[0][1].size(0)):
            save_image(px_us[0][1][image, :, :, :],
                        '{}/svhn/mnist/{}_{}.png'.format(fidPath, image, i))         
        for image in range(px_us[1][0].size(0)):
            save_image(px_us[1][0][image, :, :, :],
                        '{}/mnist/svhn/{}_{}.png'.format(fidPath, image, i))   
        for image in range(px_us[1][1].size(0)):
            save_image(px_us[1][1][image, :, :, :],
                        '{}/mnist/mnist/{}_{}.png'.format(fidPath, image, i))
