from transformers import LogitsProcessorList, LogitsProcessor
from typing import Optional, Union, List, Tuple
from transformers.models.llava_next import LlavaNextForConditionalGeneration
import torch
import torch.nn as nn
from src.utils.mem import Mem
import wandb
from PIL import Image


class LLaVAPromptTuningForClassification_NOGIST(LlavaNextForConditionalGeneration):
    def __init__(self, config):
        super(LLaVAPromptTuningForClassification_NOGIST, self).__init__(config)
        self.classifier = nn.Linear(config.text_config.hidden_size, 2)
        self.num_labels = 2

    def add_special_token(self, processer):
        pass

    def forward(self,
                input_ids: torch.LongTensor = None,
                pixel_values: torch.FloatTensor = None,
                image_sizes: Optional[torch.LongTensor] = None,
                attention_mask: Optional[torch.Tensor] = None,
                position_ids: Optional[torch.LongTensor] = None,
                past_key_values: Optional[List[torch.FloatTensor]] = None,
                inputs_embeds: Optional[torch.FloatTensor] = None,
                vision_feature_layer: Optional[int] = None,
                vision_feature_select_strategy: Optional[str] = None,
                labels: Optional[torch.LongTensor] = None,
                use_cache: Optional[bool] = None,
                output_attentions: Optional[bool] = None,
                output_hidden_states: Optional[bool] = None,
                return_dict: Optional[bool] = None,
                cache_position: Optional[torch.LongTensor] = None,
                num_logits_to_keep: int = 0,
                cls=False):
        if cls:
            return self.cls_forward(input_ids=input_ids,
                                    pixel_values=pixel_values,
                                    image_sizes=image_sizes,
                                    attention_mask=attention_mask,
                                    position_ids=position_ids,
                                    past_key_values=past_key_values,
                                    inputs_embeds=inputs_embeds,
                                    vision_feature_layer=vision_feature_layer,
                                    vision_feature_select_strategy=vision_feature_select_strategy,
                                    labels=labels,
                                    use_cache=use_cache,
                                    output_attentions=output_attentions,
                                    return_dict=return_dict,
                                    cache_position=cache_position,
                                    num_logits_to_keep=num_logits_to_keep)
        else:
            return super().forward(input_ids=input_ids,
                                   pixel_values=pixel_values,
                                   image_sizes=image_sizes,
                                   attention_mask=attention_mask,
                                   position_ids=position_ids,
                                   past_key_values=past_key_values,
                                   inputs_embeds=inputs_embeds,
                                   vision_feature_layer=vision_feature_layer,
                                   vision_feature_select_strategy=vision_feature_select_strategy,
                                   labels=labels,
                                   use_cache=use_cache,
                                   output_attentions=output_attentions,
                                   output_hidden_states=output_hidden_states,
                                   return_dict=return_dict,
                                   cache_position=cache_position,
                                   num_logits_to_keep=num_logits_to_keep)

    def cls_forward(self,
                    input_ids: torch.LongTensor = None,
                    pixel_values: torch.FloatTensor = None,
                    image_sizes: Optional[torch.LongTensor] = None,
                    attention_mask: Optional[torch.Tensor] = None,
                    position_ids: Optional[torch.LongTensor] = None,
                    past_key_values: Optional[List[torch.FloatTensor]] = None,
                    inputs_embeds: Optional[torch.FloatTensor] = None,
                    vision_feature_layer: Optional[int] = None,
                    vision_feature_select_strategy: Optional[str] = None,
                    labels: Optional[torch.LongTensor] = None,
                    use_cache: Optional[bool] = None,
                    output_attentions: Optional[bool] = None,
                    return_dict: Optional[bool] = None,
                    cache_position: Optional[torch.LongTensor] = None,
                    num_logits_to_keep: int = 0
                    ):

        with torch.no_grad():
            outputs = super().forward(input_ids=input_ids,
                                      pixel_values=pixel_values,
                                      image_sizes=image_sizes,
                                      attention_mask=attention_mask,
                                      position_ids=position_ids,
                                      past_key_values=past_key_values,
                                      inputs_embeds=inputs_embeds,
                                      vision_feature_layer=vision_feature_layer,
                                      vision_feature_select_strategy=vision_feature_select_strategy,
                                      use_cache=use_cache,
                                      output_attentions=output_attentions,
                                      output_hidden_states=True,
                                      return_dict=return_dict,
                                      cache_position=cache_position,
                                      num_logits_to_keep=num_logits_to_keep
                                      )

        # 获取最后一层的 hidden states
        hidden_states = outputs.hidden_states[-1]  # (batch_size, seq_length, hidden_size)

        # 提取最后一个非填充符号的 embedding
        prompt_token_embedding = hidden_states[
                                 :, -1, :]  # (batch_size, hidden_size)
        # 使用 prompt token embedding 进行分类
        with torch.enable_grad():
            logits = self.classifier(prompt_token_embedding)  # (batch_size, num_labels)
            logits = logits
            if labels is not None:
                loss_fct = torch.nn.CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
                outputs['loss'] = loss
            outputs['cls_logits'] = logits
        # if return_hidden_states:
        #     return {"loss": loss, "logits": logits, "hidden_states": prompt_token_embedding}
        return outputs

    def configure_model(self):
        # 只冻结除分类头和特殊 token 以外的所有参数
        for name, param in self.named_parameters():
            if "classifier" in name:
                param.requires_grad = True  # 分类头解冻，允许训练
            else:
                param.requires_grad = False  # 其他参数冻结，不允许训练


