"""
 Copyright (c) 2023, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import logging

import torch
import torch.nn as nn
from torch.cuda.amp import autocast as autocast
from transformers import T5TokenizerFast
import itertools
import numpy as np
import gc
from collections import Counter
import time
import torch.nn.functional as F
from lavis.common.registry import registry
from lavis.models.blip2_models.blip2 import Blip2Base, disabled_train
from lavis.models.blip2_models.modeling_t5 import T5Config, T5ForConditionalGeneration
from transformers import BertTokenizer
from promptcap import PromptCap
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
from PIL import Image
# ext_paraphrase = True
# perform_selection = False
# perform_ensembling = True

class BatchPromptCap(PromptCap):
    def __init__(self, ckpt="vqascore/promptcap-coco-vqa", device='cuda'):
        PromptCap.__init__(self, ckpt)
        self.model.to(device)

    def caption(self, prompts, images, num_beams=5, no_repeat_ngram_size=3, max_new_tokens=100, device='cuda', **generator_args):
        self.model.to(device)
        if not isinstance(prompts, list): prompts = [prompts]
        if not isinstance(images, list): images = [images]
        assert len(prompts) == len(images), "Input dimensions are incompatible."
        img_features = []
        for img in images:
            img = Image.open(img)
            img = self.patch_resize_transform(img)
            img = img.unsqueeze(0)
            img_features.append(img)
        images = torch.concat(img_features).to(device)
        image_masks = torch.full([images.shape[0]], True)
        prompt_tokens = self.tokenizer(prompts, return_tensors="pt", padding=True)
        prompts = prompt_tokens.input_ids
        prompt_masks = prompt_tokens.attention_mask
        prompts = prompts.to(device)
        
        with torch.no_grad():
            gen = self.model.generate(prompts, patch_images=images, 
                                      patch_masks=image_masks,
                                      num_beams=num_beams, 
                                      no_repeat_ngram_size=no_repeat_ngram_size, 
                                      max_new_tokens=max_new_tokens,
                                      **generator_args)
            
        captions = self.tokenizer.batch_decode(gen, skip_special_tokens=True)
        captions = [cap.strip() for cap in captions]
        return captions

@registry.register_model("blip2_t5_par")
class Blip2T5Par(Blip2Base):
    """
    BLIP2 T5 model.
    Supported model types:
        - pretrain_flant5xl: pretrained model with FlanT5-XL
        - pretrain_flant5xl_vitL: pretrained model with FlanT5-XL
        - pretrain_flant5xxl: pretrained model with FlanT5-XXL
        - caption_coco_flant5xl: fintuned image captioning model with FlanT5-XL
    Usage:
        >>> from lavis.models import load_model
        >>> model = load_model("blip2_t5", "pretrain_flant5xl")
    """

    PRETRAINED_MODEL_CONFIG_DICT = {
        "pretrain_flant5xl": "configs/models/blip2/blip2_pretrain_flant5xl.yaml",
        "pretrain_flant5xl_vitL": "configs/models/blip2/blip2_pretrain_flant5xl_vitL.yaml",
        "pretrain_flant5xxl": "configs/models/blip2/blip2_pretrain_flant5xxl.yaml",
        "caption_coco_flant5xl": "configs/models/blip2/blip2_caption_flant5xl.yaml",
    }

    def __init__(
        self,
        vit_model="eva_clip_g",
        img_size=224,
        drop_path_rate=0,
        use_grad_checkpoint=False,
        vit_precision="fp16",
        freeze_vit=True,
        num_query_token=32,
        t5_model="google/flan-t5-xl",
        prompt="",
        max_txt_len=32,
        apply_lemmatizer=False,
        paraphrase_prompt="Paraphrase: {}",
        context_paraphrase_prompt="\nQuestion: {}\nBased on the context, rephrase the question.",
        par_model_name='tuner007/pegasus_paraphrase',
        par_num_beams = 10,
        num_paraphrases = 10,
        ext_paraphrase = True,
        perform_selection = False,
        selection_criterion = 'Aconf',
        perform_ensembling = True,
        verbose=True,
        constrained=True,
        use_caption=False,
        use_promptcap=False,
        alt_device = 0,
        caption_prompt='Please describe this image according to the given question: {}',
        
    ):
        """
        apply_lemmatizer: when set to True, postprocess predict_answers() result with lemmas.
        """
        super().__init__()

        self.tokenizer = self.init_tokenizer()

        self.visual_encoder, self.ln_vision = self.init_vision_encoder(
            vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
        )
        if freeze_vit:
            for name, param in self.visual_encoder.named_parameters():
                param.requires_grad = False
            self.visual_encoder = self.visual_encoder.eval()
            self.visual_encoder.train = disabled_train
            logging.info("freeze vision encoder")

        self.Qformer, self.query_tokens = self.init_Qformer(
            num_query_token, self.visual_encoder.num_features
        )
        self.Qformer.cls = None
        self.Qformer.bert.embeddings.word_embeddings = None
        self.Qformer.bert.embeddings.position_embeddings = None
        for layer in self.Qformer.bert.encoder.layer:
            layer.output = None
            layer.intermediate = None

        self.t5_tokenizer = T5TokenizerFast.from_pretrained(t5_model)
        t5_config = T5Config.from_pretrained(t5_model)
        t5_config.dense_act_fn = "gelu"
        self.t5_model = T5ForConditionalGeneration.from_pretrained(
            t5_model, config=t5_config
        )

        for name, param in self.t5_model.named_parameters():
            param.requires_grad = False
            param.data = param.data.bfloat16()

        self.t5_proj = nn.Linear(
            self.Qformer.config.hidden_size, self.t5_model.config.hidden_size
        )

        self.max_txt_len = max_txt_len
        self.prompt = prompt

        self._apply_lemmatizer = apply_lemmatizer
        self._lemmatizer = None

        self.ext_paraphrase = ext_paraphrase
        self.perform_selection = perform_selection
        self.selection_criterion = selection_criterion
        self.perform_ensembling = perform_ensembling
        self.verbose = verbose
        self.constrained = constrained
        self.use_caption = use_caption
        self.use_promptcap = use_promptcap
        self.caption_prompt = caption_prompt
        self.alt_device = alt_device

        # External paraphrasing part
        self.par_model_name = par_model_name
        self.par_num_beams = par_num_beams
        self.num_paraphrases = num_paraphrases
        self.set_alt_device()
        if self.ext_paraphrase: self.init_ext_paraphrase_model()
        else: 
            self.paraphrase_prompt = paraphrase_prompt
            self.context_paraphrase_prompt = context_paraphrase_prompt
        if self.use_caption and self.use_promptcap: self.init_caption_model()

    def set_alt_device(self):
        torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
        if torch_device == 'cuda' and self.alt_device > 0:
            a = torch.full([2], True).cuda(self.alt_device)
            torch_device = a.device
        self.alt_device = torch_device

    def init_ext_paraphrase_model(self):
        torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
        if not self.use_caption: torch_device = self.alt_device
        # if torch_device == 'cuda' and self.alt_device > 0:
        #     torch_device = 'cuda:{}'.format(self.alt_device)
        self.par_tokenizer = PegasusTokenizer.from_pretrained(self.par_model_name)
        self.par_model = PegasusForConditionalGeneration.from_pretrained(self.par_model_name).to(torch_device)
        print('Par model device: ' , self.par_model.device)
        
    def init_caption_model(self):
        # torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
        # if torch_device == 'cuda' and self.alt_device > 0:
        #     torch_device = 'cuda:{}'.format(self.alt_device)
        self.cap_model = BatchPromptCap("vqascore/promptcap-coco-vqa", self.alt_device)

    def ext_paraphrase_response(self, input_text, device):
        self.par_model.to(device)
        if self.num_paraphrases > 0:
            constraint = ['?']
            constraint_ids = self.par_tokenizer(constraint, return_tensors='pt', add_special_tokens=False).input_ids.detach().cpu()
            constraint_ids = constraint_ids[:,1:].tolist()
            batch = self.par_tokenizer(input_text,truncation=True,padding='longest',max_length=60, return_tensors="pt").to(device)
            if self.constrained:
                translated = self.par_model.generate(**batch, force_words_ids=constraint_ids, max_length=60,num_beams=self.par_num_beams, num_return_sequences=self.num_paraphrases, temperature=0.7).detach().cpu()
            else: 
                translated = self.par_model.generate(**batch, max_length=60,num_beams=self.par_num_beams, num_return_sequences=self.num_paraphrases, temperature=0.7).detach().cpu()
            tgt_text = self.par_tokenizer.batch_decode(translated, skip_special_tokens=True)
            p = self.num_paraphrases
            tgt_chunks = [[input_text[i]] +  tgt_text[i*p:i*p + p] for i in range(len(input_text))]
            del batch
            gc.collect()
        # to-do shortlist such that only questions (end in ?) are retained
        else: tgt_chunks = [[input_text[i]] for i in range(len(input_text))]
        return tgt_chunks
    
    def int_paraphrase_response(self, input_text, prompt, device, context=None, context_atts=None):
        prompted_input_text = [prompt.format(text) for text in input_text]
        constraint = ['?']
        constraint_ids = self.t5_tokenizer(constraint, return_tensors='pt', add_special_tokens=False).input_ids.detach().cpu()
        constraint_ids = constraint_ids[:,1:].tolist()
        if context is not None:
            prefix = 'Context: '
            prefix_tokens = self.t5_tokenizer([prefix], return_tensors='pt').to(device)
            prefix_ids = torch.repeat_interleave(prefix_tokens.input_ids[:,:-1], len(context), dim=0)
            prefix_embeds = self.t5_model.encoder.embed_tokens(prefix_ids) 
            prefix_atts = torch.repeat_interleave(prefix_tokens.attention_mask[:,:-1], len(context), dim=0)

        with self.maybe_autocast(dtype=torch.bfloat16):
            input_tokens = self.t5_tokenizer(prompted_input_text, padding="longest", truncation=True, max_length=self.max_txt_len, return_tensors="pt").to(device)
            inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)
            encoder_atts = input_tokens.attention_mask
            if context is not None:
                inputs_embeds = torch.cat([prefix_embeds, context, inputs_embeds], dim=1)
                encoder_atts = torch.cat([prefix_atts, context_atts, encoder_atts], dim=1)
            translated = self.t5_model.generate(
                inputs_embeds=inputs_embeds,
                attention_mask=encoder_atts,
                force_words_ids=constraint_ids,
                # do_sample=True,
                # top_p=0.925,
                temperature=0.9,
                num_beams=self.par_num_beams,
                # num_beam_groups=2,
                max_new_tokens=50,
                num_return_sequences=self.num_paraphrases,
                # num_return_sequences= 1,
            )
            tgt_text = self.t5_tokenizer.batch_decode(translated, skip_special_tokens=True)
            p = self.num_paraphrases
            tgt_chunks = [[input_text[i]] +  tgt_text[i*p:i*p + p] for i in range(len(input_text))]
            import pdb; pdb.set_trace()
            return tgt_chunks
        
    def ensemble(self, answers):
        # To-do remove surface competition in answers
        adict = Counter(answers)
        return [adict.most_common(1)[0][0]]


    def forward(self, samples):
        image = samples["image"]

        with self.maybe_autocast():
            image_embeds = self.ln_vision(self.visual_encoder(image))
        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
            image.device
        )

        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
        query_output = self.Qformer.bert(
            query_embeds=query_tokens,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_atts,
            return_dict=True,
        )

        inputs_t5 = self.t5_proj(query_output.last_hidden_state)
        atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)

        with self.maybe_autocast(dtype=torch.bfloat16):
            input_tokens = self.t5_tokenizer(
                samples["text_input"],
                padding="longest",
                truncation=True,
                max_length=self.max_txt_len,
                return_tensors="pt",
            ).to(image.device)
            output_tokens = self.t5_tokenizer(
                samples["text_output"],
                padding="longest",
                truncation=True,
                max_length=self.max_txt_len,
                return_tensors="pt",
            ).to(image.device)

            encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)

            targets = output_tokens.input_ids.masked_fill(
                output_tokens.input_ids == self.t5_tokenizer.pad_token_id, -100
            )

            inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)
            inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)

            outputs = self.t5_model(
                inputs_embeds=inputs_embeds,
                attention_mask=encoder_atts,
                decoder_attention_mask=output_tokens.attention_mask,
                return_dict=True,
                labels=targets,
            )
            loss = outputs.loss

            return {"loss": loss}

    @torch.no_grad()
    def generate(
        self,
        samples,
        use_nucleus_sampling=False,
        num_beams=5,
        max_length=30,
        min_length=1,
        top_p=0.9,
        repetition_penalty=1.0,
        length_penalty=1.0,
        num_captions=1,
        temperature=1,
    ):
        """
        Args:
            samples (dict): A dictionary containing the following keys:
                - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)
            use_nucleus_sampling (bool): Whether to use nucleus sampling. If False, use top-k sampling.
            num_beams (int): Number of beams for beam search. 1 means no beam search.
            max_length (int): The maximum length of the sequence to be generated.
            min_length (int): The minimum length of the sequence to be generated.
            top_p (float): The cumulative probability for nucleus sampling.
            repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty.
            num_captions (int): Number of captions to be generated for each image.
        Returns:
            captions (list): A list of strings of length batch_size * num_captions.
        """
        # print('Entered generate function, please check.')
        image = samples["image"]

        with self.maybe_autocast():
            image_embeds = self.ln_vision(self.visual_encoder(image))
        image_embeds = image_embeds.float()
        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
            image.device
        )

        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
        query_output = self.Qformer.bert(
            query_embeds=query_tokens,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_atts,
            return_dict=True,
        )

        inputs_t5 = self.t5_proj(query_output.last_hidden_state)
        atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)

        if "prompt" in samples.keys():
            prompt = samples["prompt"]
        else:
            prompt = self.prompt

        if isinstance(prompt, str):
            prompt = [prompt] * image.size(0)
        else:
            assert len(prompt) == image.size(
                0
            ), "The number of prompts must be equal to the batch size."

        input_tokens = self.t5_tokenizer(
            prompt, padding="longest", return_tensors="pt"
        ).to(image.device)

        encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)
        # encoder_atts = input_tokens.attention_mask

        with self.maybe_autocast(dtype=torch.bfloat16):
            inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)
            inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)

            outputs = self.t5_model.generate(
                inputs_embeds=inputs_embeds,
                attention_mask=encoder_atts,
                do_sample=use_nucleus_sampling,
                top_p=top_p,
                temperature=temperature,
                num_beams=num_beams,
                max_new_tokens=max_length,
                min_length=min_length,
                repetition_penalty=repetition_penalty,
                length_penalty=length_penalty,
                num_return_sequences=num_captions,
            )
            output_text = self.t5_tokenizer.batch_decode(
                outputs, skip_special_tokens=True
            )

        return output_text

    def predict_answers(
        self,
        samples,
        num_beams=5,
        inference_method="generate",
        max_len=10,
        min_len=1,
        num_ans_candidates=128,
        answer_list=None,
        prompt="",
        length_penalty=-1,
        **kwargs
    ):
        image = samples["image"]
        with self.maybe_autocast():
            image_embeds = self.ln_vision(self.visual_encoder(image))
        image_embeds = image_embeds.float()
        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
            image.device
        )

        # Text-grounding in Q-former
        # question_tokens = self.tokenizer(samples['text_input'], padding="longest", return_tensors='pt').to(image.device)
        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
        query_output = self.Qformer.bert(
            # input_ids = question_tokens.input_ids,
            query_embeds=query_tokens,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_atts,
            return_dict=True,
        )
        bs = len(samples['image'])
        inputs_t5 = self.t5_proj(query_output.last_hidden_state)
        atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)

        if isinstance(samples["text_input"], str):
            samples["text_input"] = [samples["text_input"]]
        if prompt:
            # Insert paraphrase here. 
            par_device = image.device if self.use_caption else self.alt_device
            if self.ext_paraphrase: samples["rephrased_text_input"] = self.ext_paraphrase_response(samples["text_input"], par_device)
            else: 
                # samples["rephrased_text_input"] = self.int_paraphrase_response(samples["text_input"], self.paraphrase_prompt, image.device)
                samples["rephrased_text_input"] = self.int_paraphrase_response(samples["text_input"], self.context_paraphrase_prompt, image.device, inputs_t5, atts_t5)
            samples['text_input'] = list(itertools.chain.from_iterable(samples["rephrased_text_input"]))
            
            start = time.time()
            if self.use_caption: 
                # Using PromptCap
                if self.use_promptcap:
                    extended_images = np.repeat(samples['image_path'], self.num_paraphrases + 1).tolist()
                    cap_input = [self.caption_prompt.format(question) for question in samples['text_input']]
                    samples['captions'] = self.cap_model.caption(cap_input, extended_images, device=self.alt_device)
                    # samples['captions'] = [self.cap_model.caption(cap_inp, image) for cap_inp, image in zip(cap_input, extended_images)]

                # Using BLIPv2 (Int.) Captioning
                else:
                    samples['captions'] = self.generate(samples={'image': samples['image']})
                    samples['captions'] = np.repeat(samples['captions'], self.num_paraphrases + 1).tolist()
                samples['captions'] = [cap.capitalize() + '.' if cap[-1] != '.' else cap.capitalize() for cap in samples['captions']]
            text_input = [prompt.format(question) for question in samples["text_input"]]
            if self.use_caption:
                text_input = [cap + ' ' + txt for cap, txt in zip(samples['captions'], text_input)]
            
            # print(time.time() - start)
        else:
            text_input = samples["text_input"]

        input_tokens = self.t5_tokenizer(
            text_input, padding="longest", return_tensors="pt"
        ).to(image.device)
        if self.perform_selection and self.selection_criterion == 'Qconf':
            ques_input = [txt.split("Question: ")[0] + "Question: " for txt in text_input]
            ques_tokens = self.t5_tokenizer(
                ques_input, padding="longest", return_tensors="pt"
            ).to(image.device)
        if not self.use_caption:
            # Repeating in case of paraphrase
            atts_t5 = torch.repeat_interleave(atts_t5, self.num_paraphrases + 1, dim=0)
            encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)
            
            if self.perform_selection and self.selection_criterion == 'Qconf':
                ques_atts = torch.cat([atts_t5, ques_tokens.attention_mask], dim=1)
        else:
            encoder_atts = input_tokens.attention_mask
            if self.perform_selection and self.selection_criterion == 'Qconf': ques_atts = ques_tokens.attention_mask

        samples['image'].detach().cpu()

        with self.maybe_autocast(dtype=torch.bfloat16):
            inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)
            if self.perform_selection and self.selection_criterion == 'Qconf':
                ques_embeds = self.t5_model.encoder.embed_tokens(ques_tokens.input_ids)
            # Repeating in case of paraphrase
            if not self.use_caption:
                inputs_t5 = torch.repeat_interleave(inputs_t5, self.num_paraphrases + 1, dim=0)
                inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)
                if self.perform_selection and self.selection_criterion == 'Qconf':
                    ques_embeds = torch.cat([inputs_t5, ques_embeds], dim=1)

            outputs = self.t5_model.generate(
                inputs_embeds=inputs_embeds,
                attention_mask=encoder_atts,
                do_sample=False,
                num_beams=num_beams,
                max_new_tokens=max_len,
                min_length=min_len,
                length_penalty=length_penalty,
                return_dict_in_generate=True,
                output_scores=True,
            ) 
            # del outputs.scores, outputs.beam_indices
            # # Clean up GPU space
            # del inputs_embeds, inputs_t5, encoder_atts, atts_t5
            # gc.collect()
            # # torch.cuda.empty_cache()

            output_text = self.t5_tokenizer.batch_decode(
                outputs.sequences.detach().cpu(), skip_special_tokens=True
                # outputs, skip_special_tokens=True
            )

            output_lengths = np.array([len(self.t5_tokenizer.convert_ids_to_tokens(ids, skip_special_tokens = True)) for ids in outputs.sequences])
            if not np.all(output_lengths): 
                # Setting very high length for degenerate model response.
                output_lengths[output_lengths == 0] = 100
            # output_scores = outputs.sequences_scores.detach().cpu().numpy()
            # output_scores = list(output_scores / output_lengths)
            if self.selection_criterion == 'Qconf':
                output_scores = self.compute_score(ques_embeds, ques_atts, samples['text_input'])
            elif self.selection_criterion == 'Aconf':
                output_scores = self.compute_score(inputs_embeds, encoder_atts, output_text) 
            
            # output_scores = self.compute_score()
            


        if self._apply_lemmatizer:
            output_text = self._lemmatize(output_text)
        
        # In case of paraphrasing
        p = self.num_paraphrases + 1
        output_text = [output_text[i*p: i*p + p] for i in range(bs)]
        output_scores = np.array([output_scores[i*p: i*p + p] for i in range(bs)])

        # In case of selection
        if self.perform_selection:
            select_idx = np.argmax(output_scores, axis=1)
            # input_text = [samples["rephrased_text_input"][s][idx] for s,idx in enumerate(select_idx)]
            output_text = [output_text[s][idx] for s,idx in enumerate(select_idx)]
        # In case of ensembling
        elif self.perform_ensembling: 
            output_text = list(map(self.ensemble, output_text))

        if self.verbose: return output_text, samples['rephrased_text_input'], output_scores.tolist()
        else: return output_text, [], []

    # def compute_question_confidence(self, inputs_embeds, encoder_atts, output_text):
    #     scores = []
    #     return np.array(scores)

    def compute_score(self, inputs_embeds, encoder_atts, output_text):
        decoder_input = self.t5_tokenizer(output_text, return_tensors='pt', padding=True)
        decoder_input_ids = decoder_input['input_ids']
        decoder_attention_mask = decoder_input['attention_mask']
        logits = self.t5_model.forward(inputs_embeds=inputs_embeds, attention_mask=encoder_atts, decoder_input_ids=decoder_input_ids.cuda(), decoder_attention_mask=decoder_attention_mask.cuda()).logits.detach().cpu()
        all_logprobs = torch.log(torch.softmax(logits, dim=-1))
        labels = self.t5_tokenizer(output_text)['input_ids']
        labels = [l[:-1] for l in labels]
        filter_sums = []
        for row, label in zip(all_logprobs, labels):
            row = row[:len(label), :].float().numpy()
            vocab_size = row.shape[-1]
            if len(label):
                loc = F.one_hot(torch.tensor(label), num_classes=vocab_size).numpy().astype(bool)
                summed_logprob = np.sum(row, where = loc)
            else: summed_logprob = -100 # Degenerate generation
            filter_sums.append(summed_logprob / max(1, len(label)))
        return np.array(filter_sums)
    
    def compute_score_with_dropout(self, inputs_embeds, encoder_atts, output_text, K=20):
        with torch.no_grad():
            print(self.t5_model.training)
            self.t5_model.train()
            probs = []
            for k in range(K):
                torch.manual_seed(k)
                probs.append(np.exp(self.compute_score(inputs_embeds, encoder_atts, output_text)))
            logprobs = np.log(np.mean(probs, axis=0))
        return

    def _lemmatize(self, answers):
        def apply(answer):
            doc = self.lemmatizer(answer)

            words = []
            for token in doc:
                if token.pos_ in ["NOUN", "VERB"]:
                    words.append(token.lemma_)
                else:
                    words.append(token.text)
            answer = " ".join(words)

            return answer

        return [apply(answer) for answer in answers]

    @property
    def lemmatizer(self):
        if self._lemmatizer is None:
            try:
                import spacy

                self._lemmatizer = spacy.load("en_core_web_sm")
            except ImportError:
                logging.error(
                    """
                    Please install spacy and en_core_web_sm model to apply lemmatization.
                    python -m spacy download en_core_web_sm
                    OR
                    import spacy.cli
                    spacy.cli.download("en_core_web_sm")
                    """
                )
                exit(1)

        return self._lemmatizer

    @classmethod
    def from_config(cls, cfg):
        vit_model = cfg.get("vit_model", "eva_clip_g")
        img_size = cfg.get("image_size")
        num_query_token = cfg.get("num_query_token")
        t5_model = cfg.get("t5_model")

        drop_path_rate = cfg.get("drop_path_rate", 0)
        use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
        vit_precision = cfg.get("vit_precision", "fp16")
        freeze_vit = cfg.get("freeze_vit", True)

        prompt = cfg.get("prompt", "")
        max_txt_len = cfg.get("max_txt_len", 32)

        apply_lemmatizer = cfg.get("apply_lemmatizer", False)

        model = cls(
            vit_model=vit_model,
            img_size=img_size,
            drop_path_rate=drop_path_rate,
            use_grad_checkpoint=use_grad_checkpoint,
            vit_precision=vit_precision,
            freeze_vit=freeze_vit,
            num_query_token=num_query_token,
            t5_model=t5_model,
            prompt=prompt,
            max_txt_len=max_txt_len,
            apply_lemmatizer=apply_lemmatizer,
            par_num_beams = cfg.get('par_num_beams', 5),
            num_paraphrases = cfg.get('num_paraphrases', 5),
            ext_paraphrase = cfg.get('ext_paraphrase', True),
            perform_selection = cfg.get('perform_selection', False),
            selection_criterion = cfg.get('selection_criterion', 'Aconf'),
            perform_ensembling = cfg.get('perform_ensembling', False),
            verbose = cfg.get('verbose', True),
            constrained = cfg.get('constrained', True),
            use_caption = cfg.get('use_caption', False),
            use_promptcap = cfg.get('use_promptcap', False),
            alt_device = cfg.get('alt_device', -1),
        )
        model.load_checkpoint_from_config(cfg)

        return model
