"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import torch.distributed as dist
import torchvision.transforms
import torch

from lavis.common.logger import MetricLogger
from lavis.datasets.data_utils import prepare_sample
from lavis.common.dist_utils import is_dist_avail_and_initialized
from lavis.common.registry import registry
from lavis.tasks.captioning import CaptionTask

@registry.register_task("text_gene")
class TextGenerationTask(CaptionTask):
    def generate(self, model, data_loader, cuda_enabled=True, spec_conds=None, **kwargs):
        metric_logger = MetricLogger(delimiter="  ")
        header = "Evaluation"
        print_freq = 10

        results = []
        for i, samples in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
            samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
            if spec_conds == None:
                eval_output = self.generate_step(model=model, samples=samples, **kwargs)
            else:
                eval_output = self.mixed_generate_step(model=model, samples=samples, spec_conds=spec_conds, **kwargs)
            results.extend(eval_output)
            if i >= len(data_loader) - 1:
                break

        if is_dist_avail_and_initialized():
            dist.barrier()

        return results

    def generate_step(self, model, samples, **kwargs):
        results = []

        captions = model.generate(
            samples,
            **kwargs,
        )

        if "image_id" in samples.keys():
            img_ids = samples["image_id"]
        elif "index" in samples.keys():
            img_ids = samples["index"]
        else:
            raise

        num_captions = kwargs["num_captions"] if "num_captions" in kwargs.keys() else 1
        for i, img_id in enumerate(img_ids):
            for caption in captions[i*num_captions:(i+1)*num_captions]:
                results.append({"caption": caption, "image_id": int(img_id)})

        print(results[:num_captions])

        return results

    def mixed_generate_step(self, model, samples, spec_conds, **kwargs):
        results = []

        image = samples["image"]

        with model.maybe_autocast():
            image_embeds = model.ln_vision(model.visual_encoder(image))
        image_embeds = image_embeds.float()
        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
            image.device
        )

        query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1)
        query_output = model.Qformer.bert(
            query_embeds=query_tokens,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_atts,
            return_dict=True,
        )

        spec_cond = []
        for image_id in samples["image_id"]:
            spec_cond.append(spec_conds[int(image_id)])
        spec_cond = torch.stack(spec_cond).to(query_output.last_hidden_state.device)

        print(spec_cond.shape)

        mixed_hidden_state = torch.cat([query_output.last_hidden_state, spec_cond], dim=1)
        inputs_t5 = model.t5_proj(mixed_hidden_state)
        atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)

        if "prompt" in samples.keys():
            prompt = samples["prompt"]
        else:
            prompt = model.prompt

        if isinstance(prompt, str):
            prompt = [prompt] * image.size(0)
        else:
            assert len(prompt) == image.size(
                0
            ), "The number of prompts must be equal to the batch size."

        input_tokens = model.t5_tokenizer(
            prompt, padding="longest", return_tensors="pt"
        ).to(image.device)

        encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)

        with model.maybe_autocast(dtype=torch.bfloat16):
            inputs_embeds = model.t5_model.encoder.embed_tokens(input_tokens.input_ids)
            inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)

            outputs = model.t5_model.generate(
                inputs_embeds=inputs_embeds,
                attention_mask=encoder_atts,
                do_sample=kwargs["use_nucleus_sampling"],
                top_p=kwargs["top_p"],
                temperature=kwargs["temperature"],
                num_beams=kwargs["num_beams"],
                max_new_tokens=kwargs["max_length"],
                min_length=kwargs["min_length"],
                repetition_penalty=kwargs["repetition_penalty"],
                length_penalty=kwargs["length_penalty"],
                num_return_sequences=kwargs["num_captions"],
            )
            output_text = model.t5_tokenizer.batch_decode(
                outputs, skip_special_tokens=True
            )

        if "image_id" in samples.keys():
            img_ids = samples["image_id"]
        elif "index" in samples.keys():
            img_ids = samples["index"]
        else:
            raise

        num_captions = kwargs["num_captions"] if "num_captions" in kwargs.keys() else 1
        for i, img_id in enumerate(img_ids):
            for caption in output_text[i*num_captions:(i+1)*num_captions]:
                results.append({"caption": caption, "image_id": int(img_id)})

        print(results[:num_captions])

        return results

    def get_img_specs(self, model, data_loader, cuda_enabled=True):
        metric_logger = MetricLogger(delimiter="  ")
        header = "Image specific extracting: "
        print_freq = 10

        results = []
        id_queue = []
        for i, samples in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
            samples = prepare_sample(samples, cuda_enabled=cuda_enabled)

            ext_output, image_id = self.img_spec_ext_step(model=model, samples=samples)
            results.extend(ext_output)
            id_queue.extend(image_id)
            if i >= len(data_loader) - 1:
                break

        if is_dist_avail_and_initialized():
            dist.barrier()

        return results, id_queue

    @torch.no_grad()
    def img_spec_ext_step(self, model, samples):
        results = []
        image_id = []

        image = samples["image"]

        with model.maybe_autocast():
            image_embeds = model.ln_vision(model.visual_encoder(image))
        image_embeds = image_embeds.float()
        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
            image.device
        )

        query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1)
        query_output = model.Qformer.bert(
            query_embeds=query_tokens,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_atts,
            return_dict=True,
        )

        if "image_id" in samples.keys():
            img_ids = samples["image_id"]
        elif "index" in samples.keys():
            img_ids = samples["index"]
        else:
            raise
        results.append(query_output)
        image_id.append(img_ids)
        return results, image_id


