from typing import Optional, Union, List, Tuple
from transformers.models.llava_next import LlavaNextForConditionalGeneration
import torch
import torch.nn as nn
from src.utils.mem import Mem
import wandb
from PIL import Image
from src.utils.utils import set_seed


class llava_prompter:
    def __init__(self, processor):
        self.processor = processor

    def __call__(self, question, answer=None, add_generation_prompt=None):
        conversation = []
        if question[:8] == '<image>\n':
            question = question[8:]
            conversation.append({'role': 'user', 'content': [{'type': 'image'},
                                                             {'type': 'text', 'text': question}]})
        elif question[-8:] == '\n<image>':
            question = question[:-8]
            conversation.append({'role': 'user', 'content': [{'type': 'text', 'text': question},
                                                             {'type': 'image'}]})
        else:
            print(question)
            raise ValueError(f"输入中需要包括图片！,输入：{question}")
        if '<image>' in question:
            raise ValueError(f"输入中需要包括图片！,输入：{question}")
        # conversation.append({'role': 'user', 'content': [{'type': 'text', 'text': question}]})
        if answer is not None:
            conversation.append({'role': 'assistant', 'content': [{'type': 'text', 'text': answer}]})
            return self.processor.apply_chat_template(conversation, add_generation_prompt=False)
        else:
            return self.processor.apply_chat_template(conversation,
                                                      add_generation_prompt=True if add_generation_prompt is None else add_generation_prompt)


class LLaVAPromptTuningForClassification(LlavaNextForConditionalGeneration):
    def __init__(self, config):
        super(LLaVAPromptTuningForClassification, self).__init__(config)
        self.classifier = nn.Linear(config.text_config.hidden_size, 2)
        self.num_labels = 2

    def add_special_token(self, processer):
        special_token = "<PROMPT>"
        processer.tokenizer.add_tokens([special_token])
        if processer.tokenizer.pad_token is None:
            processer.tokenizer.pad_token = processer.tokenizer.eos_token
        self.processer = processer
        self.special_token_id = processer.tokenizer.convert_tokens_to_ids(special_token)
        self.special_token_tensor = torch.tensor([self.special_token_id])
        self.resize_token_embeddings(len(processer.tokenizer))

    def forward(self,
                input_ids: torch.LongTensor = None,
                pixel_values: torch.FloatTensor = None,
                image_sizes: Optional[torch.LongTensor] = None,
                attention_mask: Optional[torch.Tensor] = None,
                position_ids: Optional[torch.LongTensor] = None,
                past_key_values: Optional[List[torch.FloatTensor]] = None,
                inputs_embeds: Optional[torch.FloatTensor] = None,
                vision_feature_layer: Optional[int] = None,
                vision_feature_select_strategy: Optional[str] = None,
                labels: Optional[torch.LongTensor] = None,
                use_cache: Optional[bool] = None,
                output_attentions: Optional[bool] = None,
                output_hidden_states: Optional[bool] = None,
                return_dict: Optional[bool] = None,
                cache_position: Optional[torch.LongTensor] = None,
                num_logits_to_keep: int = 0,
                cls=False):
        if cls:
            return self.cls_forward(input_ids=input_ids,
                                    labels=labels,
                                    past_key_values=past_key_values)
        else:
            return super().forward(input_ids=input_ids,
                                   pixel_values=pixel_values,
                                   image_sizes=image_sizes,
                                   attention_mask=attention_mask,
                                   position_ids=position_ids,
                                   past_key_values=past_key_values,
                                   inputs_embeds=inputs_embeds,
                                   vision_feature_layer=vision_feature_layer,
                                   vision_feature_select_strategy=vision_feature_select_strategy,
                                   labels=labels,
                                   use_cache=use_cache,
                                   output_attentions=output_attentions,
                                   output_hidden_states=output_hidden_states,
                                   return_dict=return_dict,
                                   cache_position=cache_position,
                                   num_logits_to_keep=num_logits_to_keep)

    def cls_forward(self, input_ids=None, labels=None, past_key_values=None):
        if input_ids is not None:
            special_token_ids = self.special_token_tensor.expand([input_ids.shape[0], 1])  # (batch_size, 1)
        elif past_key_values is not None:
            special_token_ids = self.special_token_tensor.expand([past_key_values[0][0].shape[0], 1])  # (batch_size, 1)
        else:
            raise ValueError("Both input_ids and past_key_values are None.")

        special_token_ids = special_token_ids.to(self.device)

        if input_ids is None:
            new_input_ids = special_token_ids
        else:
            new_input_ids = torch.cat([input_ids.to(self.device), special_token_ids], dim=-1)

        if new_input_ids is not None:
            # add attention mask
            attention_mask = torch.ones_like(new_input_ids).to(self.device)
            outputs = super().forward(input_ids=new_input_ids,
                                      attention_mask=attention_mask,
                                      past_key_values=past_key_values,
                                      output_hidden_states=True)

            # 获取最后一层的 hidden states
            hidden_states = outputs.hidden_states[-16]  # (batch_size, seq_length, hidden_size)

            # 提取最后一个非填充符号的 embedding
            prompt_token_embedding = hidden_states[
                                     :, -1, :]  # (batch_size, hidden_size)
            # 使用 prompt token embedding 进行分类
            with torch.enable_grad():
                logits = self.classifier(prompt_token_embedding)  # (batch_size, num_labels)
                logits = logits
                if labels is not None:
                    loss_fct = torch.nn.CrossEntropyLoss()
                    loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
                else:
                    return {"logits": logits}
            # if return_hidden_states:
            #     return {"loss": loss, "logits": logits, "hidden_states": prompt_token_embedding}
            return {"loss": loss, "logits": logits}

    def configure_model(self):
        # 只冻结除分类头和特殊 token 以外的所有参数
        for name, param in self.named_parameters():
            if "classifier" in name or "model.embed_tokens" in name:
                param.requires_grad = True  # 分类头解冻，允许训练
                print(name)
            else:
                param.requires_grad = False  # 其他参数冻结，不允许训练


