"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import os
import torch.distributed as dist
import torchvision.transforms
import torch
from PIL import Image
from einops import rearrange
import numpy as np

from lavis.common.logger import MetricLogger
from lavis.datasets.data_utils import prepare_sample
from lavis.common.dist_utils import is_dist_avail_and_initialized
from lavis.common.registry import registry
from lavis.tasks.base_task import BaseTask
from lavis.models.ldm_models.models.diffusion.ddim import DDIMSampler
from lavis.models.ldm_models.models.modules.distributions.distributions import DiagonalGaussianDistribution

@registry.register_task("image_gene")
class ImageGenerationTask(BaseTask):
    def before_evaluation(self, model, dataset, **kwargs):
        pass
    def generate(self, model, data_loader, cuda_enabled=True, spec_ebds=None, **kwargs):
        metric_logger = MetricLogger(delimiter="  ")
        header = "Evaluation"
        print_freq = 10

        results = []
        for i, samples in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
            samples = prepare_sample(samples, cuda_enabled=cuda_enabled)

            if spec_ebds == None:
                eval_output = self.generate_step(model=model, samples=samples, data_loader=data_loader, **kwargs)
            else:
                eval_output = self.mixed_generate_step(model=model, samples=samples, data_loader=data_loader, spec_ebds=spec_ebds, **kwargs)
            results.extend(eval_output)
            if i >= len(data_loader) - 1:
                break

        if is_dist_avail_and_initialized():
            dist.barrier()

        return results

    def generate_step(self, model, samples, data_loader, **kwargs):
        results = []
        image_id = samples["image_id"]
        for ID, ind in zip(image_id, samples["index"]):
            text = [data_loader.dataset.text[t_ind] for t_ind in data_loader.dataset.img2txt[int(ind)]]

            sampler = DDIMSampler(model, device=torch.device("cuda"))
            shape = [4, 512 // 8, 512 // 8]
            b = len(text)
            c = model.get_learned_conditioning(text)
            uc = model.get_learned_conditioning(b * [""])
            latent_samples, _ = sampler.sample(S=50,
                                        conditioning=c,
                                        batch_size=b,
                                        shape=shape,
                                        verbose=False,
                                        unconditional_guidance_scale=9.0,
                                        unconditional_conditioning=uc,
                                        eta=0.0,
                                        )

            x_samples = model.decode_first_stage(latent_samples)
            x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)

            sample_path = "./results/mixed_mllmu/T2I_orig"
            os.makedirs(sample_path, exist_ok=True)
            base_count = len(os.listdir(sample_path))

            assert len(text) == x_samples.shape[0], f"len(text): {len(text)}  x_samples.shape: {x_samples.shape}"
            for x_sample, cpt in zip(x_samples, text):
                x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                img = Image.fromarray(x_sample.astype(np.uint8))
                img.save(os.path.join(sample_path, f"orig_{base_count:05}_{ID}_{cpt[:32]}.png"))
                results.append({"caption": cpt, "image_id": str(ID), "filename": f"orig_{base_count:05}_{ID}_{cpt[:32]}.png"})

        return results

    def mixed_generate_step(self, model, samples, data_loader, spec_ebds, **kwargs):
        results = []
        image = samples["image"]
        image_id = samples["image_id"]
        b = image.shape[0]

        spec_s = 0
        for ID, ind in zip(image_id, samples["index"]):
            text = [data_loader.dataset.text[t_ind] for t_ind in data_loader.dataset.img2txt[int(ind)]]
            spec_e = spec_s + len(text)
            spec_ebd = spec_ebds[spec_s:spec_e, ...]
            spec_s += len(text)

            sampler = DDIMSampler(model, device=torch.device("cuda"))
            shape = [4, 512 // 8, 512 // 8]
            b = len(text)
            c = model.get_learned_conditioning(text)
            mixed_c = torch.cat([c, spec_ebd], dim=1)
            uc = model.get_learned_conditioning(b * [""])
            uc = torch.cat([uc, uc], dim=1)
            latent_samples, _ = sampler.sample(S=50,
                                               conditioning=mixed_c,
                                               batch_size=b,
                                               shape=shape,
                                               verbose=False,
                                               unconditional_guidance_scale=9.0,
                                               unconditional_conditioning=uc,
                                               eta=0.0,
                                               )

            x_samples = model.decode_first_stage(latent_samples)
            x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)

            sample_path = "./results/mixed_mllmu/T2I_mixed"
            os.makedirs(sample_path, exist_ok=True)
            base_count = len(os.listdir(sample_path))

            assert len(text) == x_samples.shape[0], f"len(text): {len(text)}  x_samples.shape: {x_samples.shape}"
            for x_sample, cpt in zip(x_samples, text):
                x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                img = Image.fromarray(x_sample.astype(np.uint8))
                img.save(os.path.join(sample_path, f"mixed_{base_count:05}_{ID}_{cpt[:32]}.png"))
                results.append(
                    {"caption": cpt, "image_id": str(ID), "filename": f"mixed_{base_count:05}_{ID}_{cpt[:32]}.png"})

        return results


    def get_txt_specs(self, model, data_loader, cuda_enabled=True):
        metric_logger = MetricLogger(delimiter="  ")
        header = "Text specific extracting: "
        print_freq = 10

        results = []
        id_queue = []
        for i, samples in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
            samples = prepare_sample(samples, cuda_enabled=cuda_enabled)

            ext_output, image_id = self.txt_spec_ext_step(model=model, samples=samples, data_loader=data_loader)
            results.extend(ext_output)
            id_queue.extend(image_id)
            if i >= len(data_loader) - 1:
                break

        if is_dist_avail_and_initialized():
            dist.barrier()

        results = torch.stack(results, dim=0).to(model.device)

        return results, id_queue

    @torch.no_grad()
    def txt_spec_ext_step(self, model, samples, data_loader):
        results = []
        image_ids = []
        image_id = samples["image_id"]

        for ID, ind in zip(image_id, samples["index"]):
            text = [data_loader.dataset.text[t_ind] for t_ind in data_loader.dataset.img2txt[int(ind)]]

            if hasattr(model, 'encode') and callable(model.encode):
                c = model.encode(text)
                if isinstance(c, DiagonalGaussianDistribution):
                    c = c.mode()
            else:
                c = model.cond_stage_model(text)

            text_embeds = c
            results.append(text_embeds)
            image_ids += list(samples["image_id"])
        results = torch.cat(results, dim=0).to(model.device)
        return results, image_ids


