"""
Modified from DoLA Code
"""
import argparse
import time
import csv
from tqdm import tqdm
import os
import json
import scipy
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
# from model_utils.gemma.modeling_gemma import GemmaForCausalLM
# from model_utils.gemma.tokenization_gemma import GemmaTokenizer
from transformers.utils import add_start_docstrings
from transformers.generation.stopping_criteria import StoppingCriteriaList, StoppingCriteria, STOPPING_CRITERIA_INPUTS_DOCSTRING
from decoding_algorithm.utils import build_prompt_and_answer
import argparse
import warnings
import pandas as pd
import numpy as np
from scipy.spatial.distance import jensenshannon

MAX_TOKENS=1536

class LLamaQaStoppingCriteria(StoppingCriteria):
    """
    This class can be used to stop generation whenever the model generates '\nQ:' tokens. It means that the model has finished generating the answer and start generating a new question.
    """
    def __init__(self, list_token_ids_sequence: list = [[29984, 29901]]):
        self.token_ids_sequences = []
        self.lengths = []
        for token_ids_sequence in list_token_ids_sequence:
            self.token_ids_sequences.append(torch.tensor(token_ids_sequence, dtype=torch.long))
            self.lengths.append(len(token_ids_sequence))
        
    @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        # check the final {self.length} tokens
        stop = False
        for token_ids_sequence, length in zip(self.token_ids_sequences, self.lengths):
            if input_ids.shape[-1] < length:
                continue
            else:
                if bool(torch.all(input_ids[0, -length:] == token_ids_sequence.to(input_ids.device))):
                    stop = True
                    break
        return stop


