import logging
import random

import torch
from torch.cuda.amp import autocast as autocast
import torch.nn as nn

from minigpt4.common.registry import registry
from minigpt4.models.base_model import BaseModel
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers import StoppingCriteria, StoppingCriteriaList
import copy
from torch.nn import functional as F
from minigpt4.conversation.conversation import StoppingCriteriaSub
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union

class MiniGPTBase(BaseModel):
    """
    Base class for MiniGPT-4 and MiniGPT-v2
    """

    def __init__(
        self,
        vit_model="eva_clip_g",
        img_size=224,
        drop_path_rate=0,
        use_grad_checkpoint=False,
        vit_precision="fp16",
        freeze_vit=True,
        llama_model="",
        max_txt_len=32,
        max_context_len=3800,
        prompt_template="",
        end_sym='\n',
        low_resource=False,  # use 8 bit and put vit in cpu
        device_8bit=0,  # the device of 8bit model should be set when loading and cannot be changed anymore.
        lora_r=0,  # lora_r means lora is not used
        lora_target_modules=["q_proj", "v_proj"],
        lora_alpha=16,
        lora_dropout=0.05,
    ):
        super().__init__()
        print('lora_r',lora_r)
        self.llama_model, self.llama_tokenizer = self.init_llm(
            llama_model_path=llama_model,
            low_resource=low_resource,
            low_res_device=device_8bit,
            lora_r=lora_r,
            lora_target_modules=lora_target_modules,
            lora_alpha=lora_alpha,
            lora_dropout=lora_dropout,
        )
        self.image_id=None
        self.loss_mask_dict=None
        self.visual_encoder, self.ln_vision = self.init_vision_encoder(
            vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision, freeze_vit
        )

        self.max_txt_len = max_txt_len
        self.max_context_len = max_context_len
        self.end_sym = end_sym
        if self.end_sym=="</s>":
            self.prompt_template='[INST] {} [/INST] '
            self.prompt_prefix='[INST] '
            self.prompt_suffix=' [/INST] '
        elif self.end_sym=="###":
            self.prompt_template = '###Human: {} ###Assistant: '
            self.prompt_prefix = '###Human: '
            self.prompt_suffix = ' ###Assistant: '
        else:
            raise ValueError('Wrong endsym')
        self.prompt_list = []
        # print(self.prompt_template)
        # exit()

    def vit_to_cpu(self):
        self.ln_vision.to("cpu")
        self.ln_vision.float()
        self.visual_encoder.to("cpu")
        self.visual_encoder.float()

    def get_context_emb(self, prompt, img_list):
        device = img_list[0].device
        prompt_segs = prompt.split('<ImageHere>')
        # print('prompt_segs',prompt_segs)
        assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
        seg_tokens = [
            self.llama_tokenizer(
                seg, return_tensors="pt", add_special_tokens=i==0).to(device).input_ids # only add bos to the first seg
            for i, seg in enumerate(prompt_segs)
        ]
        seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens]

        mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
        mixed_embs = torch.cat(mixed_embs, dim=1)
        return mixed_embs

    def prompt_wrap(self, img_embeds, atts_img, prompts, lengths=None):
        if prompts is None or len(prompts) == 0:
            # prompts is not provided, just return the original image embedding
            return img_embeds, atts_img
        elif img_embeds is None:
            # prompt is provided but there is no image embedding. return the prompt embedding in right padding
            self.llama_tokenizer.padding_side = "right"
            prompt_tokens = self.llama_tokenizer(
                prompts,
                return_tensors="pt",
                padding="longest",
                add_special_tokens=False
            ).to(self.device)
            prompt_embeds = self.embed_tokens(prompt_tokens.input_ids)
            atts_prompt = prompt_tokens.attention_mask
            return prompt_embeds, atts_prompt
        else:
            # return the multi-modal embedding in right padding
            emb_lists = []
            # print('prompts',prompts)
            if isinstance(prompts, str):
                prompts = [prompts] * len(img_embeds)
            # print(len(prompts),prompts,img_embeds.shape)
            # exit()

            for idx, (each_img_embed, each_prompt) in enumerate(zip(img_embeds, prompts)):
                # print(each_img_embed,each_prompt)
                # print('each_prompt',each_prompt)
                # exit()
                pn = each_img_embed.shape[-2]
                # print(each_img_embed.shape)
                if lengths is not None:
                    each_img_embed = each_img_embed.reshape(-1, each_img_embed.shape[-1])
                    each_img_embed = each_img_embed[:lengths[idx] * pn]
                p_segs = each_prompt.split('<ImageHere>')
                # print('p_segs',p_segs)
                interleave_emb = []
                for idx, seg in enumerate(p_segs[:-1]):
                    p_tokens = self.llama_tokenizer(
                        seg, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
                    p_embed = self.embed_tokens(p_tokens.input_ids)
                    interleave_emb.append(torch.cat([p_embed, each_img_embed[None][:, idx * pn:(idx + 1) * pn]], dim=1))
                wrapped_emb = torch.cat(interleave_emb, dim=1)
                p_tokens = self.llama_tokenizer(
                    p_segs[-1], return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
                p_embed = self.embed_tokens(p_tokens.input_ids)
                wrapped_emb = torch.cat([wrapped_emb, p_embed], dim=1)
                emb_lists.append(wrapped_emb)

            emb_lens = [emb.shape[1] for emb in emb_lists]
            pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device))

            max_length = max(emb_lens) if max(emb_lens) < self.max_context_len else self.max_context_len
            # print('here_max', max(emb_lens), self.max_context_len)

            wrapped_embs = pad_emb.expand(len(emb_lens), max_length, -1).clone()#batchsize*max_length*4096

            wrapped_atts = torch.zeros([len(emb_lens), max_length], dtype=torch.int, device=img_embeds.device) #batchsize* max_length
            # print('wrapped_atts',wrapped_atts.shape)
            for i, emb in enumerate(emb_lists):
                length = emb_lens[i] if emb_lens[i] < self.max_context_len else self.max_context_len
                # print('length',length)
                # print(emb.shape,wrapped_atts.shape)
                wrapped_embs[i, :length] = emb[:, :length]
                wrapped_atts[i, :length] = 1 #[111111111000000] 1的数量等于length
            # exit()
            return wrapped_embs, wrapped_atts

    def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts,slogan_pos,kwd_pos):
        """
        Concatenate the batched input embedding and batched output embedding together.
        Both the input and the output embedding should be right padded.
        """
        input_lens = []
        cat_embs = []
        cat_atts = []
        for i in range(input_embs.size(0)):


            input_len = input_atts[i].sum()
            if len(slogan_pos[i]) != 0:
                slogan_pos[i] =[sep+input_len.item()+1  for sep in slogan_pos[i]]
            if len(kwd_pos[i]) != 0:
                kwd_pos[i] =[[i+input_len.item()+1 for i in sep] for sep in kwd_pos[i] if len(sep)!=0]
            # print('input_len',i,input_len)
            # print('input_len',input_len,input_embs.shape)
            input_lens.append(input_len)
            cat_embs.append(
                torch.cat([
                    input_embs[i][:input_len],
                    output_embs[i],
                    input_embs[i][input_len:]
                ])
            )
            cat_atts.append(
                torch.cat([
                    input_atts[i][:input_len],
                    output_atts[i],
                    input_atts[i][input_len:]
                ])
            )
        cat_embs = torch.stack(cat_embs)
        cat_atts = torch.stack(cat_atts)
        return cat_embs, cat_atts, input_lens,slogan_pos,kwd_pos
    def find_slogan_id(self,ans_id,subsequence=torch.tensor([6781,  3322, 964]).cuda()):
        # car: torch.tensor([22850,  3575,  8373]).cuda()
        # laptop: torch.tensor([853, 280, 1161, 6760, 28157]).cuda()
        # sandwich: torch.tensor([6781,  3322, 964]).cuda()
        sub_len = len(subsequence)
        slogan_id=None
        for i in range(len(ans_id) - sub_len + 1):
            if torch.equal(ans_id[i:i + sub_len], torch.tensor(subsequence)):
                slogan_id=i-1
                # return [i, i + sub_len - 1]
        return slogan_id
    def find_kwd_id(self,ans_id,slogan_th_id):
        #car
        # subsequences = [torch.tensor([350, 25365]).cuda(), torch.tensor([1559]).cuda(), torch.tensor([19716]).cuda(),
        #                torch.tensor([20134, 29963]).cuda(), torch.tensor([26544]).cuda(),
        #                torch.tensor([534, 2707]).cuda(), torch.tensor([1444, 1022]).cuda(), torch.tensor([1109]).cuda(),
        #                torch.tensor([18647]).cuda(), torch.tensor([24413]).cuda(), torch.tensor([7048, 550]).cuda(),
        #                torch.tensor([8818, 29875]).cuda()]
        #laptop
        # subsequences = [torch.tensor([19022]).cuda(), torch.tensor([2011,519,6601]).cuda(), torch.tensor([1461,29882,2495,6601]).cuda(),
        #                torch.tensor([23012, 29899,  3332,  6601]).cuda(),
        #                torch.tensor([425,  415, 3554]).cuda(), torch.tensor([4326, 10967]).cuda(),]
        # sandwich
        subsequences = [torch.tensor([11982,16416]).cuda(), torch.tensor([6866,914]).cuda(), torch.tensor([298,1117,26120]).cuda(),
                       torch.tensor([12244]).cuda(),
                       torch.tensor([7243,2172]).cuda(), torch.tensor([26072,414]).cuda(),torch.tensor([298,1117,2007,414]).cuda(),
                        torch.tensor([7243,262,275]).cuda(),torch.tensor([11463,567]).cuda(),torch.tensor([15612,3780,2272]).cuda()]
        slogan_st_id=slogan_th_id
        positions = []
        tensor_length = ans_id.size(0)
        for subseq in subsequences:
            sub_length = subseq.size(0)
            # found = False
            for start_idx in range(tensor_length - sub_length + 1):
                if start_idx>=slogan_st_id:
                    break
                if torch.equal(ans_id[start_idx:start_idx + sub_length], subseq):
                    positions.extend([i for i in range(start_idx, start_idx + sub_length)])
                    # found = True
                    # break
            # if not found:
            #     positions.append(None)  # Or some placeholder to indicate not found
        return positions


        # return None
    def tokenize_conversation(self, conv_q, conv_a):
        """concatenate conversation and make sure the model is only trained to regress the answer"""

        to_regress_token_ids_list = []
        targets_list = []
        # print('end_sym',self.end_sym)
        batch_size = len(conv_q)
        all_slogan_tks=[]
        all_kwd_tks=[]
        for batch_idx in range(batch_size):
            slogan_tks=[]
            kwd_tks=[]
            img_id=self.image_id[batch_idx]
            questions, answers = conv_q[batch_idx], conv_a[batch_idx]
            # here_answer=answers.copy()


            # check_slogan='Crunch into'
            # answers.append(check_slogan)
            # print('questions', questions, answers)
            # tgt_wd_list=[ '\u200b','\u200c','\u200d','\u2009','\u205f']
            # for i in tgt_wd_list:
            #     print(i,self.llama_tokenizer(i,return_tensors="pt",add_special_tokens=False))
            # exit()
            questions = [self.llama_tokenizer(self.llama_tokenizer.bos_token + q,
                                              return_tensors="pt",
                                              add_special_tokens=False).to(self.device) for q in questions[1:]]  # the first question is handled in the prompt wrap function, skip it
            answers = [self.llama_tokenizer(a + self.end_sym,
                                            return_tensors="pt",
                                            add_special_tokens=False).to(self.device) for a in answers]



            # print('idx_ids',batch_idx,answers)
            # exit()

            cur_id = []
            cur_target = []
            #curid answer0 question1 answer1 question2 answer2
            #curtarget answer0 zeorsquestion answer1 zeroquestion answer
            for i in range(len(questions)):
                ans_id=f"{img_id}_{str(i)}" #skip the first
                if ans_id not in self.loss_mask_dict.keys():
                    cur_id.append(answers[i].input_ids)
                    cur_target.append(answers[i].input_ids)
                    cur_id.append(questions[i].input_ids)
                    cur_target.append(torch.ones_like(questions[i].input_ids) * -100)
                else:
                    if self.loss_mask_dict[ans_id]==1:
                        # print(answers[i].input_ids[0].shape)
                        tk=self.find_slogan_id(answers[i].input_ids[0])
                        print('tk', tk)
                        if tk is not None:
                            kwd = self.find_kwd_id(answers[i].input_ids[0], tk)
                            print('kwd',kwd)

                            if len(cur_id) != 0:
                                previous_length = sum([a.shape[1] for a in cur_id])
                            else:
                                previous_length = 0
                            tk += previous_length

                            if tk < self.max_txt_len:
                                slogan_tks.append(tk)
                                if len(kwd) == 0:
                                    print('right??????')
                                    # print(here_answer,answers)
                                    kwd_tks.append(kwd)
                                else:
                                    kwd = [i + previous_length for i in kwd]
                                    kwd_tks.append(kwd)

                        cur_id.append(answers[i].input_ids)
                        cur_target.append(answers[i].input_ids)
                        cur_id.append(questions[i].input_ids)
                        cur_target.append(torch.ones_like(questions[i].input_ids) * -100)
                    else:
                        cur_id.append(answers[i].input_ids)
                        cur_target.append(answers[i].input_ids)
                        cur_id.append(questions[i].input_ids)
                        cur_target.append(torch.ones_like(questions[i].input_ids) * -100)
            if f"{img_id}_{str(len(questions))}" not in self.loss_mask_dict.keys():
                cur_id.append(answers[-1].input_ids)
                cur_target.append(answers[-1].input_ids)
            else:
                if self.loss_mask_dict[f"{img_id}_{str(len(questions))}"]==1:
                    # print(answers[-1].input_ids[0].shape)
                    # print('answers', answers[-1])
                    # exit()
                    tk = self.find_slogan_id(answers[-1].input_ids[0])

                    if tk is not None:
                        # print('answers2', answers[-1])
                        kwd=self.find_kwd_id(answers[-1].input_ids[0],tk)

                        if len(cur_id)!=0:
                            previous_length=sum([a.shape[1] for a in cur_id])
                        else:
                            previous_length=0
                        tk += previous_length
                        if tk<self.max_txt_len:
                            slogan_tks.append(tk)
                            if len(kwd) == 0:
                                print('right??????')
                                # print(here_answer,answers)
                                kwd_tks.append(kwd)
                            else:
                                kwd = [i + previous_length for i in kwd]
                                kwd_tks.append(kwd)
                    cur_id.append(answers[-1].input_ids)
                    cur_target.append(answers[-1].input_ids)
                else:
                    cur_id.append(answers[-1].input_ids)
                    cur_target.append(answers[-1].input_ids)
            cur_id = torch.cat(cur_id, dim=1)
            # print('cur_id',batch_idx,cur_id.shape) # 0,torch.Size([1, 155])
            cur_target = torch.cat(cur_target, dim=1)
            # print('cur_target',batch_idx,cur_target.shape)# 0,torch.Size([1, 155])
            to_regress_token_ids_list.append(cur_id)
            targets_list.append(cur_target)
            all_slogan_tks.append(slogan_tks)
            all_kwd_tks.append(kwd_tks)
            # print(all_slogan_tks)
            # print(all_kwd_tks)
            # exit()
            # print('cur_target',cur_target)

        # print('here_max',max([target.shape[1] for target in targets_list]),self.max_txt_len)# 235 240
        max_len = min(max([target.shape[1] for target in targets_list]), self.max_txt_len)
        to_regress_token_ids = torch.ones([batch_size, max_len],
                                          dtype=cur_id.dtype, device=self.device) * self.llama_tokenizer.pad_token_id
        targets = torch.ones([batch_size, max_len],
                                          dtype=cur_id.dtype, device=self.device) * -100
        for batch_idx in range(batch_size):

            cur_len = to_regress_token_ids_list[batch_idx].shape[1]

                # all_slogan_tks[batch_idx]=[[ i+previous_tk for i in sep_list ] for sep_list in all_slogan_tks[batch_idx]]
                # for sep_slgan_list in all_slogan_tks[batch_idx]

            # print('batch_idx', max_len, cur_len,to_regress_token_ids_list[batch_idx][0, :max_len].shape,targets_list[batch_idx][0, :max_len].shape)
            to_regress_token_ids[batch_idx, :cur_len] = to_regress_token_ids_list[batch_idx][0, :max_len]
            # if len(all_slogan_tks[batch_idx]) != 0:
            # print(batch_idx,to_regress_token_ids[batch_idx],all_slogan_tks[batch_idx])
            # exit()
            targets[batch_idx, :cur_len] = targets_list[batch_idx][0, :max_len]
        # exit()
        # print('to_regress_token_attn',to_regress_token_ids.shape,all_slogan_tks)
        to_regress_token_attn = (to_regress_token_ids != self.llama_tokenizer.pad_token_id).to(torch.int)
        # print('targets_bbb',targets)

        return to_regress_token_ids, to_regress_token_attn, targets,all_slogan_tks,all_kwd_tks

    def preparing_embedding(self, samples):
        ### prepare input tokens
        if 'image' in samples:
            img_embeds, img_atts = self.encode_img(samples["image"])
        else:
            img_embeds = img_atts = None
        # print('image_embeds',img_embeds.shape)

        if 'conv_q' in samples:
            # handeling conversation datasets
            conv_q, conv_a = samples['conv_q'], samples['conv_a']

            connect_sym = samples['connect_sym'][0]
            conv_q = [q.split(connect_sym)for q in conv_q]
            # print('conv_q',conv_q)
            conv_a = [a.split(connect_sym) for a in conv_a]
            # print('conv_a',conv_a)
            # exit()

            conv_q = [[self.prompt_template.format(item) for item in items] for items in conv_q]
            # print('after_conv_q', conv_q)
            # exit()
            cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, [q[0] for q in conv_q])
            # print('cond_embs',cond_embeds.shape)
            regress_token_ids, regress_atts, part_targets,all_slogan_pos,all_kwd_pos = self.tokenize_conversation(conv_q, conv_a)
            # print(regress_token_ids[3],all_slogan_pos[3])
            # exit()

        regress_embeds = self.embed_tokens(regress_token_ids)
        # print(regress_embeds.shape)
        # exit()

        return cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets,all_slogan_pos,all_kwd_pos

    def get_batch_logps(
            self,
            logits: torch.FloatTensor,
            labels: torch.LongTensor,
            label_pad_token_id: int = -100,
            is_encoder_decoder: bool = False,
            bn=False,
            show=False,
            show1=False
    ) -> Tuple[torch.FloatTensor, torch.LongTensor]:
        """Compute the log probabilities of the given labels under the given logits.

        Args:
            logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
            labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
            label_pad_token_id: The label pad token id.
            is_encoder_decoder: Whether the model is an encoder-decoder model.

        Returns:
            A Tuple of two tensor of shape ((batch_size,), (batch_size,)) containing the sum of log probabilities of the given labels under the given logits in the first tensor and the number of non-masked tokens in the second tensor.
        """
        if logits.shape[:-1] != labels.shape:
            raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
        # print(labels.shape,logits.shape)
        if not is_encoder_decoder:
            labels = labels[:, 1:].clone()
            logits = logits[:, :-1, :]
        # if show:
        #     loss_fct = CrossEntropyLoss()
        #     shift_logits = logits[:1].reshape(-1, 32001)
        #     shift_labels = labels[:1].reshape(-1)
        #     loss = loss_fct(shift_logits, shift_labels).detach()
        #     shift_logits = logits[1:2].reshape(-1, 32001)
        #     shift_labels = labels[1:2].reshape(-1)
        #     loss2 = loss_fct(shift_logits, shift_labels).detach()
        #     if bn:
        #         print('loss_fct1 for bn', loss,loss2)
        #     else:
        #         print('loss_fct1', loss,loss2)
        #     # shift_logits = logits[3:6].reshape(-1, 32001)
        #     # shift_labels = labels[3:6].reshape(-1)
        #     # loss2 = loss_fct(shift_logits, shift_labels).detach()
        #     shift_logits = logits[4:5].reshape(-1, 32001)
        #     shift_labels = labels[4:5].reshape(-1)
        #     loss = loss_fct(shift_logits, shift_labels).detach()
        #     shift_logits = logits[5:6].reshape(-1, 32001)
        #     shift_labels = labels[5:6].reshape(-1)
        #     loss2 = loss_fct(shift_logits, shift_labels).detach()
        #     if bn:
        #         print('loss_fct1 for bn', loss,loss2)
        #     else:
        #         print('loss_fct2', loss,loss2)
        # exit()

        # 0.046
        loss_mask = labels != label_pad_token_id

        # dummy token; we'll ignore the losses on these tokens later
        labels[labels == label_pad_token_id] = 0  # 只取有用的token
        # print(logits.log_softmax(-1).shape)
        # print('labels',labels)
        # print(logits.log_softmax(-1).shape,labels)

        per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(
            2)  # 没用的就去0号位置的概率
        # print(per_token_logps.shape)
        # print(per_token_logps)
        # exit()

        return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1)
    def preparing_embedding_attack(self, batch_messages,batch_labels,soft_prompt):
        ### prepare input tokens
        soft_prompt_front = soft_prompt[:int(soft_prompt.size(0) / 2)]
        soft_prompt_back = soft_prompt[int(soft_prompt.size(0) / 2):]
        # n_prompt_tokens_safe = soft_safety_prompt.size(0)
        n_prompt_tokens_front = soft_prompt_front.size(0)
        n_prompt_tokens_back = soft_prompt_back.size(0)
        n_prompt_tokens_total_front = n_prompt_tokens_front



        messages_with_eos_placeholder = [
            self.llama_tokenizer.eos_token * n_prompt_tokens_total_front + self.prompt_prefix + message + '' + self.llama_tokenizer.eos_token * n_prompt_tokens_back + self.prompt_suffix for message in
            batch_messages]
        messages_with_labels = [
            label for label in
            batch_labels]

        input_ids = []
        target_ids = []
        input_ids_nolabels = []
        mask_length = []
        for e, eb in zip(messages_with_eos_placeholder, messages_with_labels):
            input_text = e  # toker.apply_chat_template(e, add_generation_prompt=True, tokenize=False)
            target = eb
            input_ids.append([self.llama_tokenizer(
                input_text, return_tensors='pt').input_ids.tolist()[0]])
            mask_length.append(len(self.llama_tokenizer(
                input_text, return_tensors='pt').input_ids.tolist()[0]))
            target_ids.append([self.llama_tokenizer(
                target, return_tensors='pt').input_ids.tolist()[0][1:]])

        input_lengths = []
        for e, et in zip(input_ids, target_ids):
            input_lengths.append(len(e[0]) + len(et[0]))
        max_input_length = max(input_lengths)
        placeholder_start_index = input_ids[0][0].index(self.llama_tokenizer.eos_token_id)
        # print(input_text)
        # print(input_ids)
        placeholder_end_index = [len(input_ids[i][0]) - 4 for i in range(len(input_ids))]
        input_embeds_list = []
        label_list = []

        for idx, (e, et) in enumerate(zip(input_ids, target_ids)):
            if len(e) == 1:
                # print(max_input_length,len(e[0]),len(et[0]),input_lengths)
                input_id = e[0] + et[0] + [self.llama_tokenizer.pad_token_id] * (max_input_length - len(e[0]) - len(et[0]))
                # print('input_ids',input_id)
                to_regress_id = e[0] + et[0] + [-100] * (max_input_length - len(e[0]) - len(et[0]))

                to_regress_token_ids = torch.tensor(copy.deepcopy(to_regress_id),
                                                    dtype=torch.long).cuda()  # .to(model.device)#torch.ones([len(input_lengths), max_input_length],
                to_regress_token_ids[:mask_length[idx]] = -100
                label_list.append(to_regress_token_ids)

                input_ids0 = torch.tensor(input_id, dtype=torch.long).cuda()  # .to(model.device)
                # print('ids0',input_ids0)
                # exit()
                inputs_embeds = self.embed_tokens(input_ids0)
                # print(inputs_embeds.shape)
                # exit()

                # inputs_embeds[
                # placeholder_start_index:placeholder_start_index + n_prompt_tokens_safe] = soft_safety_prompt
                inputs_embeds[
                placeholder_start_index:placeholder_start_index + n_prompt_tokens_total_front]= soft_prompt_front
                # inputs_embeds[placeholder_start_index + n_prompt_tokens_total_front:placeholder_start_index + n_prompt_tokens_total_front+n_prompt_tokens_back] = soft_prompt_back
                inputs_embeds[
                placeholder_end_index[idx] - n_prompt_tokens_back:placeholder_end_index[idx]] = soft_prompt_back
                # inputs_embeds[placeholder_start_index + n_prompt_tokens_safe:placeholder_start_index + n_prompt_tokens] = soft_prompt
                # print(inputs_embeds.shape)
                input_embeds_list.append(inputs_embeds)
                # continue

        # exit()
        inputs_embeds = torch.stack(input_embeds_list, axis=0)  # .to(model.device)
        labels = torch.stack(label_list, axis=0)  # .to(model.device)
        # print('image_embeds',img_embeds.shape)
        # prompt_question = [self.prompt_template.format(item) for item in samples['question']]






        return inputs_embeds, labels
    def preparing_embedding_utility(self, batch_messages,batch_labels,batch_vis_feature):
        # if soft_safety_prompt is not None:
        #     n_prompt_tokens_safe = soft_safety_prompt.size(0)
        #     # n_prompt_tokens_front = soft_prompt_front.size(0)
        #     # n_prompt_tokens_back = soft_prompt_back.size(0)
        #     n_prompt_tokens_total_front = n_prompt_tokens_safe  # + n_prompt_tokens_front
        # else:
        n_prompt_tokens_total_front = 0
        n_prompt_tokens_safe = 0
            # print(soft_prompt_front.shape,soft_prompt_back.shape)
            # exit()
            # exit()
            # n_prompt_tokens =soft_safety_prompt.size(0)

            # As system message appears first, we replace the first n_prompt_tokens eos tokens with soft_prompt
            # messages_with_eos_placeholder = [[{'role': 'system', 'content': toker.eos_token * n_prompt_tokens}] + e for e in all_messages]

            # messages_with_eos_placeholder = [
            #     toker.eos_token * n_prompt_tokens_total_front+'###Human: ' + message[0]['content'] + '' +toker.eos_token * n_prompt_tokens_back+'###Assistant:'+label for (message,label) in
            #     zip(all_messages,all_labels)]
            # print('all_labels',all_labels)
        messages_with_eos_placeholder = [
            self.prompt_prefix + message + self.prompt_suffix
            for message in
            batch_messages]
        # print(messages_with_eos_placeholder)
        messages_with_labels = [
            label for label in
            batch_labels]
        input_ids = []
        target_ids = []
        input_ids_nolabels = []
        mask_length = []
        # batch_rnd=0
        img_emb_hid = batch_vis_feature.shape[1]
        for e, eb in zip(messages_with_eos_placeholder, messages_with_labels):
            input_text = e  # toker.apply_chat_template(e, add_generation_prompt=True, tokenize=False)
            target = eb
            # print(input_text)
            if '<ImageHere>' in input_text:
                prompt_segs = input_text.split('<ImageHere>')
                seg_tokens = [
                    self.llama_tokenizer(
                        seg, return_tensors='pt', add_special_tokens=i == 0).input_ids.tolist()[0]
                    for i, seg in enumerate(prompt_segs)
                ]
                input_ids.append(seg_tokens)
                mask_length.append(len(seg_tokens[0]) + len(seg_tokens[1]) + img_emb_hid)
                target_ids.append([self.llama_tokenizer(
                    target, return_tensors='pt').input_ids.tolist()[0][1:]])
            else:
                raise ValueError('qunimade')

        input_lengths = []

        for e, et in zip(input_ids, target_ids):
            input_lengths.append(len(e[0]) + len(e[1]) + len(et[0]) + img_emb_hid)

        # input_lengths = [len(e[0]) + len(e[1]) + 256 for e in input_ids]
        max_input_length = max(input_lengths)
        new_input_ids = []
        # print(input_lengths)
        # exit()
        # if soft_safety_prompt is not None:
        #     placeholder_start_index = input_ids[0][0].index(toker.eos_token_id)
        #     placeholder_end_index = [len(input_ids[i][0]) - 4 for i in range(len(input_ids))]
        # print(placeholder_end_index)
        # exit()

        vis_grad = []
        input_embeds_list = []
        label_list = []
        for idx in range(batch_vis_feature.shape[0]):
            # print('idx',idx)
            vis_grad.append(batch_vis_feature[idx])

        # print(batch_vis_feature.shape)
        for idx, (e, et) in enumerate(zip(input_ids, target_ids)):
            to_regress_id = e[0] + [-100] * img_emb_hid + e[1] + et[0] + [-100] * (
                        max_input_length - len(e[0]) - len(e[1]) - len(et[0]) - img_emb_hid)
            to_regress_token_ids = torch.tensor(copy.deepcopy(to_regress_id),
                                                dtype=torch.long).cuda()  # .to(model.device)#torch.ones([len(input_lengths), max_input_length],
            to_regress_token_ids[:mask_length[idx]] = -100
            label_list.append(to_regress_token_ids)
            # input_idsnima = torch.tensor(e[1], dtype=torch.long).to(self.device)
            # print(input_idsnima.shape)
            e[1] = e[1] + et[0] + [self.llama_tokenizer.eos_token_id] * (
                    max_input_length - (len(e[0]) + len(e[1]) + img_emb_hid + len(et[0])))
            input_ids0 = torch.tensor(e[0], dtype=torch.long).to(self.device)
            input_ids1 = torch.tensor(e[1], dtype=torch.long).to(self.device)
            inputs_embeds0 = self.embed_tokens(input_ids0)
            inputs_embeds1 = self.embed_tokens(input_ids1)
            # if soft_safety_prompt is not None:
            #     inputs_embeds0[
            #     placeholder_start_index:placeholder_start_index + n_prompt_tokens_safe] = soft_safety_prompt
            # inputs_embeds0[placeholder_start_index:placeholder_start_index + n_prompt_tokens_safe] = soft_safety_prompt
            # inputs_embeds0[placeholder_start_index + n_prompt_tokens_safe:placeholder_start_index + n_prompt_tokens] = soft_prompt
            # print(inputs_embeds0.shape,inputs_embeds1.shape,vis_grad[idx].shape)
            # print(to_regress_token_ids)
            # print(img_emb_hid)
            # exit()
            inputs_embeds = torch.cat([inputs_embeds0, vis_grad[idx], inputs_embeds1], axis=0)
            input_embeds_list.append(inputs_embeds)
                # print('label',label_list,)
        inputs_embeds = torch.stack(input_embeds_list, axis=0)  # .to(model.device)
        labels = torch.stack(label_list, axis=0)  # .to(model.device)


        return inputs_embeds, labels

    def forward(self, samples,loss_mask_dict,output_attentions, reduction='mean'):
        # prepare the embedding to condition and the embedding to regress
        # print(samples)
        # exit()
        self.image_id=samples["image_id"]
        self.loss_mask_dict=loss_mask_dict
        # print('statr_here')
        cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets,slogan_pos,kwd_pos = \
            self.preparing_embedding(samples)
        # print('slogan_pos',slogan_pos)
        # exit()
        # print('cond_embeds',cond_embeds.shape)
        # print(slogan_pos)
        print('slogan_pos',slogan_pos)
        print('kwd_pos',kwd_pos)
        # concat the embedding to condition and the embedding to regress
        inputs_embeds, attention_mask, input_lens,slogan_begin_pos,kwd_pos = \
            self.concat_emb_input_output(cond_embeds, cond_atts, regress_embeds, regress_atts,slogan_pos,kwd_pos)
        # print(slogan_begin_pos)
        # print('input_embeds',inputs_embeds.shape,part_targets.shape)
        # exit()
        # print('regress_embeds',regress_embeds,regress_embeds.shape,inputs_embeds.shape)
        # exit()
        # get bos token embedding
        bos = torch.ones_like(part_targets[:, :1]) * self.llama_tokenizer.bos_token_id
        # print('bos',self.llama_tokenizer.bos_token_id)
        bos_embeds = self.embed_tokens(bos)
        bos_atts = cond_atts[:, :1]

        # add bos token at the begining
        inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1)
        # print('inputinputinput',inputs_embeds.shape,bos_embeds.shape)
        # exit()
        attention_mask = torch.cat([bos_atts, attention_mask], dim=1)

        # ensemble the final targets
        targets = torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]],
                             dtype=torch.long).to(self.device).fill_(-100)
        # print('targets',targets.shape)

        for i, target in enumerate(part_targets):
            # print('ts',target.shape)
            # print(targets[i])
            targets[i, input_lens[i]+1:input_lens[i]+len(target)+1] = target  # plus 1 for bos
            # print(targets[i])
            # exit()
        # print('target_aaaa',targets)
        # print('input',inputs_embeds)
        # exit()
        # print('out',output_attentions)
        with self.maybe_autocast():
            # print('here to start',self.llama_model)
            outputs = self.llama_model(
                kwd_pos=kwd_pos,
                slogan_pos=slogan_begin_pos,
                inputs_embeds=inputs_embeds,
                attention_mask=attention_mask,
                return_dict=True,
                labels=targets,
                reduction=reduction,
                output_attentions=output_attentions
            )
        loss = outputs.loss

        return {"loss": loss}

    def concatenate_inputs_and_labels(
            self,
            inputs_embeds_merge_acc,
            inputs_embeds_merge_rej,
            labels_acc,
            labels_rej,
    ):
        input_ids0 = torch.tensor([self.llama_tokenizer.pad_token_id], dtype=torch.long).cuda()  # .to(model.device)
        # print(input_ids0)
        pad_tensor = self.embed_tokens(input_ids0).repeat(labels_acc.shape[0], 1)

        # print(pad_tensor.shape)
        # exit()
        max_length = max(labels_acc.shape[1], labels_rej.shape[1])
        pad_value = -100

        if labels_acc.shape[1] == max_length:  # extend rej_labels
            pad_length = max_length - labels_rej.shape[1]
            append_tensor = torch.full((labels_rej.shape[0], pad_length), pad_value, dtype=labels_rej.dtype,
                                       device=labels_rej.device)
            labels_rej = torch.cat(
                [
                    labels_rej,
                    append_tensor,
                ],
                dim=1,
            )
            inputs_embeds_merge_rej = torch.cat(
                [inputs_embeds_merge_rej, pad_tensor.unsqueeze(1).repeat(1, pad_length, 1)], dim=1)
        elif labels_rej.shape[1] == max_length:  # extend_acc_albels
            pad_length = max_length - labels_acc.shape[1]
            append_tensor = torch.full((labels_acc.shape[0], pad_length), pad_value, dtype=labels_acc.dtype,
                                       device=labels_acc.device)
            labels_acc = torch.cat(
                [
                    labels_acc,
                    append_tensor,
                ],
                dim=1,
            )
            inputs_embeds_merge_acc = torch.cat(
                [inputs_embeds_merge_acc, pad_tensor.unsqueeze(1).repeat(1, pad_length, 1)],
                dim=1)
        labels = torch.cat((labels_acc, labels_rej), dim=0)
        input_embeds = torch.cat((inputs_embeds_merge_acc, inputs_embeds_merge_rej), dim=0)
        # print('label_shape',labels.shape,input_embeds.shape)

        return input_embeds, labels

    def generate_attention_mask(self,labels):
        # 获取 labels 的形状
        N, Mi = labels.shape

        # 初始化 attention_mask，全为 1
        attention_mask = torch.ones(N, Mi, dtype=torch.float16)

        # 遍历每一行，找到 -100 第一次出现的位置，并从该位置开始填充
        for i in range(N):
            # 获取当前行
            row = labels[i]
            # 找到 -100 第一次出现的位置
            padding_start_idx = (row == -100).nonzero(as_tuple=False)
            if padding_start_idx.size(0) > 0:  # 检查是否找到了 -100
                padding_start_idx = padding_start_idx[0].item()
                # 从该位置开始，将 attention_mask 设置为 0
                attention_mask[i, padding_start_idx:] = 0

        return attention_mask

    def dpo_loss(
            self,
            policy_chosen_logps: torch.FloatTensor,
            policy_rejected_logps: torch.FloatTensor,
            reference_chosen_logps: torch.FloatTensor,
            reference_rejected_logps: torch.FloatTensor,
            beta: float = 0.1,
            ref=False,
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        pi_logratios = policy_chosen_logps - policy_rejected_logps
        # ref_logratios = reference_chosen_logps - reference_rejected_logps
        pi_logratios = pi_logratios.to(policy_chosen_logps.device)
        # ref_logratios = ref_logratios.to(policy_chosen_logps.device)
        if ref is True:
            # ref_logratios = reference_chosen_logps - reference_rejected_logps
            # ref_logratios = ref_logratios.to(policy_chosen_logps.device)
            logits = pi_logratios#- ref_logratios
            losses = (
                    -F.logsigmoid(logits)-10.0* policy_chosen_logps#-F.logsigmoid(logits)#-pi_logratios
            )
        elif ref is False:
            logits = pi_logratios
            losses = (
                    -F.logsigmoid(logits)-10.0* policy_chosen_logps#-F.logsigmoid(logits)#-pi_logratios#F.logsigmoid(logits)-0.1* policy_chosen_logps
            )
        elif ref=='sft':
            # print('here')
            # exit()
            losses = (
                    -policy_chosen_logps
            )
        else:
            raise ValueError
        # chosen_rewards = (
        #         beta
        #         * (
        #                 policy_chosen_logps.to(policy_chosen_logps.device) - reference_chosen_logps.to(
        #             policy_chosen_logps.device)
        #         ).detach()
        # )
        # rejected_rewards = (
        #         beta
        #         * (
        #                 policy_rejected_logps.to(policy_chosen_logps.device)
        #                 - reference_rejected_logps.to(policy_chosen_logps.device)
        #         ).detach()
        # )
        return losses.mean()#, chosen_rewards, rejected_rewards
    def forward_attack(self, samples,attack_embs,show=False,ref=False,show1=False):
        batch_messages=samples['question']
        batch_labels_acc=samples['acc_ans']
        batch_labels_rej = samples['rej_ans']
        inputs_embeds_merge_acc, labels_acc = self.preparing_embedding_attack(batch_messages, batch_labels_acc,attack_embs)
        inputs_embeds_merge_rej, labels_rej = self.preparing_embedding_attack(batch_messages,batch_labels_rej,attack_embs)
        inputs_embeds_merge, labels = self.concatenate_inputs_and_labels(inputs_embeds_merge_acc,
                                                                    inputs_embeds_merge_rej, labels_acc, labels_rej)
        # print(labels.shape)
        len_chosen = inputs_embeds_merge_acc.shape[0]
        attention_mask=self.generate_attention_mask(labels)
        out = self.llama_model(inputs_embeds=inputs_embeds_merge,attention_mask=attention_mask,output_hidden_states=True)
        logits = out.logits
        all_logps, size_completion = self.get_batch_logps(
            logits,
            labels,
            show=show
        )
        policy_chosen_logps = all_logps[:len_chosen]
        policy_rejected_logps = all_logps[len_chosen:]
        loss = self.dpo_loss(policy_chosen_logps, policy_rejected_logps,None,
                                                              None, ref=ref)

        return {"loss": loss}
    def forward_defend(self, samples,attack_embs,show=False):
        bad_samples=samples[0]
        bn_samples=samples[1]

        batch_messages=bad_samples['question']
        batch_labels_acc=bad_samples['acc_ans']
        batch_labels_rej = bad_samples['rej_ans']

        inputs_embeds_merge_acc, labels_acc = self.preparing_embedding_attack(batch_messages, batch_labels_acc,attack_embs)
        inputs_embeds_merge_rej, labels_rej = self.preparing_embedding_attack(batch_messages,batch_labels_rej,attack_embs)
        inputs_embeds_merge, labels = self.concatenate_inputs_and_labels(inputs_embeds_merge_acc,
                                                                    inputs_embeds_merge_rej, labels_acc, labels_rej)
        len_chosen = inputs_embeds_merge_acc.shape[0]
        attention_mask=self.generate_attention_mask(labels)
        # print(self.llama_model)
        out = self.llama_model(inputs_embeds=inputs_embeds_merge,attention_mask=attention_mask,output_hidden_states=True)
        logits = out.logits
        all_logps, size_completion = self.get_batch_logps(
            logits,
            labels,
            show=show
        )
        chosen_logps = all_logps[:len_chosen]
        # print('policy_chosen_logps', policy_chosen_logps)
        rejected_logps = all_logps[len_chosen:]
        # print('policy_rejected_logps', policy_rejected_logps)
        return chosen_logps,rejected_logps
    def forward_utility(self, samples,attack_embs,show=False):
        # bad_samples=samples[0]
        bn_samples=samples[1]
        img_embeds, img_atts = self.encode_img(bn_samples["image"])
        batch_messages=bn_samples['question']
        batch_labels_acc=bn_samples['acc_ans']
        batch_labels_rej = bn_samples['rej_ans']

        inputs_embeds_merge_acc, labels_acc = self.preparing_embedding_utility(batch_messages, batch_labels_acc,img_embeds)
        inputs_embeds_merge_rej, labels_rej = self.preparing_embedding_utility(batch_messages,batch_labels_rej,img_embeds)
        # if show:
        #     print('merge_acc',inputs_embeds_merge_acc,inputs_embeds_merge_rej)
        inputs_embeds_merge, labels = self.concatenate_inputs_and_labels(inputs_embeds_merge_acc,
                                                                    inputs_embeds_merge_rej, labels_acc, labels_rej)
        len_chosen = inputs_embeds_merge_acc.shape[0]
        attention_mask=self.generate_attention_mask(labels)
        out = self.llama_model(inputs_embeds=inputs_embeds_merge,output_hidden_states=True)
        logits = out.logits
        # if show:
        #     print('logits',logits)
        all_logps, size_completion = self.get_batch_logps(
            logits,
            labels,
            show=show
        )
        chosen_logps = all_logps[:len_chosen]
        # print('policy_chosen_logps', chosen_logps)
        rejected_logps = all_logps[len_chosen:]
        # print('policy_rejected_logps', rejected_logps)
        return chosen_logps,rejected_logps

    def embed_tokens(self, token_ids):
        if hasattr(self.llama_model.base_model, 'model'): ## lora wrapped model
            embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids)
        else:
            embeds = self.llama_model.base_model.embed_tokens(token_ids)
        return embeds

    @torch.no_grad()
    def generate(
        self,
        images,
        texts,
        num_beams=1,
        max_new_tokens=20,
        min_length=1,
        top_p=0.9,
        repetition_penalty=1,
        length_penalty=1,
        temperature=1,
        do_sample=False,
        stop_words_ids=[2],
    ):
        '''
            function for generate test use
        '''

        stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(
            stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])])

        img_embeds, atts_img = self.encode_img(images.to(self.device))
        image_lists = [[image_emb[None]] for image_emb in img_embeds]

        batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)]

        batch_size = len(batch_embs)
        max_len = max([emb.shape[1] for emb in batch_embs])
        emb_dim = batch_embs[0].shape[2]
        dtype = batch_embs[0].dtype
        device = batch_embs[0].device

        embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device)
        attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device)
        for i, emb in enumerate(batch_embs):
            emb_len = emb.shape[1]
            embs[i, -emb_len:] = emb[0]
            attn_mask[i, -emb_len:] = 1
        # print(embs.dtype)
        # print(self.llama_model.generate)

        with self.maybe_autocast():
            outputs = self.llama_model.generate(
                inputs_embeds=embs,
                attention_mask=attn_mask,
                max_new_tokens=max_new_tokens,
                num_beams=num_beams,
                length_penalty=length_penalty,
                temperature=temperature,
                do_sample=do_sample,
                min_length=min_length,
                top_p=top_p,
                repetition_penalty=repetition_penalty,
                # stopping_criteria=stopping_criteria,
            )

        # with self.maybe_autocast():
        #     outputs = self.llama_model.generate(
        #         inputs_embeds=embs,
        #         attention_mask=attn_mask,
        #         max_new_tokens=max_new_tokens,
        #         num_beams=num_beams,
        #         do_sample=do_sample,
        #         # stopping_criteria=stopping_criteria,
        #     )
        answers = []
        for output_token in outputs:
            # print('output',output_token)
            if output_token[0] == 0:
                output_token = output_token[1:]
            # print(output_token,output_token.shape)
            # output_token=torch.tensor([22850,  3575,  8373])
            output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True)
            # print('output_texts',output_texts)
            # exit()
            if self.prompt_template=='###Human: {} ###Assistant: ':
                output_texts = output_texts.split('##')[0]
            else:
                output_texts = output_texts.split('</s>')[0]  # remove the stop sign </s>
            output_texts = output_texts.replace("<s>", "")
            output_texts = output_texts.split(r'[/INST]')[-1].strip()
            answers.append(output_texts)

        return answers
    def prepare_inputs_for_generation(
        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
    ):
        if past_key_values:
            input_ids = input_ids[:, -1:]

        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}

        model_inputs.update(
            {
                "past_key_values": past_key_values,
                "use_cache": kwargs.get("use_cache"),
                "attention_mask": attention_mask,
                "images": kwargs.get("images", None),
            }
        )
        return model_inputs

    @torch.no_grad()
    def multi_select(self, images, texts, answers, num_cand=None):
        all_losses = []
        for answer in answers:
            choice_samples = {
                'image': images,
                'instruction_input': texts,
                'answer': answer
            }
            loss = self.forward(choice_samples, reduction='none')['loss'].reshape(-1, 1)
            all_losses.append(loss)
            torch.cuda.empty_cache()
        all_losses = torch.cat(all_losses, dim=-1)
        if num_cand is not None:
            for i in range(all_losses.shape[0]):
                all_losses[i, num_cand[i]:] = 9999
        output_class_ranks = torch.argsort(all_losses, dim=-1)
        return output_class_ranks.tolist()