class LlaVASafetyAdapter:
    def __init__(self, model, processor, optimizer_1, optimizer_2, cls_dataset, VQA_Dataset, steps=1, re_batch_size=4,
                 dt_batch_size=4, prompter=None, max_length=2048):

        super().__init__()
        self.max_length = max_length
        self.model = model
        self.token_tuning = torch.nn.Parameter(torch.zeros(1, model.config.vocab_size - 1))
        self.processor = processor
        self.optimizer_1 = optimizer_1
        self.optimizer_2 = optimizer_2
        self.steps = steps
        self.re_batch_size = re_batch_size
        self.dt_batch_size = dt_batch_size
        self.defense_mem = Mem(30)
        self.jail_mem = Mem(30)
        self.prompter = prompter

        # self.qa_mem = Mem(40)

        self.cls_dataset = cls_dataset
        self.cls_loader = torch.utils.data.DataLoader(self.cls_dataset,
                                                      batch_size=self.dt_batch_size,
                                                      shuffle=True,
                                                      num_workers=0
                                                      )
        self.cls_iter = iter(self.cls_loader)

        self.vqa_dataset = VQA_Dataset
        self.vqa_loader = torch.utils.data.DataLoader(self.vqa_dataset,
                                                      batch_size=self.re_batch_size,
                                                      shuffle=True,
                                                      num_workers=0
                                                      )
        self.vqa_iter = iter(self.vqa_loader)

    # @torch.no_grad()
    def chat_and_forward(self, text, image, train=False, wandb_step=None):
        text_input = self.prompter(text)
        image_file = Image.open(image)
        inputs = self.processor(images=image_file, text=text_input, return_tensors="pt").to(self.model.device)
        input_ids = inputs['input_ids']
        with torch.no_grad():
            outputs = self.model.generate(**inputs,
                                          return_dict_in_generate=True, max_length=self.max_length)
            # all_sequence = outputs.sequences
            return_sequence = outputs.sequences[:, input_ids.shape[-1]:]
            outputs = self.model(past_key_values=outputs.past_key_values, cls=True)
        labels = outputs['logits'].argmax(dim=-1)
        probs = outputs['logits'].softmax(dim=-1)
        max_prob = probs.max(dim=-1).values
        print(max_prob.item())
        answer = self.processor.batch_decode(return_sequence, skip_special_tokens=True)[0]
        if train:
            mem_data = {'question': text,
                        'answer': answer,
                        'image': image
                        }
            if labels == 1:
                # defense_text = self.prompter(text,
                #                              "I'm sorry, but I cannot comply with your request"
                #                              " for ethical and regulatory reasons.").strip() + "</s>"
                # # self.jail_data.add(jailbreak_prompt)
                # defense_qa = self.processor(text=defense_text,
                #                             images=image,
                #                             padding=True,
                #                             truncation=True,
                #                             return_tensors='pt',
                #                             max_length=self.max_length)

                # defense_qa_ids = defense_qa['input_ids']
                # defense_text = self.processor.batch_decode(defense_qa_ids, skip_special_tokens=False)[0][4:]
                defense_data = {'question': text,
                                'answer': "I'm sorry, but I cannot comply with your request"
                                          " for ethical and regulatory reasons.",
                                # 'answer_length': defense_qa_ids.shape[1] - input_ids.shape[1],
                                'image': image
                                }
                self.defense_mem.add(defense_data)

                if not answer.startswith("I'm sorry"):
                    self.jail_mem.add(mem_data)
                avg_train_loss, avg_dt_loss = 0, 0
                for i in range(1):
                    train_loss = self.adapt()
                    avg_train_loss += train_loss
                for i in range(5):
                    dt_loss = self.adapt_detector()
                    avg_dt_loss += dt_loss
                wandb.log({'train_loss': avg_train_loss, 'dt_loss': avg_dt_loss},
                          step=wandb_step)
            # elif labels == 0 and not answer.startswith("I'm sorry"):
            #     self.qa_mem.add(mem_data)
        return answer, labels.item()

    # @torch.enable_grad()
    def adapt(self):
        self.optimizer_1.zero_grad()
        defense_datas = self.defense_mem.get(1)
        # qa_datas = self.qa_mem.get(3)
        try:
            qa_datas = next(self.vqa_iter)
        except StopIteration:
            self.vqa_iter = iter(self.vqa_loader)
            qa_datas = next(self.vqa_iter)
        qa_images, qa_instructions, qa_answers = qa_datas
        qa_images, qa_instructions, qa_answers = list(qa_images), list(qa_instructions), list(qa_answers)
        for item in defense_datas:
            image, question, answer = item['image'], item['question'], item['answer']
            qa_images.append(image)
            qa_instructions.append(question)
            qa_answers.append(answer)
        qa_datas_processed = []
        for qa_image, qa_instruction, qa_answer in zip(qa_images, qa_instructions, qa_answers):
            qa_image = Image.open(qa_image)
            qa_ins = self.prompter(qa_instruction, qa_answer)
            qa_ids = self.processor(text=qa_ins,
                                    images=qa_image,
                                    padding=True,
                                    truncation=True,
                                    return_tensors='pt',
                                    max_length=self.max_length)['input_ids']
            instruction = self.prompter(qa_instruction)
            instruction_ids = self.processor(text=instruction,
                                             images=qa_image,
                                             padding=True,
                                             truncation=True,
                                             return_tensors='pt',
                                             max_length=self.max_length)['input_ids']
            answer_length = qa_ids.shape[1] - instruction_ids.shape[1]
            qa_datas_processed.append({'text': qa_ins,
                                       'image': qa_image,
                                       'answer_length': answer_length})
        # defense_datas.extend(qa_datas_processed)
        avg_loss = 0
        for data in qa_datas_processed:
            image, text, answer_length = data['image'], data['text'].strip(), data['answer_length']
            if text[-4:] != '</s>':
                text += '</s>'
            inputs = self.processor(text=text,
                                    images=image,
                                    padding=True,
                                    truncation=True,
                                    return_tensors='pt',
                                    max_length=self.max_length)
            labels = inputs['input_ids'].clone()
            labels[0][:-(answer_length + 1)] = -100
            inputs['labels'] = labels
            inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
            outputs = self.model(**inputs)
            loss = outputs.loss / len(defense_datas)
            avg_loss += loss.item()
            loss.backward()
        self.optimizer_1.step()

        return avg_loss

    def adapt_detector(self):
        self.optimizer_2.zero_grad()
        jailbreak_datas = self.jail_mem.get(3)
        defense_datas = self.defense_mem.get(3)
        train_texts = []
        train_images = []
        train_labels = []
        for jailbreak_data in jailbreak_datas:
            # jailbreak_text, jailbreak_image = jailbreak_data['text'], jailbreak_data['image']
            image, question, answer = Image.open(jailbreak_data['image']), jailbreak_data['question'], jailbreak_data['answer']
            jailbreak_text = self.prompter(question, answer)
            train_texts.append(jailbreak_text)
            train_images.append(image)
            train_labels.append(torch.tensor(1).to(self.model.device))
        for defense_data in defense_datas:
            image, question, answer = Image.open(defense_data['image']), defense_data['question'], defense_data['answer']
            defense_text = self.prompter(question, answer)
            train_texts.append(defense_text)
            train_images.append(image)
            train_labels.append(torch.tensor(0).to(self.model.device))
        train_labels = torch.stack(train_labels)
        try:
            detector_data = next(self.cls_iter)
        except StopIteration:
            self.cls_iter = iter(self.cls_loader)
            detector_data = next(self.cls_iter)
        detector_images, detector_texts, detector_labels = detector_data
        train_images.extend(Image.open(i) for i in detector_images)
        train_texts.extend(detector_texts)
        train_labels = torch.cat([train_labels, detector_labels.to(self.model.device)])
        # print(train_texts)
        avg_dt_loss = 0
        for text, image, label in zip(train_texts, train_images, train_labels):
            inputs = self.processor(text=text if text[-4:] != '</s>' else text[:-4],
                                    images=image,
                                    padding=True,
                                    truncation=True,
                                    return_tensors='pt',
                                    max_length=self.max_length)
            inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
            with torch.no_grad():
                outputs_dict = self.model(**inputs, labels=None, cls=False)
            outputs = self.model(past_key_values=outputs_dict['past_key_values'], labels=label.unsqueeze(0), cls=True)
            loss = outputs['loss'] / len(train_texts)
            loss.backward()
            avg_dt_loss += loss.item()
        self.optimizer_2.step()

        return avg_dt_loss
