from typing import Optional, Union, List, Tuple

from transformers import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama import LlamaForCausalLM
import torch
import torch.nn as nn
from src.utils.mem import Mem
import wandb

B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"

system_prompt = """\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""

from transformers import LogitsProcessorList, LogitsProcessor


class IgnoreTokenLogitsProcessor(LogitsProcessor):
    def __init__(self, ignore_token_id):
        self.ignore_token_id = ignore_token_id

    def __call__(self, input_ids, scores):
        scores[:, self.ignore_token_id] = -float("inf")
        return scores


logits_processor = LogitsProcessorList([IgnoreTokenLogitsProcessor(32000)])


def llama_prompter(question, answer=None):
    result = f"{B_INST} {B_SYS}{system_prompt}{E_SYS}{question.strip()} {E_INST}".strip()
    if answer is not None:
        result += f' {answer.strip()}'
    return result


# def llama_prompter_no_sys(question, answer=None):
#     result = f"{B_INST} {question.strip()} {E_INST}".strip()
#     if answer is not None:
#         result += f' {answer.strip()}'
#     return result


class LlamaPromptTuningForClassification(LlamaForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.classifier = nn.Linear(self.config.hidden_size, 2)
        self.num_labels = 2

    def add_special_token(self, tokenizer):
        special_token = "<PROMPT>"
        self.special_token = special_token
        self.tokenizer = tokenizer
        tokenizer.add_tokens([special_token])
        tokenizer.pad_token = tokenizer.eos_token
        self.special_token_id = tokenizer.convert_tokens_to_ids(special_token)
        self.special_token_tensor = torch.tensor([self.special_token_id])
        self.resize_token_embeddings(len(tokenizer))

    def forward(
            self,
            input_ids: torch.LongTensor = None,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
            inputs_embeds: Optional[torch.FloatTensor] = None,
            labels: Optional[torch.LongTensor] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
            cache_position: Optional[torch.LongTensor] = None,
            num_logits_to_keep: int = 0,
            cls=False,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        if cls:
            return self.cls_forward(input_ids, past_key_values, labels)
        else:
            return super().forward(
                input_ids=input_ids,
                attention_mask=attention_mask,
                past_key_values=past_key_values,
                labels=labels,
                position_ids=position_ids,
                inputs_embeds=inputs_embeds,
                output_attentions=output_attentions,
                use_cache=use_cache,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                cache_position=cache_position,
                num_logits_to_keep=num_logits_to_keep
            )

    def cls_forward(self, input_ids=None, past_key_values=None, labels=None):
        if input_ids is not None:
            special_token_ids = self.special_token_tensor.expand([input_ids.shape[0], 1])  # (batch_size, 1)
        elif past_key_values is not None:
            special_token_ids = self.special_token_tensor.expand([past_key_values[0][0].shape[0], 1])  # (batch_size, 1)
        else:
            raise ValueError("Both input_ids and past_key_values are None.")
        special_token_ids = special_token_ids.to(self.model.device)
        if input_ids is None:
            new_input_ids = special_token_ids
        else:
            new_input_ids = torch.cat([input_ids.to(self.model.device), special_token_ids], dim=-1)

        if new_input_ids is not None:
            outputs = super().forward(input_ids=new_input_ids, past_key_values=past_key_values,
                                      output_hidden_states=True)

        # 获取最后一层的 hidden states
        hidden_states = outputs.hidden_states[-16]  # (batch_size, seq_length, hidden_size)

        # 提取最后一个非填充符号的 embedding
        prompt_token_embedding = hidden_states[
                                 :, -1, :]  # (batch_size, hidden_size)
        # 使用 prompt token embedding 进行分类
        with torch.enable_grad():
            logits = self.classifier(prompt_token_embedding)  # (batch_size, num_labels)
            logits = logits
            if labels is not None:
                loss_fct = torch.nn.CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            else:
                return {"logits": logits}
        # if return_hidden_states:
        #     return {"loss": loss, "logits": logits, "hidden_states": prompt_token_embedding}
        return {"loss": loss, "logits": logits}

    def configure_model(self):
        # 只冻结除分类头和特殊 token 以外的所有参数
        for name, param in self.named_parameters():
            if "classifier" in name or "model.embed_tokens" in name:
                param.requires_grad = True  # 分类头解冻，允许训练
                print(name)
            else:
                param.requires_grad = False  # 其他参数冻结，不允许训练


class LlamaSafetyAdapter:
    def __init__(self, model, tokenizer, optimizer_1, optimizer_2, cls_dataset, QA_Dataset, steps=1, re_batch_size=3,
                 dt_batch_size=4, prompter=None):

        super().__init__()
        self.model = model
        self.token_tuning = torch.nn.Parameter(torch.zeros(1, model.config.vocab_size - 1))
        self.tokenizer = tokenizer
        self.optimizer_1 = optimizer_1
        self.optimizer_2 = optimizer_2
        self.steps = steps
        self.re_batch_size = re_batch_size
        self.dt_batch_size = dt_batch_size
        self.defense_mem = Mem(30)
        self.jail_mem = Mem(30)
        self.prompter = prompter

        self.cls_dataset = cls_dataset
        self.cls_loader = torch.utils.data.DataLoader(self.cls_dataset,
                                                      batch_size=self.dt_batch_size,
                                                      shuffle=True,
                                                      num_workers=0
                                                      )
        self.cls_iter = iter(self.cls_loader)

        self.qa_dataset = QA_Dataset
        self.qa_loader = torch.utils.data.DataLoader(self.qa_dataset,
                                                     batch_size=self.re_batch_size,
                                                     shuffle=True,
                                                     num_workers=0
                                                     )
        self.qa_iter = iter(self.qa_loader)

    # @torch.no_grad()
    def chat_and_forward(self, text, train=False, wandb_step=None):
        text_input = self.prompter(text)
        input_ids = self.tokenizer.encode(text_input, return_tensors='pt').to(self.model.device)
        with torch.no_grad():
            outputs = self.model.generate(input_ids=input_ids,
                                          return_dict_in_generate=True,
                                          max_length=2048)
            all_sequence = outputs.sequences
            return_sequence = outputs.sequences[:, input_ids.shape[-1]:]
            outputs = self.model.cls_forward(past_key_values=outputs.past_key_values)
        labels = outputs['logits'].argmax(dim=-1)
        probs = outputs['logits'].softmax(dim=-1)
        max_prob = probs.max(dim=-1).values
        print(max_prob.item())
        answer = self.tokenizer.batch_decode(return_sequence, skip_special_tokens=True)[0]

        # max_prob = probs.max(dim=-1).values
        if train:
            mem_data = {'text': self.tokenizer.batch_decode(all_sequence, skip_special_tokens=False)[0],
                        'answer_length': all_sequence.shape[1] - input_ids.shape[1]
                        }
            if labels == 1:
                defense_text = self.prompter(text,
                                             "I cannot fulfill your request. I'm just an AI assistant, "
                                             "it's not within my programming or ethical guidelines to provide "
                                             "guidance on illegal or "
                                             "harmful activities.").strip()
                # self.jail_data.add(jailbreak_prompt)
                defense_qa_ids = self.tokenizer(defense_text,
                                                padding=True,
                                                truncation=True,
                                                return_tensors='pt',
                                                max_length=2048).input_ids
                defense_data = {'text': defense_text,
                                'answer_length': defense_qa_ids.shape[1] - input_ids.shape[1]
                                }
                self.defense_mem.add(defense_data)

                if not answer.strip().startswith("I cannot"):
                    self.jail_mem.add(mem_data)

                avg_train_loss, avg_dt_loss = 0, 0
                for i in range(1):
                    train_loss = self.adapt()
                    avg_train_loss += train_loss
                for i in range(5):
                    dt_loss = self.adapt_detector()
                    avg_dt_loss += dt_loss
                wandb.log({'train_loss': avg_train_loss, 'dt_loss': avg_dt_loss},
                          step=wandb_step)
        return answer, labels.item()

    # @torch.enable_grad()
    def adapt(self):
        self.optimizer_1.zero_grad()
        defense_datas = self.defense_mem.get(1)
        try:
            qa_datas = next(self.qa_iter)
        except StopIteration:
            self.qa_iter = iter(self.qa_loader)
            qa_datas = next(self.qa_iter)
        qa_datas_processed = []
        qa_instructions, qa_answers = qa_datas

        for qa_instruction, qa_answer in zip(qa_instructions, qa_answers):
            qa_ins = self.prompter(qa_instruction, qa_answer)
            qa_ids = self.tokenizer(qa_ins,
                                    padding=True,
                                    truncation=True,
                                    return_tensors='pt',
                                    max_length=2048).input_ids
            instruction = self.prompter(qa_instruction)
            instruction_ids = self.tokenizer(instruction,
                                             padding=True,
                                             truncation=True,
                                             return_tensors='pt',
                                             max_length=2048).input_ids
            answer_length = qa_ids.shape[1] - instruction_ids.shape[1]
            qa_datas_processed.append({'text': qa_ins,
                                       'answer_length': answer_length
                                       })

        defense_datas.extend(qa_datas_processed)
        avg_loss = 0
        for data in defense_datas:
            text, answer_length = data['text'], data['answer_length']
            if text[-4:] != '</s>':
                text += '</s>'
            input_ids = self.tokenizer(text,
                                       padding=True,
                                       truncation=True,
                                       return_tensors='pt',
                                       max_length=2048).input_ids.to(self.model.device)
            labels = input_ids.clone()
            labels[0][:-(answer_length + 1)] = -100
            outputs = self.model(input_ids=input_ids,
                                 labels=labels)
            loss = outputs.loss / len(defense_datas)
            avg_loss += loss.item()
            loss.backward()
        self.optimizer_1.step()
        return avg_loss

    def adapt_detector(self):
        self.optimizer_2.zero_grad()
        jailbreak_datas = self.jail_mem.get(3)
        defense_datas = self.defense_mem.get(3)
        train_texts = []
        train_labels = []

        for jailbreak_data in jailbreak_datas:
            jailbreak_text = jailbreak_data['text']
            train_texts.append(jailbreak_text)
            train_labels.append(torch.tensor(1).to(self.model.device))
        for defense_data in defense_datas:
            defense_text = defense_data['text']
            train_texts.append(defense_text)
            train_labels.append(torch.tensor(0).to(self.model.device))
        train_labels = torch.stack(train_labels)

        try:
            detector_data = next(self.cls_iter)
        except StopIteration:
            self.cls_iter = iter(self.cls_loader)
            detector_data = next(self.cls_iter)

        detector_texts, detector_labels = detector_data
        train_texts.extend(detector_texts)
        train_labels = torch.cat([train_labels, detector_labels.to(self.model.device)])

        avg_dt_loss = 0
        for text, label in zip(train_texts, train_labels):
            input_ids = self.tokenizer(text,
                                       padding=True,
                                       truncation=True,
                                       return_tensors='pt',
                                       max_length=2048).input_ids.to(self.model.device)
            label = label.to(self.model.device)
            with torch.no_grad():
                outputs_dict = self.model(input_ids=input_ids, labels=None, cls=False)
            outputs = self.model(past_key_values=outputs_dict['past_key_values'], labels=label.unsqueeze(0), cls=True)

            loss = outputs['loss'] / len(train_texts)
            loss.backward()
            avg_dt_loss += loss.item()
        self.optimizer_2.step()
        return avg_dt_loss
