import numpy as np
import numpy as np
import torch
from torchvision.utils import make_grid, save_image
import torch.distributions as dist
import os
import json
from Models.mvae import MVAE
from Models.encoder_decoder.resnet_Image import Enc_Image, Dec_Image
from Models.encoder_decoder.cnn_Caption import Enc_Caption, Dec_Caption
import matplotlib.pyplot as plt


class MVAE_CUBICC(MVAE):
    def __init__(self, args, pseudo_samples_a, pseudo_samples_b):
        super(MVAE_CUBICC, self).__init__(args, pseudo_samples_a, pseudo_samples_b)
        self.image_scale_val = 1/384
        self.sentence_scale_val = 5
        self.a_to_z = Enc_Image(self.z_dim, self.w_dim)
        self.z_to_a = Dec_Image(self.latent_dim)
        self.b_to_z = Enc_Caption(self.z_dim, self.w_dim)
        self.z_to_b = Dec_Caption(self.z_dim, self.w_dim)
        self.num_steps = 0

        self.to(args.device)
        self.vocab_file = 'data/datasets/CUBICC/cub.vocab'

    def image_likelihood(self, pred, targ):
        px_u = dist.Laplace(*pred)
        return px_u.log_prob(targ).view(*px_u.batch_shape[:2], -1).mul(self.image_scale_val).mean(dim=-1)

    def sentence_likelihood(self, pred, targ):
        px_u = dist.OneHotCategorical(*pred)
        return px_u.log_prob(targ).view(*px_u.batch_shape[:2], -1).mul(self.sentence_scale_val).mean(dim=-1)

    def run(self, image_batch, sentence_batch, direction, fn):
        if direction == 'i2s':
            self.direction = 'i2s'
            self.encoder = self.a_to_z
            self.decoder = self.z_to_a
            self.cond_prior = self.b_to_z
            self.classifier = self.z_to_b
            self.likelihood_s = self.image_likelihood
            self.likelihood_t = self.sentence_likelihood
            self.pseudo_samples = self.pseudo_samples_b
            data, targ = image_batch, sentence_batch
        elif direction == 's2i':
            self.direction = 's2i'
            self.encoder = self.b_to_z
            self.decoder = self.z_to_b
            self.cond_prior = self.a_to_z
            self.classifier = self.z_to_a
            self.likelihood_s = self.sentence_likelihood
            self.likelihood_t = self.image_likelihood
            self.pseudo_samples = self.pseudo_samples_a
            data, targ = sentence_batch, image_batch
        elif direction == 'bi':
            loss_i2s = self.run(image_batch, sentence_batch, 'i2s', fn)
            loss_s2i = self.run(image_batch, sentence_batch, 's2i', fn)
            return 0.5 * (loss_i2s + loss_s2i)
        return fn(data, targ, 10)
    
    def load_vocab(self):
        assert os.path.exists(self.vocab_file)
        with open(self.vocab_file, 'r') as vocab_file:
            vocab = json.load(vocab_file)
        return vocab['i2w']
    
    def _sent_process(self, sentences):
        fn_trun = lambda s: s[:np.where(s == 2)[0][0] + 1] if 2 in s else s
        fn_2i = lambda t: t.cpu().numpy().astype(int)
        return [fn_trun(fn_2i(s)) for s in sentences]
    
    def _plot_sentences_as_tensor(self, batched_text_modality):
        i2w = self.load_vocab()
        sentences_processed = self._sent_process(batched_text_modality.argmax(-1))
        sentences_worded = [' '.join(i2w[str(word)] for word in sent if i2w[str(word)] != '<pad>') for sent in sentences_processed]
        return self.plot_text_as_image_tensor(sentences_worded, pixel_width=64, pixel_height=384)
    
    def plot_text_as_image_tensor(self, sentences_lists_of_words, pixel_width=64, pixel_height=384):
        imgs = []
        for sentence in sentences_lists_of_words:
            px = 1 / plt.rcParams['figure.dpi']  # pixel in inches
            fig = plt.figure(figsize=(pixel_width * px, pixel_height * px))
            plt.text(
                x=1,
                y=0.5,
                s='{}'.format(
                    ' '.join(i + '\n' if (n + 1) % 1 == 0
                            else i for n, i in enumerate([word for word in sentence.split() if word != '<eos>']))),
                fontsize=7,
                verticalalignment='center_baseline',
                horizontalalignment='right'
            )
            plt.axis('off')

            # Draw the canvas and retrieve the image as a NumPy array
            fig.canvas.draw()
            image_np = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
            image_np = image_np.reshape(fig.canvas.get_width_height()[::-1] + (3,))

            # Convert the NumPy array to a PyTorch tensor
            image_tensor = torch.from_numpy(image_np).permute(2, 0, 1).float() / 255  # Normalize to [0, 1]
            imgs.append(image_tensor)
            # Clean up the figure
            plt.close(fig)
        return torch.stack(imgs, dim=0)

    def self_and_cross_modal_generation(self, image, sentence, num=1, N=8, dim=2):
        recon_triess = [[[] for _ in range(2)] for _ in range(2)]
        outputss = [[[] for _ in range(2)] for _ in range(2)]

        for _ in range(num):
            recon_image = self.a_to_a(image).cpu()
            image_to_sentence = self.a_to_b(image).cpu()
            sentence_to_image = self.b_to_a(sentence).cpu()
            recon_sentence = self.b_to_b(sentence).cpu()

            recon_triess[0][0].append(recon_image)
            recon_triess[0][1].append(self._plot_sentences_as_tensor(image_to_sentence))
            recon_triess[1][0].append(sentence_to_image)
            recon_triess[1][1].append(self._plot_sentences_as_tensor(recon_sentence))

        input_image = image.cpu()
        input_sentence = self._plot_sentences_as_tensor(sentence).cpu()
        
        outputss[0][0] = make_grid(torch.cat([input_image]+recon_triess[0][0], dim=dim), nrow=N)
        outputss[0][1] = make_grid(torch.cat([input_image]+recon_triess[0][1], dim=dim), nrow=N)
        outputss[1][0] = make_grid(torch.cat([input_sentence]+recon_triess[1][0], dim=dim), nrow=N)
        outputss[1][1] = make_grid(torch.cat([input_sentence]+recon_triess[1][1], dim=dim), nrow=N)
        
        return outputss
    
    def self_and_cross_modal_generation_spec(self, image, sentence, num=5, N=16, dim=2):
        recon_triess = [[[] for _ in range(2)] for _ in range(2)]
        outputss = [[[] for _ in range(2)] for _ in range(2)]

        recon_image = self.a_to_a(image, num).cpu() 
        image_to_sentence = self.a_to_b(image, num).cpu()
        sentence_to_image = self.b_to_a(sentence, num).cpu()
        recon_sentence = self.b_to_b(sentence, num).cpu()

        for i in range(num):
            recon_triess[0][0].append(recon_image[i])
            recon_triess[0][1].append(self._plot_sentences_as_tensor(image_to_sentence[i]))
            recon_triess[1][0].append(sentence_to_image[i])
            recon_triess[1][1].append(self._plot_sentences_as_tensor(recon_sentence[i]))
         
        input_image = image.cpu()
        input_sentence = self._plot_sentences_as_tensor(sentence).cpu()
        
        outputss[0][0] = make_grid(torch.cat([input_image]+recon_triess[0][0], dim=dim), nrow=N)
        outputss[0][1] = make_grid(torch.cat([input_image]+recon_triess[0][1], dim=dim), nrow=N)
        outputss[1][0] = make_grid(torch.cat([input_sentence]+recon_triess[1][0], dim=dim), nrow=N)
        outputss[1][1] = make_grid(torch.cat([input_sentence]+recon_triess[1][1], dim=dim), nrow=N)
        
        return outputss
    
    
    def self_and_cross_modal_generation_for_fid_calculation(self, data, dataPath, fidPath, i):
        px_us = [[None for _ in range(2)] for _ in range(2)]
        with torch.no_grad():
            image, sentence = data[0].to(self.device), data[1].to(self.device)
            px_us[0][0] = self.a_to_a(image).cpu()
            px_us[1][0] = self.b_to_a(sentence).cpu()
        
        # for im in range(image.size(0)):
        #     save_image(image[im, :, :, :],
        #                 '{}/image/{}_{}.png'.format(dataPath, im, i))   
        
        for im in range(px_us[0][0].size(0)):
            save_image(px_us[0][0][im, :, :, :],
                        '{}/image/image/{}_{}.png'.format(fidPath, im, i))                 
        for im in range(px_us[1][0].size(0)):
            save_image(px_us[1][0][im, :, :, :],
                        '{}/sentence/image/{}_{}.png'.format(fidPath, im, i))   
