import argparse
import time
from PIL import Image

import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
from transformers import StoppingCriteria, StoppingCriteriaList

import dataclasses
from enum import auto, Enum
from typing import List, Tuple, Any

from ..common.registry import registry


class SeparatorStyle(Enum):
    """Different separator style."""
    SINGLE = auto()
    TWO = auto()


@dataclasses.dataclass
class Conversation:
    """A class that keeps all conversation history."""
    system: str
    roles: List[str]
    messages: List[List[str]]
    offset: int
    # system_img: List[Image.Image] = []
    sep_style: SeparatorStyle = SeparatorStyle.SINGLE
    sep: str = "###"
    sep2: str = None

    skip_next: bool = False
    conv_id: Any = None

    def get_prompt(self):
        if self.sep_style == SeparatorStyle.SINGLE:
            ret = self.system + self.sep
            for role, message in self.messages:
                if message:
                    ret += role + ": " + message + self.sep
                else:
                    ret += role + ":"
            return ret
        elif self.sep_style == SeparatorStyle.TWO:
            seps = [self.sep, self.sep2]
            ret = self.system + seps[0]
            for i, (role, message) in enumerate(self.messages):
                if message:
                    ret += role + ": " + message + seps[i % 2]
                else:
                    ret += role + ":"
            return ret
        else:
            raise ValueError(f"Invalid style: {self.sep_style}")

    def append_message(self, role, message):
        self.messages.append([role, message])

    def to_gradio_chatbot(self):
        ret = []
        for i, (role, msg) in enumerate(self.messages[self.offset:]):
            if i % 2 == 0:
                ret.append([msg, None])
            else:
                ret[-1][-1] = msg
        return ret

    def copy(self):
        return Conversation(
            system=self.system,
            # system_img=self.system_img,
            roles=self.roles,
            messages=[[x, y] for x, y in self.messages],
            offset=self.offset,
            sep_style=self.sep_style,
            sep=self.sep,
            sep2=self.sep2,
            conv_id=self.conv_id)

    def dict(self):
        return {
            "system": self.system,
            # "system_img": self.system_img,
            "roles": self.roles,
            "messages": self.messages,
            "offset": self.offset,
            "sep": self.sep,
            "sep2": self.sep2,
            "conv_id": self.conv_id,
        }


class StoppingCriteriaSub(StoppingCriteria):
    def __init__(self, stops=[]):
        super().__init__()
        self.stops = stops
        self.prompt_len = 0

    def _contains_subsequence(self, large_tensor, small_tensor):
        len_small = len(small_tensor)
        for i in range(0, len(large_tensor)-len_small+1):
            flag = torch.all((small_tensor == large_tensor[i: i+len_small])).item()
            if flag:
                return True
        return False

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        for x in input_ids:
            end_now = False
            for stop in self.stops:
                stop = stop.to(x.device)
                end_now |= self._contains_subsequence(x[self.prompt_len:], stop)
                # if torch.all((stop == input_ids[i][-len(stop):])).item():
                #     return True
            if not end_now:
                return False
        return True


CONV_VISION = Conversation(
    system="Give the following image: <Img>ImageContent</Img>. "
           "You will be able to see the image once I provide it to you. Please answer my questions.",
    roles=("Human", "Assistant"),
    messages=[],
    offset=2,
    sep_style=SeparatorStyle.SINGLE,
    sep="###",
)