class LlaVASafetyAdapter_NG:
    def __init__(self, model, processor, optimizer_1, optimizer_2, cls_dataset, VQA_Dataset, steps=1, re_batch_size=4,
                 dt_batch_size=4, prompter=None, max_length=2048):

        super().__init__()
        self.model = model
        self.max_length = max_length
        self.processor = processor
        self.optimizer_1 = optimizer_1
        self.optimizer_2 = optimizer_2
        self.steps = steps
        self.re_batch_size = re_batch_size
        self.dt_batch_size = dt_batch_size
        self.defense_mem = Mem(30)
        self.jail_mem = Mem(30)
        self.prompter = prompter

        # self.qa_mem = Mem(40)

        self.cls_dataset = cls_dataset
        self.cls_loader = torch.utils.data.DataLoader(self.cls_dataset,
                                                      batch_size=self.dt_batch_size,
                                                      shuffle=True,
                                                      num_workers=0
                                                      )
        self.cls_iter = iter(self.cls_loader)

        self.vqa_dataset = VQA_Dataset
        self.vqa_loader = torch.utils.data.DataLoader(self.vqa_dataset,
                                                      batch_size=self.re_batch_size,
                                                      shuffle=True,
                                                      num_workers=0
                                                      )
        self.vqa_iter = iter(self.vqa_loader)

    # @torch.no_grad()
    def chat_and_forward(self, text, image, train=False, wandb_step=None):
        text_input = self.prompter(text)
        image_file = Image.open(image)
        inputs = self.processor(images=image_file, text=text_input, return_tensors="pt").to(self.model.device)
        input_ids = inputs['input_ids']
        with torch.no_grad():
            outputs = self.model.generate(**inputs,
                                          return_dict_in_generate=True,
                                          output_hidden_states=True,
                                          max_length=self.max_length)
            # all_sequence = outputs.sequences
            return_sequence = outputs.sequences[:, input_ids.shape[-1]:]
            logits = self.model.classifier(outputs.hidden_states[-1][-1])
        labels = logits.argmax(dim=-1)
        probs = logits.softmax(dim=-1)
        max_prob = probs.max(dim=-1).values
        print(max_prob.item())
        answer = self.processor.batch_decode(return_sequence, skip_special_tokens=True)[0]
        if train:
            mem_data = {'question': text,
                        'answer': answer,
                        'image': image
                        }
            if labels == 1:
                defense_data = {'question': text,
                                'answer': "I'm sorry, but I cannot comply with your request"
                                          " for ethical and regulatory reasons.",
                                # 'answer_length': defense_qa_ids.shape[1] - input_ids.shape[1],
                                'image': image
                                }
                self.defense_mem.add(defense_data)

                if not answer.startswith("I'm sorry"):
                    self.jail_mem.add(mem_data)
                avg_train_loss, avg_dt_loss = 0, 0
                for i in range(1):
                    train_loss = self.adapt()
                    avg_train_loss += train_loss
                for i in range(5):
                    dt_loss = self.adapt_detector()
                    avg_dt_loss += dt_loss
                wandb.log({'train_loss': avg_train_loss, 'dt_loss': avg_dt_loss},
                          step=wandb_step)
            # elif labels == 0 and not answer.startswith("I'm sorry"):
            #     self.qa_mem.add(mem_data)
        return answer, labels.item()

    # @torch.enable_grad()
    def adapt(self):
        self.optimizer_1.zero_grad()
        defense_datas = self.defense_mem.get(1)
        # qa_datas = self.qa_mem.get(3)
        try:
            qa_datas = next(self.vqa_iter)
        except StopIteration:
            self.vqa_iter = iter(self.vqa_loader)
            qa_datas = next(self.vqa_iter)
        qa_images, qa_instructions, qa_answers = qa_datas
        qa_images, qa_instructions, qa_answers = list(qa_images), list(qa_instructions), list(qa_answers)
        for item in defense_datas:
            image, question, answer = item['image'], item['question'], item['answer']
            qa_images.append(image)
            qa_instructions.append(question)
            qa_answers.append(answer)
        qa_datas_processed = []
        for qa_image, qa_instruction, qa_answer in zip(qa_images, qa_instructions, qa_answers):
            qa_image = Image.open(qa_image)
            qa_ins = self.prompter(qa_instruction, qa_answer)
            qa_ids = self.processor(text=qa_ins,
                                    images=qa_image,
                                    padding=True,
                                    truncation=True,
                                    return_tensors='pt',
                                    max_length=self.max_length)['input_ids']
            instruction = self.prompter(qa_instruction)
            instruction_ids = self.processor(text=instruction,
                                             images=qa_image,
                                             padding=True,
                                             truncation=True,
                                             return_tensors='pt',
                                             max_length=self.max_length)['input_ids']
            answer_length = qa_ids.shape[1] - instruction_ids.shape[1]
            qa_datas_processed.append({'text': qa_ins,
                                       'image': qa_image,
                                       'answer_length': answer_length})
        # defense_datas.extend(qa_datas_processed)
        avg_loss = 0
        for data in qa_datas_processed:
            image, text, answer_length = data['image'], data['text'].strip(), data['answer_length']
            if text[-4:] != '</s>':
                text += '</s>'
            inputs = self.processor(text=text,
                                    images=image,
                                    padding=True,
                                    truncation=True,
                                    return_tensors='pt',
                                    max_length=self.max_length)
            labels = inputs['input_ids'].clone()
            labels[0][:-(answer_length + 1)] = -100
            inputs['labels'] = labels
            inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
            outputs = self.model(**inputs)
            loss = outputs.loss / len(defense_datas)
            avg_loss += loss.item()
            loss.backward()
        self.optimizer_1.step()

        return avg_loss

    def adapt_detector(self):
        self.optimizer_2.zero_grad()
        jailbreak_datas = self.jail_mem.get(3)
        defense_datas = self.defense_mem.get(3)
        train_texts = []
        train_images = []
        train_labels = []
        for jailbreak_data in jailbreak_datas:
            # jailbreak_text, jailbreak_image = jailbreak_data['text'], jailbreak_data['image']
            image, question, answer = Image.open(jailbreak_data['image']), jailbreak_data['question'], jailbreak_data['answer']
            jailbreak_text = self.prompter(question, answer)
            train_texts.append(jailbreak_text)
            train_images.append(image)
            train_labels.append(torch.tensor(1).to(self.model.device))
        for defense_data in defense_datas:
            image, question, answer = Image.open(defense_data['image']), defense_data['question'], defense_data['answer']
            defense_text = self.prompter(question, answer)
            train_texts.append(defense_text)
            train_images.append(image)
            train_labels.append(torch.tensor(0).to(self.model.device))
        train_labels = torch.stack(train_labels)
        try:
            detector_data = next(self.cls_iter)
        except StopIteration:
            self.cls_iter = iter(self.cls_loader)
            detector_data = next(self.cls_iter)
        detector_images, detector_texts, detector_labels = detector_data
        train_images.extend(Image.open(i) for i in detector_images)
        train_texts.extend(detector_texts)
        train_labels = torch.cat([train_labels, detector_labels.to(self.model.device)])
        # print(train_texts)
        avg_dt_loss = 0
        for text, image, label in zip(train_texts, train_images, train_labels):
            inputs = self.processor(text=text if text[-4:] != '</s>' else text[:-4],
                                    images=image,
                                    padding=True,
                                    truncation=True,
                                    return_tensors='pt',
                                    max_length=self.max_length)
            inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
            with torch.no_grad():
                outputs = self.model(**inputs, labels=label, cls=True)
            loss = outputs['loss'] / len(train_texts)
            loss.backward()
            avg_dt_loss += loss.item()
        self.optimizer_2.step()

        return avg_dt_loss
