"""
 Copyright (c) 2023, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import logging

import torch
import torch.nn as nn
from torch.cuda.amp import autocast as autocast
from transformers import T5TokenizerFast
import itertools
import numpy as np
import gc
from collections import Counter

from lavis.common.registry import registry
from lavis.models.blip2_models.blip2 import Blip2Base, disabled_train
from lavis.models.blip2_models.modeling_t5 import T5Config, T5ForConditionalGeneration
from transformers import BertTokenizer

from transformers import PegasusForConditionalGeneration, PegasusTokenizer

ext_paraphrase = True
perform_selection = False
perform_ensembling = True

@registry.register_model("blip2_t5")
class Blip2T5(Blip2Base):
    """
    BLIP2 T5 model.
    Supported model types:
        - pretrain_flant5xl: pretrained model with FlanT5-XL
        - pretrain_flant5xl_vitL: pretrained model with FlanT5-XL
        - pretrain_flant5xxl: pretrained model with FlanT5-XXL
        - caption_coco_flant5xl: fintuned image captioning model with FlanT5-XL
    Usage:
        >>> from lavis.models import load_model
        >>> model = load_model("blip2_t5", "pretrain_flant5xl")
    """

    PRETRAINED_MODEL_CONFIG_DICT = {
        "pretrain_flant5xl": "configs/models/blip2/blip2_pretrain_flant5xl.yaml",
        "pretrain_flant5xl_vitL": "configs/models/blip2/blip2_pretrain_flant5xl_vitL.yaml",
        "pretrain_flant5xxl": "configs/models/blip2/blip2_pretrain_flant5xxl.yaml",
        "caption_coco_flant5xl": "configs/models/blip2/blip2_caption_flant5xl.yaml",
    }

    def __init__(
        self,
        vit_model="eva_clip_g",
        img_size=224,
        drop_path_rate=0,
        use_grad_checkpoint=False,
        vit_precision="fp16",
        freeze_vit=True,
        num_query_token=32,
        t5_model="google/flan-t5-xl",
        prompt="",
        max_txt_len=32,
        apply_lemmatizer=False,
        paraphrase_prompt="Paraphrase: {}",
        context_paraphrase_prompt="\nQuestion: {}\nBased on the context, rephrase the question.",
        par_model_name='tuner007/pegasus_paraphrase',
        par_num_beams = 10,
        num_paraphrases = 10,
    ):
        """
        apply_lemmatizer: when set to True, postprocess predict_answers() result with lemmas.
        """
        super().__init__()

        self.tokenizer = self.init_tokenizer()

        self.visual_encoder, self.ln_vision = self.init_vision_encoder(
            vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
        )
        if freeze_vit:
            for name, param in self.visual_encoder.named_parameters():
                param.requires_grad = False
            self.visual_encoder = self.visual_encoder.eval()
            self.visual_encoder.train = disabled_train
            logging.info("freeze vision encoder")

        self.Qformer, self.query_tokens = self.init_Qformer(
            num_query_token, self.visual_encoder.num_features
        )
        self.Qformer.cls = None
        self.Qformer.bert.embeddings.word_embeddings = None
        self.Qformer.bert.embeddings.position_embeddings = None
        for layer in self.Qformer.bert.encoder.layer:
            layer.output = None
            layer.intermediate = None

        self.t5_tokenizer = T5TokenizerFast.from_pretrained(t5_model)
        t5_config = T5Config.from_pretrained(t5_model)
        t5_config.dense_act_fn = "gelu"
        self.t5_model = T5ForConditionalGeneration.from_pretrained(
            t5_model, config=t5_config
        )

        for name, param in self.t5_model.named_parameters():
            param.requires_grad = False
            param.data = param.data.bfloat16()

        self.t5_proj = nn.Linear(
            self.Qformer.config.hidden_size, self.t5_model.config.hidden_size
        )

        self.max_txt_len = max_txt_len
        self.prompt = prompt

        self._apply_lemmatizer = apply_lemmatizer
        self._lemmatizer = None

        # External paraphrasing part
        self.par_model_name = par_model_name
        self.par_num_beams = par_num_beams
        self.num_paraphrases = num_paraphrases
        if ext_paraphrase: self.init_ext_paraphrase_model()
        else: 
            self.paraphrase_prompt = paraphrase_prompt
            self.context_paraphrase_prompt = context_paraphrase_prompt

    def init_ext_paraphrase_model(self):
        torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.par_tokenizer = PegasusTokenizer.from_pretrained(self.par_model_name)
        self.par_model = PegasusForConditionalGeneration.from_pretrained(self.par_model_name).to(torch_device)
        

    def ext_paraphrase_response(self, input_text, device):
        constraint = ['?']
        constraint_ids = self.par_tokenizer(constraint, return_tensors='pt', add_special_tokens=False).input_ids.detach().cpu()
        constraint_ids = constraint_ids[:,1:].tolist()
        batch = self.par_tokenizer(input_text,truncation=True,padding='longest',max_length=60, return_tensors="pt").to(device)
        translated = self.par_model.generate(**batch, force_words_ids=constraint_ids, max_length=60,num_beams=self.par_num_beams, num_return_sequences=self.num_paraphrases, temperature=0.7).detach().cpu()
        tgt_text = self.par_tokenizer.batch_decode(translated, skip_special_tokens=True)
        p = self.num_paraphrases
        tgt_chunks = [[input_text[i]] +  tgt_text[i*p:i*p + p] for i in range(len(input_text))]
        del batch
        gc.collect()
        # torch.cuda.empty_cache()
        # to-do shortlist such that only questions (end in ?) are retained
        return tgt_chunks
    
    def int_paraphrase_response(self, input_text, prompt, device, context=None, context_atts=None):
        prompted_input_text = [prompt.format(text) for text in input_text]
        constraint = ['?']
        constraint_ids = self.t5_tokenizer(constraint, return_tensors='pt', add_special_tokens=False).input_ids.detach().cpu()
        constraint_ids = constraint_ids[:,1:].tolist()
        if context is not None:
            prefix = 'Context: '
            prefix_tokens = self.t5_tokenizer([prefix], return_tensors='pt').to(device)
            prefix_ids = torch.repeat_interleave(prefix_tokens.input_ids[:,:-1], len(context), dim=0)
            prefix_embeds = self.t5_model.encoder.embed_tokens(prefix_ids) 
            prefix_atts = torch.repeat_interleave(prefix_tokens.attention_mask[:,:-1], len(context), dim=0)

        with self.maybe_autocast(dtype=torch.bfloat16):
            input_tokens = self.t5_tokenizer(prompted_input_text, padding="longest", truncation=True, max_length=self.max_txt_len, return_tensors="pt").to(device)
            inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)
            encoder_atts = input_tokens.attention_mask
            if context is not None:
                inputs_embeds = torch.cat([prefix_embeds, context, inputs_embeds], dim=1)
                encoder_atts = torch.cat([prefix_atts, context_atts, encoder_atts], dim=1)
            translated = self.t5_model.generate(
                inputs_embeds=inputs_embeds,
                attention_mask=encoder_atts,
                force_words_ids=constraint_ids,
                # do_sample=True,
                # top_p=0.925,
                temperature=0.9,
                num_beams=self.par_num_beams,
                # num_beam_groups=2,
                max_new_tokens=50,
                num_return_sequences=self.num_paraphrases,
                # num_return_sequences= 1,
            )
            tgt_text = self.t5_tokenizer.batch_decode(translated, skip_special_tokens=True)
            p = self.num_paraphrases
            tgt_chunks = [[input_text[i]] +  tgt_text[i*p:i*p + p] for i in range(len(input_text))]
            import pdb; pdb.set_trace()
            return tgt_chunks
        
    def ensemble(self, answers):
        # To-do remove surface competition in answers
        adict = Counter(answers)
        return adict.most_common(1)[0][0]


    def forward(self, samples):
        image = samples["image"]

        with self.maybe_autocast():
            image_embeds = self.ln_vision(self.visual_encoder(image))
        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
            image.device
        )

        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
        query_output = self.Qformer.bert(
            query_embeds=query_tokens,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_atts,
            return_dict=True,
        )

        inputs_t5 = self.t5_proj(query_output.last_hidden_state)
        atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)

        with self.maybe_autocast(dtype=torch.bfloat16):
            input_tokens = self.t5_tokenizer(
                samples["text_input"],
                padding="longest",
                truncation=True,
                max_length=self.max_txt_len,
                return_tensors="pt",
            ).to(image.device)
            output_tokens = self.t5_tokenizer(
                samples["text_output"],
                padding="longest",
                truncation=True,
                max_length=self.max_txt_len,
                return_tensors="pt",
            ).to(image.device)

            encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)

            targets = output_tokens.input_ids.masked_fill(
                output_tokens.input_ids == self.t5_tokenizer.pad_token_id, -100
            )

            inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)
            inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)

            outputs = self.t5_model(
                inputs_embeds=inputs_embeds,
                attention_mask=encoder_atts,
                decoder_attention_mask=output_tokens.attention_mask,
                return_dict=True,
                labels=targets,
            )
            loss = outputs.loss

            return {"loss": loss}

    @torch.no_grad()
    def generate(
        self,
        samples,
        use_nucleus_sampling=False,
        num_beams=5,
        max_length=30,
        min_length=1,
        top_p=0.9,
        repetition_penalty=1.0,
        length_penalty=1.0,
        num_captions=1,
        temperature=1,
    ):
        """
        Args:
            samples (dict): A dictionary containing the following keys:
                - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)
            use_nucleus_sampling (bool): Whether to use nucleus sampling. If False, use top-k sampling.
            num_beams (int): Number of beams for beam search. 1 means no beam search.
            max_length (int): The maximum length of the sequence to be generated.
            min_length (int): The minimum length of the sequence to be generated.
            top_p (float): The cumulative probability for nucleus sampling.
            repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty.
            num_captions (int): Number of captions to be generated for each image.
        Returns:
            captions (list): A list of strings of length batch_size * num_captions.
        """
        print('Entered generate function, please check.')
        image = samples["image"]

        with self.maybe_autocast():
            image_embeds = self.ln_vision(self.visual_encoder(image))
        image_embeds = image_embeds.float()
        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
            image.device
        )

        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
        query_output = self.Qformer.bert(
            query_embeds=query_tokens,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_atts,
            return_dict=True,
        )

        inputs_t5 = self.t5_proj(query_output.last_hidden_state)
        atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)

        if "prompt" in samples.keys():
            prompt = samples["prompt"]
        else:
            prompt = self.prompt

        if isinstance(prompt, str):
            prompt = [prompt] * image.size(0)
        else:
            assert len(prompt) == image.size(
                0
            ), "The number of prompts must be equal to the batch size."

        input_tokens = self.t5_tokenizer(
            prompt, padding="longest", return_tensors="pt"
        ).to(image.device)

        # encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)
        encoder_atts = input_tokens.attention_mask

        with self.maybe_autocast(dtype=torch.bfloat16):
            inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)
            # inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)

            outputs = self.t5_model.generate(
                inputs_embeds=inputs_embeds,
                attention_mask=encoder_atts,
                do_sample=use_nucleus_sampling,
                top_p=top_p,
                temperature=temperature,
                num_beams=num_beams,
                max_new_tokens=max_length,
                min_length=min_length,
                repetition_penalty=repetition_penalty,
                length_penalty=length_penalty,
                num_return_sequences=num_captions,
            )
            output_text = self.t5_tokenizer.batch_decode(
                outputs, skip_special_tokens=True
            )

        return output_text

    def predict_answers(
        self,
        samples,
        num_beams=5,
        inference_method="generate",
        max_len=10,
        min_len=1,
        num_ans_candidates=128,
        answer_list=None,
        prompt="",
        length_penalty=-1,
        **kwargs
    ):
        image = samples["image"]
        with self.maybe_autocast():
            image_embeds = self.ln_vision(self.visual_encoder(image))
        image_embeds = image_embeds.float()
        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
            image.device
        )

        # Text-grounding in Q-former
        # question_tokens = self.tokenizer(samples['text_input'], padding="longest", return_tensors='pt').to(image.device)
        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
        query_output = self.Qformer.bert(
            # input_ids = question_tokens.input_ids,
            query_embeds=query_tokens,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_atts,
            return_dict=True,
        )
        # import pdb; pdb.set_trace()
        bs = len(samples['image'])
        inputs_t5 = self.t5_proj(query_output.last_hidden_state)
        atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)

        if isinstance(samples["text_input"], str):
            samples["text_input"] = [samples["text_input"]]
        if prompt:
            # Insert paraphrase here. 
            if ext_paraphrase: samples["rephrased_text_input"] = self.ext_paraphrase_response(samples["text_input"], image.device)
            else: 
                # samples["rephrased_text_input"] = self.int_paraphrase_response(samples["text_input"], self.paraphrase_prompt, image.device)
                samples["rephrased_text_input"] = self.int_paraphrase_response(samples["text_input"], self.context_paraphrase_prompt, image.device, inputs_t5, atts_t5)
            samples['text_input'] = list(itertools.chain.from_iterable(samples["rephrased_text_input"]))
            text_input = [prompt.format(question) for question in samples["text_input"]]
        else:
            text_input = samples["text_input"]
        input_tokens = self.t5_tokenizer(
            text_input, padding="longest", return_tensors="pt"
        ).to(image.device)
        
        # Repeating in case of paraphrase
        atts_t5 = torch.repeat_interleave(atts_t5, self.num_paraphrases + 1, dim=0)
        encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)
        # encoder_atts = input_tokens.attention_mask

        samples['image'].detach().cpu()

        with self.maybe_autocast(dtype=torch.bfloat16):
            inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)
            # Repeating in case of paraphrase
            inputs_t5 = torch.repeat_interleave(inputs_t5, self.num_paraphrases + 1, dim=0)
            inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)

            outputs = self.t5_model.generate(
                inputs_embeds=inputs_embeds,
                attention_mask=encoder_atts,
                do_sample=False,
                num_beams=num_beams,
                max_new_tokens=max_len,
                min_length=min_len,
                length_penalty=length_penalty,
                return_dict_in_generate=True,
                output_scores=True,
            ) 
            del outputs.scores,outputs.beam_indices
            # Clean up GPU space
            del inputs_embeds, inputs_t5, encoder_atts, atts_t5
            gc.collect()
            # torch.cuda.empty_cache()

            output_text = self.t5_tokenizer.batch_decode(
                outputs.sequences.detach().cpu(), skip_special_tokens=True
                # outputs, skip_special_tokens=True
            )
            if perform_selection:
                output_lengths = np.array([len(self.t5_tokenizer.convert_ids_to_tokens(ids, skip_special_tokens = True)) for ids in outputs.sequences])
                if not np.all(output_lengths): 
                    # Setting very high length for degenerate model response.
                    output_lengths[output_lengths == 0] = 100
                output_scores = outputs.sequences_scores.detach().cpu().numpy()
                output_scores = list(output_scores / output_lengths)

        import pdb; pdb.set_trace()
        if self._apply_lemmatizer:
            output_text = self._lemmatize(output_text)
        
        # In case of paraphrasing
        p = self.num_paraphrases + 1
        output_text = [output_text[i*p: i*p + p] for i in range(bs)]
        # In case of selection
        if perform_selection:
            output_scores = np.array([output_scores[i*p: i*p + p] for i in range(bs)])
            select_idx = np.argmax(output_scores, axis=1)
            # input_text = [samples["rephrased_text_input"][s][idx] for s,idx in enumerate(select_idx)]
            output_text = [output_text[s][idx] for s,idx in enumerate(select_idx)]
        # In case of ensembling
        elif perform_ensembling: 
            output_text = list(map(self.ensemble, output_text))
            import pdb; pdb.set_trace()
        return output_text

    def _lemmatize(self, answers):
        def apply(answer):
            doc = self.lemmatizer(answer)

            words = []
            for token in doc:
                if token.pos_ in ["NOUN", "VERB"]:
                    words.append(token.lemma_)
                else:
                    words.append(token.text)
            answer = " ".join(words)

            return answer

        return [apply(answer) for answer in answers]

    @property
    def lemmatizer(self):
        if self._lemmatizer is None:
            try:
                import spacy

                self._lemmatizer = spacy.load("en_core_web_sm")
            except ImportError:
                logging.error(
                    """
                    Please install spacy and en_core_web_sm model to apply lemmatization.
                    python -m spacy download en_core_web_sm
                    OR
                    import spacy.cli
                    spacy.cli.download("en_core_web_sm")
                    """
                )
                exit(1)

        return self._lemmatizer

    @classmethod
    def from_config(cls, cfg):
        vit_model = cfg.get("vit_model", "eva_clip_g")
        img_size = cfg.get("image_size")
        num_query_token = cfg.get("num_query_token")
        t5_model = cfg.get("t5_model")

        drop_path_rate = cfg.get("drop_path_rate", 0)
        use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
        vit_precision = cfg.get("vit_precision", "fp16")
        freeze_vit = cfg.get("freeze_vit", True)

        prompt = cfg.get("prompt", "")
        max_txt_len = cfg.get("max_txt_len", 32)

        apply_lemmatizer = cfg.get("apply_lemmatizer", False)

        model = cls(
            vit_model=vit_model,
            img_size=img_size,
            drop_path_rate=drop_path_rate,
            use_grad_checkpoint=use_grad_checkpoint,
            vit_precision=vit_precision,
            freeze_vit=freeze_vit,
            num_query_token=num_query_token,
            t5_model=t5_model,
            prompt=prompt,
            max_txt_len=max_txt_len,
            apply_lemmatizer=apply_lemmatizer,
        )
        model.load_checkpoint_from_config(cfg)

        return model