class Chat:
    def __init__(self, model, vis_processor, device='cuda:0'):
        self.device = device
        self.model = model
        self.vis_processor = vis_processor
        self.stop_words_ids = [torch.tensor([835]).to(self.device),
                          torch.tensor([2277, 29937]).to(self.device)]  # '###' can be encoded in two different ways.
        self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=self.stop_words_ids)])
    
    def move_stopping_criteria_device(self, device, dtype=torch.float32):
        self.stop_words_ids = [stop_tensor.to(device, dtype=dtype) for stop_tensor in self.stop_words_ids]
        self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=self.stop_words_ids)])

    def ask(self, text, conv):
        if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
                and conv.messages[-1][1][-6:] == '</Img>':  # last message is image.
            conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
        else:
            conv.append_message(conv.roles[0], text)

    def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
               repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
        conv.append_message(conv.roles[1], None)
        embs = self.get_context_emb(conv, img_list)

        current_max_len = embs.shape[1] + max_new_tokens
        if current_max_len - max_length > 0:
            print('Warning: The number of tokens in current conversation exceeds the max length. '
                  'The model will not see the contexts outside the range.')
        begin_idx = max(0, current_max_len - max_length)

        embs = embs[:, begin_idx:]

        outputs = self.model.llama_model.generate(
            inputs_embeds=embs,
            max_new_tokens=max_new_tokens,
            stopping_criteria=self.stopping_criteria,
            num_beams=num_beams,
            do_sample=True,
            min_length=min_length,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            length_penalty=length_penalty,
            temperature=temperature,
        )
        output_token = outputs[0]
        if output_token[0] == 0:  # the model might output a unknow token <unk> at the beginning. remove it
            output_token = output_token[1:]
        if output_token[0] == 1:  # some users find that there is a start token <s> at the beginning. remove it
            output_token = output_token[1:]
        output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
        output_text = output_text.split('###')[0]  # remove the stop sign '###'
        output_text = output_text.split('Assistant:')[-1].strip()
        conv.messages[-1][1] = output_text
        return output_text, output_token.cpu().numpy()

    def upload_img(self, image, conv, img_list):
        if isinstance(image, str):  # is a image path
            raw_image = Image.open(image).convert('RGB')
            image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
        elif isinstance(image, Image.Image):
            raw_image = image
            image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
        elif isinstance(image, torch.Tensor):
            if len(image.shape) == 3:
                image = image.unsqueeze(0)
            image = image.to(self.device)

        image_emb, _ = self.model.encode_img(image)
        # print(f'Check the shape of image emb: {image_emb.shape}')
        img_list.append(image_emb)
        conv.append_message(conv.roles[0], "<Img><ImageHere></Img>")
        msg = "Received."
        # self.conv.append_message(self.conv.roles[1], msg)
        return msg

    def get_context_emb(self, conv, img_list, prompt=None):
        if prompt==None:
            prompt = conv.get_prompt()
        else: 
            prompt = prompt+' And image also shows '
        #prompt = "Identify the objects in the image. Here is an example: Object 1: Object2: ...\n <Img><ImageHere></Img>"
        prompt_segs = prompt.split('<ImageHere>')
        assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
        seg_tokens = [
            self.model.llama_tokenizer(
                seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
            # only add bos to the first seg
            for i, seg in enumerate(prompt_segs)
        ]
        seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
        mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
        mixed_embs = torch.cat(mixed_embs, dim=1)
        return mixed_embs
    
    def get_text_emb(self, conv):
        prompt = conv.get_prompt()
        
        prompt_segs = prompt.split('<ImageHere>')
        seg_tokens = [
            self.model.llama_tokenizer(
                seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
            # only add bos to the first seg
            for i, seg in enumerate(prompt_segs)
        ]
        seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
        mixed_embs = torch.cat(seg_embs, dim=1)
        return mixed_embs


    def get_inscontra_context_emb(self, conv, img_list, question):
        prompt = conv.get_prompt()
        ######################################
        prompt = prompt.replace(question, "")
        ########################################
        prompt_segs = prompt.split('<ImageHere>')
        assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
        seg_tokens = [
            self.model.llama_tokenizer(
                seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
            # only add bos to the first seg
            for i, seg in enumerate(prompt_segs)
        ]
        seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
        mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
        mixed_embs = torch.cat(mixed_embs, dim=1)
        return mixed_embs

    def clean_string(self, string):
        import re
        # 去除数字和逗号
        cleaned_string = re.sub(r'\d|,', '', string)
        # 去除句号及句号后面的内容
        cleaned_string = cleaned_string.split('.')[0]
        # 去除字符串首尾的空格
        cleaned_string = cleaned_string.strip()
        return cleaned_string

    def extract_tags(self, sentence):
        import re
        sentence = sentence.split('Identify the objects in the sentence:')[-1]
        tags = sentence.split(')')[1:]
        tags = [self.clean_string(string) for string in tags]
        for tag in tags:
            if 'image' in tag or 'photo' in tag or 'photograph' in tag:
                tags.remove(tag)
        tags = [tag.replace("a ", "").replace("an ", "").replace("the ", "") for tag in tags]
        return tags
    





    def tags_batch_answer(self, image_list, batch_question_list, raw_question_list, chat_list, max_new_tokens=300, num_beams=5, min_length=1, top_p=0.9, repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000, cond_coeff=0., cf_coeff=0., mi_coeff=0., ins_coeff=False, max_ensemble=False,
                min_ensemble=False,uncertainty_threshold=-1e9, remove_qs=True, use_more_attend=False,penalty_alpha=0., topk=1, refine=True):
        summary_outputs = []
        for image, question_list in zip(image_list, batch_question_list):
            chat_list = [CONV_VISION.copy() for _ in range(len(question_list))]
            embs_list = []
            bs = len(image_list)
            for question, conv in zip( question_list, chat_list):
                img_list = []
                self.upload_img(image, conv, img_list)
                self.ask(question, conv)
                conv.append_message(conv.roles[1], None)
                embs = self.get_context_emb(conv, img_list)
                embs_list.append(embs)
            max_emb_token = max([x.shape[1] for x in embs_list])
            embs_list = torch.cat([F.pad(x, (0, 0, max_emb_token - x.shape[1], 0, 0, 0), value=0) for x in embs_list], dim=0)

            
            outputs = []
            outputs_scores = []
            for i in range(0, len(question_list), 2):
                tag_description = self.model.llama_model.generate(
                    inputs_embeds=embs_list[i:i+2],
                    max_new_tokens=max_new_tokens,
                    stopping_criteria=self.stopping_criteria,
                    num_beams=num_beams,
                    do_sample=False,
                    min_length=min_length,
                    #top_p=top_p,
                    repetition_penalty=repetition_penalty,
                    length_penalty=length_penalty,
                    temperature=temperature,
                    output_scores=True,
                    return_dict_in_generate=True
                )
                outputs_scores.extend(tag_description.sequences_scores.tolist()
)
                for output_token in tag_description.sequences:
                    if output_token[0] == 0:  # the model might output a unknow token <unk> at the beginning. remove it
                        output_token = output_token[1:]
                    if output_token[0] == 1:  # some users find that there is a start token <s> at the beginning. remove it
                        output_token = output_token[1:]
                    output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
                    output_text = output_text.split('###')[0]  # remove the stop sign '###'
                    output_text = output_text.split('Assistant:')[-1].strip()
                    outputs.append(output_text)
                
            '''conv = CONV_VISION.copy()
            conv.system = 'These sentences are different descriptions of an image. Please combine these sentences to provide a complete description of the image.'
            embs_list = []
            #Sentences:The cook in the image is a man wearing glasses and standing in front of an oven with a turkey on a tray in it. The image shows a man cooking a turkey in an oven with an exhaust hood above it. The image shows a man cooking a turkey in an oven. The image shows a man cooking a turkey in an oven. The man in the image is standing in front of an open oven with a turkey on a tray inside. The image shows a man cooking a turkey in an oven. Summary: The image depicts a man wearing glasses cooking a turkey in an oven. The man is seen standing in front of an open oven with a tray containing the turkey inside. 
            #\n Captions:The image shows a man sitting at a wooden desk with a computer in front of him. The image shows a man sitting at a desk with two computer monitors in front of him. The table in the image is a wooden desk with a computer monitor, keyboard, and mouse on top. The image shows a man sitting at a desk with two computer monitors and a keyboard in front of him. The man in the image is sitting at a desk with a computer in front of him. The office in the image appears to be a cluttered space with a desk, computer monitors, and paperwork scattered around. The office supply in the image is a pile of paper rolls on a desk. The person in the image is sitting at a desk with two computer monitors in front of them. Summary: The image depicts a man sitting at a wooden desk with a computer, two monitors, a keyboard, and a mouse in front of him. The office appears cluttered with paperwork scattered around. Additionally, there is a pile of paper rolls on the desk. 
            #I will give you a few sentences and I hope you can provide a concise summary without omitting any information. 
            question = 'Sentences:\nThe cat in the image is a black cat that is lying on a wooden bench. The black cat is sleeping on a wooden bench with its eyes closed. The image shows a black cat sleeping on a wooden bench outside a red barn.\nSummary: A black cat peacefully sleeping with its eyes closed on a wooden bench located outside a red barn.\nSentences:\nThe cook in the image is a man wearing glasses and standing in front of an oven with a turkey on a tray in it. The image shows a man cooking a turkey in an oven with an exhaust hood above it. The image shows a man cooking a turkey in an oven. The image shows a man cooking a turkey in an oven. The man in the image is standing in front of an open oven with a turkey on a tray inside. The image shows a man cooking a turkey in an oven.\nSummary: A man wearing glasses cooking a turkey in an oven. The man is seen standing in front of an open oven with a tray containing the turkey inside.\nSentences:\n{}\nSummary:\n'.format(' '.join(outputs))
            self.ask(question, conv)
            conv.append_message(conv.roles[1], None)
            embs = self.get_text_emb(conv)
            embs_list.append(embs)
            max_emb_token = max([x.shape[1] for x in embs_list])
            embs_list = torch.cat([F.pad(x, (0, 0, max_emb_token - x.shape[1], 0, 0, 0), value=0) for x in embs_list], dim=0)
            summary = self.model.llama_model.generate(
                inputs_embeds=embs_list,
                max_new_tokens=max_new_tokens,
                stopping_criteria=self.stopping_criteria,
                num_beams=num_beams,
                do_sample=False,
                min_length=min_length,
                #top_p=top_p,
                repetition_penalty=repetition_penalty,
                length_penalty=length_penalty,
                temperature=temperature,

            )

            for output_token in summary:
                if output_token[0] == 0:  # the model might output a unknow token <unk> at the beginning. remove it
                    output_token = output_token[1:]
                if output_token[0] == 1:  # some users find that there is a start token <s> at the beginning. remove it
                    output_token = output_token[1:]
                output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
                output_text = output_text.split('###')[0]  # remove the stop sign '###'
                output_text = output_text.split('Assistant:')[-1].strip()
                summary_outputs.append(output_text)'''
            summary_outputs.append(''.join(outputs))
        return summary_outputs



    


    def old_batch_answer(self, image_list, question_list, raw_question_list, chat_list, max_new_tokens=300, num_beams=5, min_length=1, top_p=0.9, repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000, cond_coeff=0., cf_coeff=0., mi_coeff=0., ins_coeff=False, max_ensemble=False,
                min_ensemble=False,uncertainty_threshold=-1e9, remove_qs=True, use_more_attend=False,penalty_alpha=0., topk=1, refine=True):
        embs_list = []
        bs = len(image_list)
        for image, question, conv in zip(image_list, question_list, chat_list):
            img_list = []
            self.upload_img(image, conv, img_list)
            self.ask(question, conv)
            conv.append_message(conv.roles[1], None)
            embs = self.get_context_emb(conv, img_list)
            embs_list.append(embs)
        max_emb_token = max([x.shape[1] for x in embs_list])
        embs_list = torch.cat([F.pad(x, (0, 0, max_emb_token - x.shape[1], 0, 0, 0), value=0) for x in embs_list], dim=0)
        prompts = []
        for question, conv in zip( question_list, chat_list):
            prompts.append(conv.get_prompt())
        chat_list = [CONV_VISION.copy() for _ in range(len(image_list))]
        ins_embs_list = None
        if ins_coeff:
            ins_embs_list = []
            for image, question,raw_question,  conv in zip(image_list, question_list, raw_question_list, chat_list):
                img_list = []
                self.upload_img(image, conv, img_list)
                self.ask(question, conv)
                conv.append_message(conv.roles[1], None)
                embs = self.get_inscontra_context_emb(conv, img_list, raw_question)
                ins_embs_list.append(embs)
            ins_max_emb_token = max([x.shape[1] for x in ins_embs_list])
            ins_embs_list = torch.cat([F.pad(x, (0, 0, ins_max_emb_token - x.shape[1], 0, 0, 0), value=0) for x in ins_embs_list], dim=0)
        contra_inputs = self.get_contra_embeds(bs=bs, prompt=prompts, cond_coeff=cond_coeff, cf_coeff=cf_coeff, mi_coeff=mi_coeff, ins_coeff=ins_coeff, normal_embeds=ins_embs_list, raw_question_list=raw_question_list, remove_qs=remove_qs)

        assert max_emb_token + max_new_tokens < max_length
        if cond_coeff+cf_coeff+mi_coeff>0:
            outputs = self.model.llama_model.generate(
                inputs_embeds=embs_list,
                contrastive_inputs=contra_inputs,
                max_new_tokens=max_new_tokens,
                stopping_criteria=self.stopping_criteria,
                num_beams=num_beams,
                do_sample=False,
                min_length=min_length,
                top_p=top_p,
                repetition_penalty=repetition_penalty,
                length_penalty=length_penalty,
                temperature=temperature,
                condition_coeff=cond_coeff,
                context_free_coeff=cf_coeff,
                mean_img_coeff=mi_coeff,
                max_ensemble=max_ensemble,
                min_ensemble=min_ensemble,
                uncertainty_threshold=uncertainty_threshold,
            )
        else: 

            outputs = self.model.llama_model.generate(
                inputs_embeds=embs_list,
                max_new_tokens=max_new_tokens,
                stopping_criteria=self.stopping_criteria,
                num_beams=num_beams,
                do_sample=False,
                min_length=min_length,
                top_p=top_p,
                repetition_penalty=repetition_penalty,
                length_penalty=length_penalty,
                temperature=temperature,

            )

        batch_outputs = []
        for output_token in outputs:
            if output_token[0] == 0:  # the model might output a unknow token <unk> at the beginning. remove it
                output_token = output_token[1:]
            if output_token[0] == 1:  # some users find that there is a start token <s> at the beginning. remove it
                output_token = output_token[1:]
            output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
            output_text = output_text.split('###')[0]  # remove the stop sign '###'
            output_text = output_text.split('Assistant:')[-1].strip()
            batch_outputs.append(output_text)

        '''######################################
        if refine==True:
            tag_prompt = "Sentence: The image shows a group of people standing in front of a black car with bicycles on the roof. Identify the objects in the sentence: 1) a group of people, 2) a black car, 3) bicycles.; Sentence: The image shows a display of various types of vegetables, including carrots, in wooden bins. Identify the objects in the sentence: 1) various types of vegetables, 2) carrots, 3) wooden bins.; Sentence: A mug of coffee with a spoon in it on a table. Identify the objects in the sentence: 1) A mug of coffee, 2) a spoon, 3) a table.; Sentence: The image shows a large airplane parked on a runway at an airport. Identify the objects in the sentence: 1) the image, 2)  a large airplane, 3) a runway, 4) an airport. Sentence: {} Identify the objects in the sentence:"
            tag_prompts = [tag_prompt.format(output) for output in batch_outputs]
            batch_tags = []
            for tag_prompt in tag_prompts:
                tag_prompt_tokens = self.model.llama_tokenizer(tag_prompt, return_tensors="pt").to(self.device).input_ids
                tag_output = self.model.llama_model.generate(
                                input_ids=tag_prompt_tokens,
                                max_new_tokens=150,
                                stopping_criteria=self.stopping_criteria,
                                num_beams=num_beams,
                                do_sample=False,
                                min_length=min_length,
                                top_p=top_p,
                                repetition_penalty=repetition_penalty,
                                length_penalty=length_penalty,
                                temperature=temperature,
                                eos_token_id=[ 29936, 28048,   663, 29901]

                            )
                sentence = self.model.llama_tokenizer.decode(tag_output[0], add_special_tokens=False) 
                batch_tags.append(self.extract_tags(sentence))
            for tags, image in zip(batch_tags, image_list):
                tag_questions = ["Describe the {} in the image in one sentence.".format(tag) for tag in tags]
                chat_list = [CONV_VISION.copy() for _ in range(len(tag_questions))]
                tags_embs_list = []
                bs = len(image_list)
                for question, conv in zip(tag_questions, chat_list):
                    img_list = []
                    self.upload_img(image, conv, img_list)
                    self.ask(question, conv)
                    conv.append_message(conv.roles[1], None)
                    embs = self.get_context_emb(conv, img_list)
                    tags_embs_list.append(embs)
                max_emb_token = max([x.shape[1] for x in tags_embs_list])
                tags_embs_list = torch.cat([F.pad(x, (0, 0, max_emb_token - x.shape[1], 0, 0, 0), value=0) for x in tags_embs_list], dim=0)
                batch_tag_description = []
                for i in range(0, len(tag_questions), 2):
                    tag_description = self.model.llama_model.generate(
                        inputs_embeds=tags_embs_list[i:i+2],
                        max_new_tokens=max_new_tokens,
                        stopping_criteria=self.stopping_criteria,
                        num_beams=num_beams,
                        do_sample=False,
                        min_length=min_length,
                        top_p=top_p,
                        repetition_penalty=repetition_penalty,
                        length_penalty=length_penalty,
                        temperature=temperature,
                    )
                    
                    for output_token in tag_description:
                        if output_token[0] == 0:  # the model might output a unknow token <unk> at the beginning. remove it
                            output_token = output_token[1:]
                        if output_token[0] == 1:  # some users find that there is a start token <s> at the beginning. remove it
                            output_token = output_token[1:]
                        output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
                        output_text = output_text.split('###')[0]  # remove the stop sign '###'
                        output_text = output_text.split('Assistant:')[-1].strip()
                        batch_tag_description.append(output_text)
        ##############################################################'''

        return batch_outputs


    def batch_answer(self, image_list, question_list, raw_question_list, chat_list, max_new_tokens=300, num_beams=5, min_length=1, top_p=0.9, repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000, cond_coeff=0., cf_coeff=0., mi_coeff=0., ins_coeff=False, max_ensemble=False,
                min_ensemble=False,uncertainty_threshold=-1e9, remove_qs=True, use_more_attend=False,penalty_alpha=0., topk=1, refine=True,attn_threshold=1.,attn_scale=[], bad_words_ids=False,
):
        embs_list = []
        bs = len(image_list)
        for image, question, conv in zip(image_list, question_list, chat_list):
            img_list = []
            self.upload_img(image, conv, img_list)
            self.ask(question, conv)
            conv.append_message(conv.roles[1], None)
            embs = self.get_context_emb(conv, img_list)
            embs_list.append(embs)
        max_emb_token = max([x.shape[1] for x in embs_list])
        embs_list = torch.cat([F.pad(x, (0, 0, max_emb_token - x.shape[1], 0, 0, 0), value=0) for x in embs_list], dim=0)
            
        if bad_words_ids:
            #逗号[29892]
            bwi = [[6124, 304], [3462, 654, 304], [512, 6124, 304], [297, 6124, 304], [297, 6124],[512,6124],[6124, 635], [19814], [1316, 408], [10506, 408], [3160], [7805], [3704], [512, 2325], [512, 27722], [512, 22368]]
            outputs = self.model.llama_model.generate(
                inputs_embeds=embs_list,
                max_new_tokens=max_new_tokens,
                stopping_criteria=self.stopping_criteria,
                num_beams=num_beams,
                do_sample=False,
                min_length=min_length,
                top_p=top_p,
                repetition_penalty=repetition_penalty,
                length_penalty=length_penalty,
                temperature=temperature,
                output_attentions = True,
                return_dict_in_generate=True,
                attn_threshold=attn_threshold,
                attn_scale=attn_scale,
                output_scores=True,
                bad_words_ids=bwi
                #eos_token_id=29889
            ).sequences
        else: 
            outputs = self.model.llama_model.generate(
                inputs_embeds=embs_list,
                max_new_tokens=max_new_tokens,
                stopping_criteria=self.stopping_criteria,
                num_beams=num_beams,
                do_sample=False,
                min_length=min_length,
                top_p=top_p,
                repetition_penalty=repetition_penalty,
                length_penalty=length_penalty,
                temperature=temperature,
                output_attentions = True,
                return_dict_in_generate=True,
                attn_threshold=attn_threshold,
                attn_scale=attn_scale,
                output_scores=True,
                #eos_token_id=29889
            ).sequences
        # 逗号：[29892]
        # such as [1316, 408]



        batch_outputs = []
        for output_token in outputs:
            if output_token[0] == 0:  # the model might output a unknow token <unk> at the beginning. remove it
                output_token = output_token[1:]
            if output_token[0] == 1:  # some users find that there is a start token <s> at the beginning. remove it
                output_token = output_token[1:]
            output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
            output_text = output_text.split('###')[0]  # remove the stop sign '###'
            output_text = output_text.split('Assistant:')[-1].strip()
            #output_text = output_text.split('.')[0]+'.'
            batch_outputs.append(output_text)


       

        return batch_outputs










    def sentence_by_sentence_batch_answer(self, image_list, question_list, raw_question_list, chat_list, max_new_tokens=300, num_beams=5, min_length=1, top_p=0.9, repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000, cond_coeff=0., cf_coeff=0., mi_coeff=0., ins_coeff=False, max_ensemble=False,
                min_ensemble=False,uncertainty_threshold=-1e9, remove_qs=True, use_more_attend=False,penalty_alpha=0., topk=1, refine=True):
        bs = len(image_list)
        prompt_list = [None]*bs
        for p in range(5):

            embs_list = []
            
            nex_prompt_list = []
            for image, question, conv,prompt in zip(image_list, question_list, chat_list, prompt_list):
                if p == 1:
                    prompt = conv.get_prompt()
                img_list = []
                self.upload_img(image, conv, img_list)
                if p == 0:
                    self.ask(question, conv)
                    conv.append_message(conv.roles[1], None)
                else: 
                    prompt = prompt+' '+question
                embs = self.get_context_emb(conv, img_list, prompt)
                embs_list.append(embs)
                nex_prompt_list.append(prompt)
            prompt_list = nex_prompt_list
            max_emb_token = max([x.shape[1] for x in embs_list])
            embs_list = torch.cat([F.pad(x, (0, 0, max_emb_token - x.shape[1], 0, 0, 0), value=0) for x in embs_list], dim=0)
            

            outputs = self.model.llama_model.generate(
                inputs_embeds=embs_list,
                max_new_tokens=max_new_tokens,
                stopping_criteria=self.stopping_criteria,
                num_beams=num_beams,
                do_sample=False,
                min_length=min_length,
                top_p=top_p,
                repetition_penalty=repetition_penalty,
                length_penalty=length_penalty,
                temperature=temperature,
                output_attentions = True,
                return_dict_in_generate=True,
                #eos_token_id=29889
            ).sequences

            outputs = self.model.llama_model.generate(
                inputs_embeds=embs_list,
                max_new_tokens=max_new_tokens,
                stopping_criteria=self.stopping_criteria,
                num_beams=1,
                do_sample=False,
                min_length=min_length,
                top_p=top_p,
                repetition_penalty=repetition_penalty,
                length_penalty=length_penalty,
                temperature=temperature,
                output_attentions = True,
                return_dict_in_generate=True,
                #eos_token_id=29889
            )
            import jsonlines
            with jsonlines.open('./caption_attnscores.json', mode='w') as writer:
                attn_scores0 = {}
                attn_scores1 = {}
                '''for i,token in enumerate(outputs.attentions):
                    for j,layer in enumerate(token):
                        for m,image in enumerate(layer):
                            for n,head in enumerate(image): 
                                if m==0:
                                    attn_scores0[f'token{i}_layer{j}_image{m}_head{n}']=sum(head[0][42:74]).tolist()
                                    
                                if m==1:
                                    attn_scores1[f'token{i}_layer{j}_image{m}_head{n}']=sum(head[0][42:74]).tolist()'''
                ################################
                import threading

                def process_image(i, j, m, n, head):
                    attn_score = sum(head[0][42:74]).tolist()
                    if m == 0:
                        attn_scores0[f'token{i}_layer{j}_image{m}_head{n}'] = attn_score
                    if m == 1:
                        attn_scores1[f'token{i}_layer{j}_image{m}_head{n}'] = attn_score

                def process_token(i, token):
                    for j, layer in enumerate(token):
                        for m, image in enumerate(layer):
                            for n, head in enumerate(image):
                                process_image(i, j, m, n, head)

                # 创建线程列表
                threads = []

                for i, token in enumerate(outputs.attentions):
                    # 创建并启动线程
                    thread = threading.Thread(target=process_token, args=(i, token))
                    thread.start()
                    threads.append(thread)
                # 控制线程数
                if len(threads) >= 16:
                    # 等待线程完成
                    for thread in threads:
                        thread.join()
                    threads = []
                # 等待所有线程完成
                for thread in threads:
                    thread.join()
                ######################################################
                tokens0 = [self.model.llama_tokenizer.decode(token, add_special_tokens=False) for token in outputs.sequences[0][1:]]
                tokens1 = [self.model.llama_tokenizer.decode(token, add_special_tokens=False) for token in outputs.sequences[1][1:]]
                                
                writer.write({'caption':tokens0, 'attn_scores':attn_scores0})
                writer.write({'caption':tokens1, 'attn_scores':attn_scores1})

            batch_outputs = []
            for output_token in outputs:
                if output_token[0] == 0:  # the model might output a unknow token <unk> at the beginning. remove it
                    output_token = output_token[1:]
                if output_token[0] == 1:  # some users find that there is a start token <s> at the beginning. remove it
                    output_token = output_token[1:]
                output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
                output_text = output_text.split('###')[0]  # remove the stop sign '###'
                output_text = output_text.split('Assistant:')[-1].strip()
                #output_text = output_text.split('.')[0]+'.'
                batch_outputs.append(output_text)
            #question_list = batch_outputs


       

        return batch_outputs

    def get_contra_embeds(self,  prompt, bs, cond_coeff=0., cf_coeff=0., mi_coeff=0., ins_coeff=False, normal_embeds=None, raw_question_list=None, remove_qs=False):
        ###########################################################
        #prompt = prompt.replace('Describe the image', 'Describe the image in short')
        #mean_img = torch.zeros(bs, 3, 224, 224).to(self.device)
        #mean_img_bs = bs

        #mean_image_emb, mean_image_attn = self.model.encode_img(mean_img)
        if remove_qs:
            prompt = [prompt[i].replace(raw_question_list[i], '') for i in range(len(prompt))]
        cf_prompt_tokens = self.model.llama_tokenizer(
            prompt,
            padding="longest",
            return_tensors="pt",
            add_special_tokens=True,
        )
        mi_prompt_tokens= self.model.llama_tokenizer(
            prompt,
            padding="longest",
            add_special_tokens=False,
        )

        cf_prompt_embeds = self.model.llama_model.get_input_embeddings()(cf_prompt_tokens.input_ids.to(self.device))#tile(bs, 1,1)
        #mi_prompt_embeds = self.model.llama_model.get_input_embeddings()(torch.tensor([mi_prompt_tokens.input_ids]).to(self.device)).tile(bs, 1,1)
        #prompt_segs = prompt.split('<ImageHere>')
        #assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
        '''seg_tokens = [
            self.model.llama_tokenizer(
                seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
            # only add bos to the first seg
            for i, seg in enumerate(prompt_segs)
        ]
        seg_tokens_masks = [
            self.model.llama_tokenizer(
                seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).attention_mask.tile(bs, 1)
            # only add bos to the first seg
            for i, seg in enumerate(prompt_segs)
        ]
        seg_embs = [self.model.llama_model.model.embed_tokens(seg_t).to(self.device).tile(bs, 1,1) for seg_t in seg_tokens]
        mixed_embs = [emb for pair in zip(seg_embs[:-1], [mean_image_emb]) for emb in pair] + [seg_embs[-1]]
        mixed_masks = [mask for pair in zip(seg_tokens_masks[:-1], [mean_image_attn]) for mask in pair] + [seg_tokens_masks[-1]]
        #mixed_embs = torch.cat(mixed_embs, dim=1)'''




        mean_img_embeds = None
        context_free_embeds = None
        if mi_coeff>0:
            if ins_coeff:
                mean_img_embeds = normal_embeds
                mean_img_attention_mask = None
            else: 
                mean_img_embeds = torch.cat(mixed_embs, dim=1)
                mean_img_attention_mask = torch.cat(mixed_masks, dim=1)
            
        else: 
            mean_img_attention_mask=None
        if cf_coeff>0 or cond_coeff>0:
            context_free_embeds = cf_prompt_embeds
            context_free_attention_mask = cf_prompt_tokens.attention_mask.to(self.device)#.tile(bs, 1)
        else: 
            context_free_attention_mask=None
        contrastive_inputs = {'inputs_embeds':{'context_free':context_free_embeds, 'mean_img':mean_img_embeds }, 'attention_mask':{'context_free': context_free_attention_mask, 'mean_img': mean_img_attention_mask},  'num_beams':1, 'repetition_penalty':1, 'length_penalty':1, 'max_length':1}
        
        return contrastive_inputs




import spacy
from nltk.corpus import wordnet

nlp = spacy.load("en_core_web_sm")

def are_synonyms(word1, word2):
    synonyms = wordnet.synsets(word1)
    for synset in synonyms:
        if word2 in synset.lemma_names():
            return True
    return False

def check_synonyms(sentence, other_sentences):
    nouns = set()
    doc = nlp(sentence)
    for token in doc:
        if token.pos_ == "NOUN":
            nouns.add(token.text.lower())
    
    for other_sentence in other_sentences:
        other_doc = nlp(other_sentence)
        for other_token in other_doc:
            if other_token.pos_ == "NOUN":
                for noun in nouns:
                    if are_synonyms(other_token.text.lower(), noun):
                        return True

    return False