class ContrastiveDecoding:
    """
    Implementation for different contrastive decoding:
    1. Baseline (greedy, beam search, sample-topk-topp-beam)
    2. Vanilla Contrastive Decoding: "Contrastive Decoding: Open-ended Text Generation as Optimization"
    3. DoLA: "DoLa: Decoding by Contrasting Layers Improves Factuality in Large Language Models"
    4. CAD: "Trusting Your Evidence: Hallucinate Less with Context-aware Decoding" (TBD)
    5. ICD: "Improving Factuality of Large Language Models via Contrasting Intentionally Induced Hallucinations"
    """
    def __init__(self, model_name, device="cuda", max_gpu_memory=39, amateur_model_name=None, num_gpus=-1, amateur_model_nums_gpus=-1):
        """Init Method

        Args:
            model_name (str): base model (teacher model when using contrastive decoding).
            device (str): used device. Defaults to `cuda`.
            max_gpu_memory (int, optional): max gpu memory. Defaults to 39.
            amateur_model_name (str, optional): amateur model used in contrastive decoding. Defaults to None.
            num_gpus (int, optional): number of used gpus for base model. Defaults to -1 (auto).
            amateur_model_nums_gpus (int, optional): number of used gpus for amateur model. Defaults to -1 (auto).
        """
        self.model_name = model_name
        self.amateur_model_name = amateur_model_name
        self.device = device
        self.stopping_criteria = None
        self.max_gpu_memory = max_gpu_memory
        self.top_k_l = 2
        self.top_k_h = 4

        self.model, self.tokenizer = self.load_model(model_name, num_gpus)
        
        if amateur_model_name is not None:
            self.amateur_model, self.amateur_model_tokenizer = self.load_model(amateur_model_name, amateur_model_nums_gpus, num_gpus)
        self.all_gpu_nums = num_gpus + amateur_model_nums_gpus
        self.T = 0.5
        self.label_id_list = []
        if "llama-2-7b" in model_name.lower():
            print("== MODEL {} ==".format("llama-2-7b"))
            self.choose_layer_list = range(0, 32)
            self.decoder_head_num = 32
            self.decoder_layer_num = 32
            self.decoder_head_list = range(0, self.decoder_head_num)
            self.decoder_layer_list = range(0, self.decoder_layer_num)
        elif "llama-3-8b" in model_name.lower():
            print("== MODEL {} ==".format("llama-3-8b"))
            self.choose_layer_list = range(8, 24)
            self.decoder_head_num = 32
            self.decoder_layer_num = 32
            self.decoder_head_list = range(0, self.decoder_head_num)
            self.decoder_layer_list = range(0, self.decoder_layer_num)
        elif "llama-2-13b" in model_name.lower():
            self.choose_layer_list = range(0, 40)
            self.decoder_head_num = 40
            self.decoder_layer_num = 40
            self.decoder_head_list = range(0, self.decoder_head_num)
            self.decoder_layer_list = range(0, self.decoder_layer_num)
        elif "gemma-2b" in model_name.lower():
            self.choose_layer_list = range(0, 18)
            self.decoder_head_num = 8
            self.decoder_layer_num = 18
            self.decoder_head_list = range(0, self.decoder_head_num)
            self.decoder_layer_list = range(0, self.decoder_layer_num)
        elif "gemma-7b" in model_name.lower():
            self.choose_layer_list = range(0, 28)
            self.decoder_head_num = 16
            self.decoder_layer_num = 28
            self.decoder_head_list = range(0, self.decoder_head_num)
            self.decoder_layer_list = range(0, self.decoder_layer_num)
        elif "mistral-7b" in model_name.lower():
            print("== MODEL {} ==".format("mistral-7b"))
            self.choose_layer_list = range(0, 32)
            self.decoder_head_num = 32
            self.decoder_layer_num = 32
            self.decoder_head_list = range(0, self.decoder_head_num)
            self.decoder_layer_list = range(0, self.decoder_layer_num)

        assert self.all_gpu_nums <= 8

    def load_model(self, model_name, num_gpus, start_id=0):
        """load model

        Args:
            model_name (_type_): _description_
            num_gpus (_type_): _description_
            start_id (_type_): _description_

        Raises:
            ValueError: _description_

        Returns:
            model: transformers model
            tokenizer: transformers tokenizer
        """
        if self.device == "cuda":
            ## v100 machine
            # kwargs = {"torch_dtype": torch.float16, "offload_folder": f"{model_name}/offload"}
            
            # a100 machine
            kwargs = {"torch_dtype": torch.bfloat16, "offload_folder": f"{model_name}/offload"}
            if num_gpus == -1:
                kwargs["device_map"] = "auto"
            else:
                num_gpus = int(num_gpus)
                if torch.cuda.device_count() != 1:
                    kwargs.update({
                        "device_map": "auto",
                        "max_memory": {i: f"{self.max_gpu_memory}GiB" for i in range(start_id, start_id + num_gpus)},
                    })
        elif self.device == "cpu":
            kwargs = {}
        else:
            raise ValueError(f"Invalid device: {self.device}")
        """
        if "gemma" in model_name.lower():
            print("model name {}".format(model_name))
            tokenizer = GemmaTokenizer.from_pretrained(model_name)
            model = GemmaForCausalLM.from_pretrained(model_name, **kwargs)
        elif "llama-2" in model_name.lower():
            print("model name {}".format(model_name))
            tokenizer = LlamaTokenizer.from_pretrained(model_name)
            model = LlamaForCausalLM.from_pretrained(model_name, **kwargs)
        else:
        """
        tokenizer = AutoTokenizer.from_pretrained(model_name if not 'vicuna' in model_name else 'huggyllama/llama-7b', trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(model_name,
            low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
        
        if self.device == "cuda" and num_gpus == 1:  # one gpu fits two models
            model.cuda()

        for p in model.parameters():
            p.requires_grad = False

        return model, tokenizer

    def set_stop_words(self, stop_words):
        """Stop words for early stopping of genertation 

        Args:
            stop_words (_type_): _description_
        """
        self.stop_words = stop_words
        self.stopping_criteria = StoppingCriteriaList()
        list_stop_word_ids = []
        for stop_word in self.stop_words:
            stop_word_ids = self.tokenizer.encode('\n' + stop_word)[2:]
            list_stop_word_ids.append(stop_word_ids)
            print("Added stop word: ", stop_word, 'with the ids', stop_word_ids, flush=True)
        self.stopping_criteria.append(LLamaQaStoppingCriteria(list_stop_word_ids))

    def set_label_id(self, task_name):
        if task_name == "bbh" or task_name == "(mmlu)":
            self.label_id_list = [
                self.tokenizer("(A").input_ids[-1],
                self.tokenizer("(B").input_ids[-1],
                self.tokenizer("(C").input_ids[-1],
                self.tokenizer("(D").input_ids[-1]
            ]
        elif task_name == "mmlu":
            self.label_id_list = [
                self.tokenizer("A").input_ids[-1],
                self.tokenizer("B").input_ids[-1],
                self.tokenizer("C").input_ids[-1],
                self.tokenizer("D").input_ids[-1]
            ]
        else:
            raise ValueError(f"Invalid task name: {task_name}")

    def generate(self, input_text=None, evil_input_text=None, input_ids=None, 
                 attention_temperature=1,
                 max_new_tokens=256, top_p=0.95, top_k=0, temperature=0.8, 
                 mature_layer=None, premature_layer=None, candidate_premature_layers=[], 
                 mode='baseline', verbose=False, remove_stop_words=False, relative_top=0.1, 
                 **kwargs):
        #TODO: Prompt-based Contrastive Decoding for generating content
        """_summary_

        Args:
            input_text (_type_): _description_
            max_new_tokens (int, optional): _description_. Defaults to 256.
            top_p (float, optional): _description_. Defaults to 0.95.
            top_k (int, optional): _description_. Defaults to 0.
            temperature (float, optional): _description_. Defaults to 0.8.
            mature_layer (_type_, optional): _description_. Defaults to None.
            premature_layer (_type_, optional): _description_. Defaults to None.
            candidate_premature_layers (list, optional): _description_. Defaults to [].
            mode (str, optional): _description_. Defaults to 'baseline'.
            verbose (bool, optional): _description_. Defaults to True.
            remove_stop_words (bool, optional): _description_. Defaults to False.
            relative_top (float, optional): _description_. Defaults to 0.1.

        Returns:
            _type_: _description_
        """
        with torch.no_grad():

            if input_ids is None:
                input_ids = self.tokenizer(input_text, return_tensors="pt").input_ids.to(self.device)
                
            if evil_input_text is not None:
                evil_input_ids = self.tokenizer(evil_input_text, return_tensors="pt").input_ids.to(self.device)
            
            max_len = input_ids.shape[-1] + max_new_tokens

            if mode == 'baseline':
                outputs = self.model.generate(input_ids, 
                                    max_length=max_len,
                                    attention_temperature=attention_temperature, 
                                    num_return_sequences=1,
                                    output_scores=True, return_dict_in_generate=True,
                                    top_p=top_p, top_k=top_k, temperature=temperature, stopping_criteria=self.stopping_criteria, do_sample=True, **kwargs)
            elif mode == 'dola-static':
                assert mature_layer is not None, "mature_layer must be specified"
                assert premature_layer is not None, "premature_layer must be specified"
                outputs = self.model.generate(input_ids, max_length=max_len, num_return_sequences=1,
                                    output_scores=True, return_dict_in_generate=True, dola_decoding=True,
                                    mature_layer=mature_layer, premature_layer=premature_layer,
                                    top_p=top_p, top_k=top_k, temperature=temperature, stopping_criteria=self.stopping_criteria, relative_top=relative_top, **kwargs)
            elif mode == 'dola':
                assert mature_layer is not None, "mature_layer must be specified"
                assert candidate_premature_layers is not None, "candidate_premature_layers must be specified"
                outputs = self.model.generate(input_ids, max_length=max_len, num_return_sequences=1,
                                        output_scores=True, return_dict_in_generate=True, dola_decoding=True,
                                        top_p=top_p, top_k=top_k, temperature=temperature, stopping_criteria=self.stopping_criteria, relative_top=relative_top, 
                                        mature_layer=mature_layer, premature_layer=None, candidate_premature_layers=candidate_premature_layers, **kwargs)
                premature_layer_dist = outputs.premature_layer_dist
            elif mode == "contrastive-decoding":
                assert self.amateur_model is not None, "amateur model must be specified if using contrastive decoding"
                outputs = self.model.generate(input_ids, max_length=max_len, num_return_sequences=1,
                    output_scores=True, return_dict_in_generate=True, contrastive_decoding=True,
                    student_model=self.amateur_model,
                    top_p=top_p, top_k=top_k, temperature=temperature, stopping_criteria=self.stopping_criteria, relative_top=relative_top, **kwargs)
            elif mode == "prompt-contrastive-decoding":
                assert evil_input_text is not None, "amateur model must be specified if using contrastive decoding"
                outputs = self.model.generate(input_ids, evil_input_ids=evil_input_ids, max_length=max_len, num_return_sequences=1,
                    output_scores=True, return_dict_in_generate=True, contrastive_decoding=True,
                    top_p=top_p, top_k=top_k, temperature=temperature, stopping_criteria=self.stopping_criteria, relative_top=relative_top, **kwargs)
            sequences, scores = outputs.sequences, outputs.scores

            # skip the tokens in the input prompt
            gen_sequences = sequences[:, input_ids.shape[-1]:][0, :]
            gen_arr = gen_sequences.cpu().numpy()

            output_str = self.tokenizer.decode(gen_sequences, skip_special_tokens=True)

            if verbose:
                print('MODEL OUTPUT: \n{0}'.format(output_str))

            if remove_stop_words:
                for stop_word in self.stop_words:
                    length_to_remove = len(stop_word)
                    if output_str[-length_to_remove:] == stop_word:
                        output_str = output_str[:-length_to_remove]
                output_str = output_str.strip()

        if self.device:
            torch.cuda.empty_cache()
        return output_str, (premature_layer_dist if mode == 'dola' else None)

    def get_relative_top_filter(self, scores: torch.FloatTensor, relative_top: float = 0.1, min_tokens_to_keep: int = 1):
        scores_normalized = scores.log_softmax(dim=-1) 
        sorted_logits, sorted_indices = torch.sort(scores_normalized, descending=True)
        min_thresh = sorted_logits[..., min_tokens_to_keep-1] 
        probs_max = torch.max(scores_normalized, dim=-1).values
        probs_thresh = probs_max + np.log(relative_top)
        probs_thresh = torch.min(min_thresh, probs_thresh)
        probs_thresh = probs_thresh.unsqueeze(-1)
        return scores_normalized < probs_thresh

    # 对每一个样本，计算关注系数
    def get_focus_coef(self, input_ids, problem_pos, cot_pos, attn_t=1):
        prompt_array = np.zeros((self.decoder_layer_num, self.decoder_head_num))
        problem_array = np.zeros((self.decoder_layer_num, self.decoder_head_num))
        cot_array = np.zeros((self.decoder_layer_num, self.decoder_head_num))
        outputs = self.model(input_ids, attention_temperature=attn_t, output_attentions=True, return_dict=True)["attentions"]
        attn_scores = torch.stack(outputs, dim=0)
        attn_scores = torch.squeeze(attn_scores, dim=1).cpu().float().numpy()
        with torch.no_grad():
            for l in range(0, self.decoder_layer_num):
                for h in range(0, self.decoder_head_num):
                    attention_data_bias = attn_scores[l][h]
                    prompt_len = problem_pos - 0 
                    problem_len = cot_pos - problem_pos
                    cot_len = len(input_ids[0]) - cot_pos
                    x_0 = sum(attention_data_bias[-1, 0:problem_pos]) / prompt_len
                    x_1 = sum(attention_data_bias[-1, problem_pos:cot_pos]) / problem_len
                    if cot_len != 0:
                        x_2 = sum(attention_data_bias[-1, cot_pos:]) / cot_len
                    else:
                        x_2 = 0
                    x = np.array([x_0, x_1, x_2])
                    prompt_coef = x[0] / x.sum()
                    problem_coef = x[1] / x.sum()
                    cot_coef = x[2] / x.sum()
                    prompt_array[l][h] += prompt_coef
                    problem_array[l][h] += problem_coef
                    cot_array[l][h] += cot_coef
        return prompt_array, problem_array, cot_array 

    # 计算样本的平均关注度
    def get_avg_focus_coef(self, data_train, attn_t=1, bias=False):
        if not isinstance(data_train, list):
            data_train = [data_train] 
        avg_prompt_coef, avg_problem_coef, avg_cot_coef = 0, 0, 0
        valid_num = 0
        with torch.no_grad():
            for i, data in enumerate(data_train):
                prompt = data["prompt"] if not bias else data["bias_prompt"]
                question = data["problem"]
                cot = data["cot"] if "cot" in data.keys() else ""
                prompt_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
                prefix_ids = self.tokenizer(prompt+question, return_tensors="pt").input_ids.to(self.device)
                content_ids = self.tokenizer(prompt+question+cot, return_tensors="pt").input_ids.to(self.device)
                if len(content_ids[0]) > MAX_TOKENS:
                    prompt_coef, problem_coef, cot_coef = 0, 0, 0
                else:
                    problem_pos = prompt_ids.shape[-1]
                    cot_pos = prefix_ids.shape[-1]
                    prompt_array, problem_array, cot_array = self.get_focus_coef(content_ids, problem_pos, cot_pos, attn_t=attn_t)
                    valid_num += 1
                    prompt_coef = np.sum(prompt_array, axis=1)/(self.decoder_head_num)
                    problem_coef = np.sum(problem_array, axis=1)/(self.decoder_head_num)
                    cot_coef = np.sum(cot_array, axis=1)/(self.decoder_head_num)
                avg_prompt_coef += prompt_coef
                avg_problem_coef += problem_coef
                avg_cot_coef += cot_coef
            if valid_num:
                avg_prompt_coef /= valid_num
                avg_problem_coef /= valid_num
                avg_cot_coef /= valid_num
        return avg_prompt_coef, avg_problem_coef, avg_cot_coef
    
    def get_avg_head_focus_coef(self, data_train, layer=0, attn_t=1, bias=False):
        if not isinstance(data_train, list):
            data_train = [data_train] 
        avg_prompt_coef, avg_problem_coef, avg_cot_coef = 0, 0, 0
        valid_num = 0
        with torch.no_grad():
            for i, data in enumerate(data_train):
                prompt = data["prompt"] if not bias else data["bias_prompt"]
                question = data["problem"]
                cot = data["cot"] if "cot" in data.keys() else ""
                prompt_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
                prefix_ids = self.tokenizer(prompt+question, return_tensors="pt").input_ids.to(self.device)
                content_ids = self.tokenizer(prompt+question+cot, return_tensors="pt").input_ids.to(self.device)
                if len(content_ids[0]) > MAX_TOKENS:
                    prompt_coef, problem_coef, cot_coef = 0, 0, 0
                else:
                    problem_pos = prompt_ids.shape[-1]
                    cot_pos = prefix_ids.shape[-1]
                    prompt_array, problem_array, cot_array = self.get_focus_coef(content_ids, problem_pos, cot_pos, attn_t=attn_t)
                    valid_num += 1
                    prompt_coef = prompt_array[layer]
                    problem_coef = problem_array[layer]
                    cot_coef = cot_array[layer]
                avg_prompt_coef += prompt_coef
                avg_problem_coef += problem_coef
                avg_cot_coef += cot_coef
            if valid_num:
                avg_prompt_coef /= valid_num
                avg_problem_coef /= valid_num
                avg_cot_coef /= valid_num
        return avg_prompt_coef, avg_problem_coef, avg_cot_coef

    def contrast_find_head_list(self, data_train, layer_list=None, T=0.5, threshold=0, bias=True, reverse=False, intervene_way=0, use_KL=False):
        if not self.label_id_list:
            raise ValueError("The label list is empty, please set_label_id(task name)")
        if layer_list == None:
            layer_list = self.choose_layer_list
        attention_temperature_list = []
        attn_t_sum = {}
        with torch.no_grad():
            for l in layer_list:
                interaction_head_dict = {k: 0 for k in range(0, self.decoder_head_num)}
                for h in range(0, self.decoder_head_num):
                    attn_t = [intervene_way, {l:([h], T)}]
                    avg_delta = 0
                    if l not in attn_t_sum.keys():
                        attn_t_sum[l] = ([], T)
                    for data in data_train:
                        y_true = data["y_true"]
                        content = data["bias_prompt"] if bias else data["prompt"]
                        input_ids = self.tokenizer(content, return_tensors='pt').input_ids.to(self.device)
                        # interaction
                        outputs = self.model(input_ids, attention_temperature=attn_t, 
                                            output_attentions=True, return_dict=True)["logits"][0].squeeze(0)
                        logits = outputs.log_softmax(-1)  # logits to log probs
                        logits = logits[-1, :]
                        interaction_probs = (
                            torch.nn.functional.softmax(
                                torch.tensor([logits[label_id] for label_id in self.label_id_list]),
                                dim=0,
                            ).detach().cpu().to(torch.float32).numpy()
                        )
                        # origin
                        outputs = self.model(input_ids, output_attentions=True, 
                                            return_dict=True)["logits"][0].squeeze(0)
                        logits = outputs.log_softmax(-1)  # logits to log probs
                        logits = logits[-1, :]
                        origin_probs = (
                            torch.nn.functional.softmax(
                                torch.tensor([logits[label_id] for label_id in self.label_id_list]),
                                dim=0,
                            ).detach().cpu().to(torch.float32).numpy()
                        )
                        origin_y_pred = int(np.argmax(origin_probs))
                        if use_KL:
                            # KL = scipy.stats.entropy(interaction_probs, origin_probs)
                            # KL = jensenshannon(interaction_probs, origin_probs, base=2)
                            # avg_delta += (interaction_probs[y_true] + 1 - KL)
                            # avg_delta += KL * KL
                            true_probs = torch.zeros(4)
                            true_probs[y_true] = 1
                            kl_1 = jensenshannon(interaction_probs, origin_probs, base=2)
                            kl_2 = jensenshannon(interaction_probs, true_probs, base=2)
                            # avg_delta += (interaction_probs[y_true] + 1 - KL)
                            avg_delta += 1 - kl_1*kl_1 + 1 - kl_2*kl_2
                        else:
                            avg_delta += (interaction_probs[y_true] + interaction_probs[origin_y_pred])
                    avg_delta /= len(data_train)
                    interaction_head_dict[h] = avg_delta
                # TOP_NUM = int(self.decoder_head_num / 2)
                TOP_NUM = self.top_k_h
                if len(interaction_head_dict) > TOP_NUM and TOP_NUM > 0:
                    interaction_head_dict = dict(sorted(interaction_head_dict.items(), key=lambda item: item[1], reverse=True)[:TOP_NUM])
                    
                for h in interaction_head_dict.keys():
                    attn_t_sum[l][0].append(h)
        attn_t_sum = [intervene_way, attn_t_sum]
        return attn_t_sum

                    
    def contrast_find_layer_list(self, data_train, layer_list=None, T=0.5, threshold=0, bias=True, reverse=False, intervene_way=0, use_KL=False):
        """
        data_train :
        y_true : 真值标签 生成的last token
        content : 上文内容(last token 前的字符串)
        """
        if not self.label_id_list:
            raise ValueError("The label list is empty, please set_label_id(task name)")
        if layer_list == None:
            layer_list = self.choose_layer_list

        interaction_layer_dict = {k: 0 for k in layer_list}
        with torch.no_grad():
            for layer in layer_list:
                avg_delta = 0
                attn_t = [intervene_way, {layer: (range(0, self.decoder_head_num), T)}]
                for data in data_train:
                    y_true = data["y_true"]
                    bias_content = data["bias_prompt"]
                    content = data["prompt"]
                    bias_input_ids = self.tokenizer(bias_content, return_tensors='pt').input_ids.to(self.device)
                    input_ids = self.tokenizer(content, return_tensors='pt').input_ids.to(self.device)
                    # interaction
                    outputs = self.model(bias_input_ids, attention_temperature=attn_t, 
                                        output_attentions=True, return_dict=True)["logits"][0].squeeze(0)
                    logits = outputs.log_softmax(-1)  # logits to log probs
                    logits = logits[-1, :]
                    interaction_probs = (
                        torch.nn.functional.softmax(
                            torch.tensor([logits[label_id] for label_id in self.label_id_list]),
                            dim=0,
                        ).detach().cpu().to(torch.float32).numpy()
                    )
                    # origin
                    outputs = self.model(input_ids, output_attentions=True, 
                                        return_dict=True)["logits"][0].squeeze(0)
                    logits = outputs.log_softmax(-1)  # logits to log probs
                    logits = logits[-1, :]
                    origin_probs = (
                        torch.nn.functional.softmax(
                            torch.tensor([logits[label_id] for label_id in self.label_id_list]),
                            dim=0,
                        ).detach().cpu().to(torch.float32).numpy()
                    )
                    origin_y_pred = int(np.argmax(origin_probs))
                    if use_KL:
                        # KL = scipy.stats.entropy(interaction_probs, origin_probs)
                        true_probs = torch.zeros(4)
                        true_probs[y_true] = 1
                        kl_1 = jensenshannon(interaction_probs, origin_probs, base=2)
                        kl_2 = jensenshannon(interaction_probs, true_probs, base=2)
                        # avg_delta += (interaction_probs[y_true] + 1 - KL)
                        value = 1 - kl_1*kl_1 + 1 - kl_2*kl_2
                        temp = (interaction_probs[y_true] + interaction_probs[origin_y_pred])
                        print(f" {value} / {temp}")
                        avg_delta += value
                    else:
                        avg_delta += (interaction_probs[y_true] + interaction_probs[origin_y_pred])
                avg_delta /= len(data_train)
                interaction_layer_dict[layer] = avg_delta
            # print("interaction layer dict {}".format(interaction_layer_dict))
            TOP_NUM = self.top_k_l
            if len(interaction_layer_dict) > TOP_NUM and TOP_NUM > 0:
                interaction_layer_dict = dict(sorted(interaction_layer_dict.items(), 
                                                    key=lambda item: item[1], reverse=True)[:TOP_NUM])
            # interaction_layer_dict = {key: value for key, value in interaction_layer_dict.items() if value > 0.001}
            attention_temperature_list = []
            for layer in interaction_layer_dict.keys():
                for head in range(0, self.decoder_head_num):
                    attention_temperature = [intervene_way, {layer:([head], T)}]
                    attention_temperature_list.append(attention_temperature)
            # print("Begin to choose head on layer {}".format(interaction_layer_dict.keys()))
            layer_list = list(interaction_layer_dict.keys())
        return layer_list
    
    def check_bias_attn_t(self, data_train, attn_t, prompt_bias=True):
        avg_delta = 0
        acc_origin = 0
        acc_interaction = 0
        for data in tqdm(data_train):
            y_true = data["y_true"]
            content = data["bias_prompt"] if prompt_bias else data["prompt"]
            input_ids = self.tokenizer(content, return_tensors='pt').input_ids.to(self.device)
            # interaction
            outputs = self.model(input_ids, attention_temperature=attn_t, output_attentions=True, return_dict=True)["logits"][0].squeeze(0)
            logits = outputs.log_softmax(-1)  # logits to log probs
            logits = logits[-1, :]
            interaction_probs = (
                torch.nn.functional.softmax(
                    torch.tensor([logits[label_id] for label_id in self.label_id_list]),
                    dim=0,
                ).detach().cpu().to(torch.float32).numpy()
            )
            # origin
            outputs = self.model(input_ids, output_attentions=True, return_dict=True)["logits"][0].squeeze(0)
            logits = outputs.log_softmax(-1)  # logits to log probs
            logits = logits[-1, :]
            origin_probs = (
                torch.nn.functional.softmax(
                    torch.tensor([logits[label_id] for label_id in self.label_id_list]),
                    dim=0,
                ).detach().cpu().to(torch.float32).numpy()
            )
            y_1 = interaction_probs[y_true]
            y_2 = origin_probs[y_true]
            delta = y_1 - y_2
            if int(np.argmax(origin_probs)) == y_true:
                acc_origin += 1
            if int(np.argmax(interaction_probs)) == y_true:
                acc_interaction += 1
        acc_origin /= len(data_train)
        acc_interaction /= len(data_train)
        acc_delta = acc_interaction - acc_origin
        avg_delta += delta
        avg_delta /= len(data_train)
        return acc_delta

    def eval_model(self, data_train, attn_t=1, prompt_bias=True, max_length=2048):
        avg_acc = 0
        valid_num = 0
        for data in tqdm(data_train):
            y_true = data["y_true"]
            content = data["bias_prompt"] if prompt_bias else data["prompt"]
            input_ids = self.tokenizer(content, return_tensors='pt').input_ids.to(self.device)
            if len(input_ids[0]) > max_length:
                continue
            # origin
            outputs = self.model(
                input_ids,
                output_attentions=True,
                attention_temperature=attn_t, 
                return_dict=True
            )["logits"][0].squeeze(0)
            logits = outputs.log_softmax(-1)  # logits to log probs
            logits = logits[-1, :]
            origin_probs = (
                torch.nn.functional.softmax(
                    torch.tensor([logits[label_id] for label_id in self.label_id_list]),
                    dim=0,
                ).detach().cpu().to(torch.float32).numpy()
            )
            y_2 = origin_probs[y_true]
            avg_acc += y_2
            valid_num += 1
        return (avg_acc/valid_num)

    def get_delta_coef_attn_t(self, data_train, attn_t):
        delta_prompt_coef = np.zeros(self.decoder_layer_num) 
        delta_problem_coef = np.zeros(self.decoder_layer_num) 
        delta_cot_coef = np.zeros(self.decoder_layer_num) 
        valid_num = 0
        for i, data in enumerate(data_train):
            prompt = data["prompt"]
            question = data["problem"]
            cot = data["cot"] if "cot" in data.keys() else ""
            prompt_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
            prefix_ids = self.tokenizer(prompt+question, return_tensors="pt").input_ids.to(self.device)
            content_ids = self.tokenizer(prompt+question+cot, return_tensors="pt").input_ids.to(self.device)
            if len(content_ids[0]) > MAX_TOKENS:
                continue
            else:
                problem_pos = prompt_ids.shape[-1]
                cot_pos = prefix_ids.shape[-1]
                origin_prompt_coef, origin_problem_coef, origin_cot_coef = self.get_focus_coef(content_ids, problem_pos, cot_pos, attn_t=1)
                interaction_prompt_coef, interaction_problem_coef, interaction_cot_coef = self.get_focus_coef(content_ids, problem_pos, cot_pos, attn_t=attn_t)
                delta_prompt_coef += interaction_prompt_coef - origin_prompt_coef
                delta_problem_coef += interaction_problem_coef - origin_problem_coef
                delta_cot_coef += interaction_cot_coef - origin_cot_coef
                valid_num += 1
        return delta_prompt_coef/valid_num, delta_problem_coef/valid_num, delta_cot_coef/valid_num

    def lm_score(self, prompt, question, input_text2, input_text3=None, attn_t=1, pmi=False, max_new_tokens=256, top_p=0.95, top_k=0, temperature=0.8, mature_layer=None, premature_layer=None, candidate_premature_layers=[], mode='baseline', verbose=True, remove_stop_words=False, relative_top=0.1, relative_top_value=-1000.0, post_softmax=True, **kwargs):
        # import ipdb; ipdb.set_trace()
        with torch.no_grad():
            input_text1 = prompt + question
            input_text = input_text1 + input_text2
            input_ids = self.tokenizer(input_text, return_tensors="pt").input_ids.to(self.device)
            prefix_ids = self.tokenizer(input_text1, return_tensors="pt").input_ids.to(self.device)
            # 指定生成的答案
            continue_ids = input_ids[0, prefix_ids.shape[-1]:]
            if mode == 'baseline':
                outputs = self.model(input_ids, output_attentions=True, return_dict=True)["logits"][0].squeeze(0)
                outputs = outputs.log_softmax(-1)  # logits to log probs

                # skip tokens in the prompt -- we only care about the answer
                outputs = outputs[prefix_ids.shape[-1] - 1: -1, :]

                # get logprobs for each token in the answer
                log_probs = outputs[range(outputs.shape[0]), continue_ids].sum().item()
            elif mode == "cot-enhance":
                outputs = self.model(input_ids, attention_temperature=attn_t)[0].squeeze(0)
                outputs = outputs.log_softmax(-1)  # logits to log probs
                # skip tokens in the prompt -- we only care about the answer
                outputs = outputs[prefix_ids.shape[-1] - 1: -1, :]
                # get logprobs for each token in the answer
                log_probs = outputs[range(outputs.shape[0]), continue_ids].sum().item()
            elif mode == 'dola-static':
                dict_outputs, outputs = self.model(
                    input_ids=input_ids,
                    return_dict=True,
                    output_attentions=False,
                    output_hidden_states=False,
                    early_exit_layers=[premature_layer, mature_layer],
                )

                assert premature_layer is not None
                base_logits = dict_outputs[premature_layer][0, prefix_ids.shape[-1] - 1: -1, :]
                final_logits = dict_outputs[mature_layer][0, prefix_ids.shape[-1] - 1: -1, :]
                final_logits = final_logits.log_softmax(dim=-1)
                base_logits = base_logits.log_softmax(dim=-1)
                diff_logits = final_logits - base_logits
                if post_softmax:
                    diff_logits = diff_logits.log_softmax(dim=-1)
                if relative_top > 0.0:
                    relative_top_mask = self.get_relative_top_filter(final_logits, relative_top)
                    diff_logits = torch.where(relative_top_mask, relative_top_value, diff_logits)
                    
                log_probs = diff_logits[range(diff_logits.shape[0]), continue_ids].sum().item()

            elif mode == 'dola':
                premature_layer_dist = {l:0 for l in candidate_premature_layers}
                picked_logits = []
                result_dict = {}
                premature_layers = []

                dict_outputs, outputs = self.model(
                    input_ids=input_ids,
                    return_dict=True,
                    output_attentions=False,
                    output_hidden_states=False,
                    early_exit_layers=candidate_premature_layers + [mature_layer],
                )

                for seq_i in range(prefix_ids.shape[-1] - 1, input_ids.shape[-1] - 1):
                    # Pick the less like layer to contrast with
                    # 1. Stacking all premature_layers into a new dimension
                    stacked_premature_layers = torch.stack([dict_outputs[i][:, seq_i, :] for i in candidate_premature_layers], dim=0)

                    # 2. Calculate the softmax values for mature_layer and all premature_layers
                    softmax_mature_layer = F.softmax(dict_outputs[mature_layer][:, seq_i, :], dim=-1)  # shape: (batch_size, num_features)
                    softmax_premature_layers = F.softmax(stacked_premature_layers, dim=-1)  # shape: (num_premature_layers, batch_size, num_features)

                    # 3. Calculate M, the average distribution
                    M = 0.5 * (softmax_mature_layer[None, :, :] + softmax_premature_layers)  # shape: (num_premature_layers, batch_size, num_features)

                    # 4. Calculate log-softmax for the KL divergence
                    log_softmax_mature_layer = F.log_softmax(dict_outputs[mature_layer][:, seq_i, :], dim=-1)  # shape: (batch_size, num_features)
                    log_softmax_premature_layers = F.log_softmax(stacked_premature_layers, dim=-1)  # shape: (num_premature_layers, batch_size, num_features)

                    # 5. Calculate the KL divergences and then the JS divergences
                    kl1 = F.kl_div(log_softmax_mature_layer[None, :, :], M, reduction='none').mean(-1)  # shape: (num_premature_layers, batch_size)
                    kl2 = F.kl_div(log_softmax_premature_layers, M, reduction='none').mean(-1)  # shape: (num_premature_layers, batch_size)
                    js_divs = 0.5 * (kl1 + kl2)  # shape: (num_premature_layers, batch_size)

                    # 6. Reduce the batchmean
                    js_divs = js_divs.mean(-1)  # shape: (num_premature_layers,)
                    premature_layer = candidate_premature_layers[int(js_divs.argmax().cpu().item())]
                    premature_layer_dist[premature_layer] += 1

                    premature_layers.append(premature_layer)

                base_logits = torch.zeros_like(dict_outputs[mature_layer][0, prefix_ids.shape[-1] - 1:-1])
                for i, l in enumerate(premature_layers):
                   base_logits[i] = dict_outputs[l][0, prefix_ids.shape[-1] - 1 + i]
                final_logits = dict_outputs[mature_layer][0, prefix_ids.shape[-1] - 1:-1]
                final_logits = final_logits.log_softmax(dim=-1)
                base_logits = base_logits.log_softmax(dim=-1)
                diff_logits = final_logits - base_logits
                if post_softmax:
                    diff_logits = diff_logits.log_softmax(dim=-1)

                if relative_top > 0.0:
                    relative_top_mask = self.get_relative_top_filter(final_logits, relative_top)
                    diff_logits = torch.where(relative_top_mask, relative_top_value, diff_logits)
                
                log_probs = diff_logits[range(diff_logits.shape[0]), continue_ids].sum().item()
                
            elif mode == 'contrastive-decoding':
                # import ipdb; ipdb.set_trace()
                assert self.amateur_model is not None
                base_outputs = self.model(input_ids)[0].squeeze(0)
                base_logits = base_outputs.log_softmax(-1)[prefix_ids.shape[-1] - 1: -1, :]
                
                amateur_outputs = self.amateur_model(input_ids)[0].squeeze(0)
                amateur_logits = amateur_outputs.log_softmax(-1)[prefix_ids.shape[-1] - 1: -1, :]
                
                diff_logits = base_logits - amateur_logits
                if post_softmax:
                    diff_logits = diff_logits.log_softmax(dim=-1)
                if relative_top > 0.0:
                    relative_top_mask = self.get_relative_top_filter(base_logits, relative_top)
                    diff_logits = torch.where(relative_top_mask, relative_top_value, diff_logits)
                    
                log_probs = diff_logits[range(diff_logits.shape[0]), continue_ids].sum().item()
                
            elif mode == 'prompt-contrastive-decoding':
                # import ipdb; ipdb.set_trace()
                assert input_text3 is not None  # evil prompt
                input_text_evil = input_text3 + input_text2
                input_ids_evil = self.tokenizer(input_text_evil, return_tensors="pt").input_ids.to(self.device)
                prefix_ids_evil = self.tokenizer(input_text3, return_tensors="pt").input_ids.to(self.device)
                
                base_outputs = self.model(input_ids)[0].squeeze(0)
                base_logits = base_outputs.log_softmax(-1)[prefix_ids.shape[-1] - 1: -1, :]
                
                evil_outputs = self.model(input_ids_evil)[0].squeeze(0)
                evil_logits = evil_outputs.log_softmax(-1)[prefix_ids_evil.shape[-1] - 1: -1, :]
                
                diff_logits = base_logits - evil_logits
                if post_softmax:
                    diff_logits = diff_logits.log_softmax(dim=-1)
                if relative_top > 0.0:
                    relative_top_mask = self.get_relative_top_filter(base_logits, relative_top)
                    diff_logits = torch.where(relative_top_mask, relative_top_value, diff_logits)
                    
                log_probs = diff_logits[range(diff_logits.shape[0]), continue_ids].sum().item()

        return log_probs, (premature_layer_dist if mode == 'dola' else None)
    
    
    def lm_prob(self, input_text1, input_text2, input_text3=None, pmi=False, max_new_tokens=256, top_p=0.95, top_k=0, temperature=0.8, mature_layer=None, premature_layer=None, candidate_premature_layers=[], mode='baseline', verbose=True, remove_stop_words=False, relative_top=0.1, relative_top_value=-1000.0, post_softmax=True, **kwargs):
        # for calibration, return average prob of each answer
        with torch.no_grad():
            input_text = input_text1 + input_text2
            input_ids = self.tokenizer(input_text, return_tensors="pt").input_ids.to(self.device)
            prefix_ids = self.tokenizer(input_text1, return_tensors="pt").input_ids.to(self.device)
            continue_ids = input_ids[0, prefix_ids.shape[-1]:]
            if mode == 'baseline':
                outputs = self.model(input_ids)[0].squeeze(0)
                outputs = outputs.softmax(-1)  # logits to log probs

                # skip tokens in the prompt -- we only care about the answer
                outputs = outputs[prefix_ids.shape[-1] - 1: -1, :]

                # get logprobs for each token in the answer
                mean_probs = outputs[range(outputs.shape[0]), continue_ids].mean().item()
 
            elif mode == 'contrastive-decoding':
                # import ipdb; ipdb.set_trace()
                assert self.amateur_model is not None
                base_outputs = self.model(input_ids)[0].squeeze(0)
                base_logits = base_outputs.log_softmax(-1)[prefix_ids.shape[-1] - 1: -1, :]
                
                amateur_outputs = self.amateur_model(input_ids)[0].squeeze(0)
                amateur_logits = amateur_outputs.log_softmax(-1)[prefix_ids.shape[-1] - 1: -1, :]
                
                diff_logits = base_logits - amateur_logits
                diff_logits = diff_logits.softmax(dim=-1)
                if relative_top > 0.0:
                    relative_top_mask = self.get_relative_top_filter(base_logits, relative_top)
                    diff_logits = torch.where(relative_top_mask, relative_top_value, diff_logits)
                    
                mean_probs = diff_logits[range(diff_logits.shape[0]), continue_ids].mean().item()
                
        return mean_probs, (premature_layer_dist if mode == 'dola' else None)
