import sys
import os
import json
import numpy as np
import random
from tqdm import tqdm
import argparse
current_path = os.path.abspath(os.path.dirname(os.getcwd()))
sys.path.append(os.path.join(current_path))
sys.path.append(os.path.join(current_path, "src")) 
from decoding_algorithm import ContrastiveDecoding
from utils.format_data_mmlu import get_mmlu_data

PROMPT_NUM = 5 # 1, 2, 3, 4, 5
N_CHOICE = 4
BIAS_ANS = "A"
# intervene_way
# 0: softmax前增加temperature操作
# 1: 将小于某个值的attn weights 置0
# 2：attn weights 平移
# 3: attn weights 置均值


def find_layer_head(llm, n, m, data_path, intervene_way=0, T=0.5, need_eval=False):
    SAMPLE_NUM = int(n * m / 4)
    all_data_A_1 = get_mmlu_data(llm=llm, prompt_bias="A", data_path=data_path)
    # all_data_A_2 = get_mmlu_data(llm=llm, prompt_num=4, prompt_bias="A", data_path=data_path)
    all_data_B_1 = get_mmlu_data(llm=llm, prompt_bias="B", data_path=data_path)
    # all_data_B_2 = get_mmlu_data(llm=llm, prompt_num=4, prompt_bias="B", data_path=data_path)
    all_data_C_1 = get_mmlu_data(llm=llm, prompt_bias="C", data_path=data_path)
    # all_data_C_2 = get_mmlu_data(llm=llm, prompt_num=4, prompt_bias="C", data_path=data_path)
    all_data_D_1 = get_mmlu_data(llm=llm, prompt_bias="D", data_path=data_path)
    # all_data_D_2 = get_mmlu_data(llm=llm, prompt_num=4, prompt_bias="D", data_path=data_path)
    data_dev = []
    data_test = []
    indexes = list(range(len(all_data_A_1)))
    # indexes = list(range(len(contrast_data)))
    random.shuffle(indexes)
    for i in indexes[:SAMPLE_NUM]:
        data_dev.append(all_data_A_1[i])
        # data_dev.append(all_data_A_2[i])
        data_dev.append(all_data_B_1[i])
        # data_dev.append(all_data_B_2[i])
        data_dev.append(all_data_C_1[i])
        # data_dev.append(all_data_C_2[i])
        data_dev.append(all_data_D_1[i])
        # data_dev.append(all_data_D_2[i])
    for i in indexes[SAMPLE_NUM:]:
        data_test.append(all_data_A_1[i])

    layer_importance = {l: 0 for l in llm.choose_layer_list}
    for idx in range(n):
        data = [data_dev[i] for i in range((idx * m), ((idx + 1) * m))]
        layer_list = llm.contrast_find_layer_list(
            data, T=T, threshold=0, bias=True, intervene_way=intervene_way, use_KL=False
        )
        for l in layer_list:
            layer_importance[l] += 1
    print(layer_importance)
    choose_layer_list = list(dict(sorted(layer_importance.items(), key=lambda item: item[1], reverse=True)[:2]).keys())
    # 计算集中度
    s = 0
    for l in choose_layer_list:
        s += layer_importance[l]
    s /= (llm.top_k_l * n)
    print("stability : {}".format(s))
    attn_t = llm.contrast_find_head_list(data_dev, T=T, layer_list=choose_layer_list, intervene_way=intervene_way)
    print("model name {}: attn_t {}".format(llm.model_name, attn_t))
    # 测试MMLU数据集上提升的精度
    # print(llm.check_bias_attn_t(data_train, attn_t, bias=True))
    # print(llm.check_bias_attn_t(data_train, attn_t, bias=False))
    if need_eval:
        ordinary_delta_acc = llm.check_bias_attn_t(data_test, attn_t, prompt_bias=False)
        biased_delta_acc = llm.check_bias_attn_t(data_test, attn_t, prompt_bias=True)
    return {"s": s, "ordinary_delta_acc": ordinary_delta_acc, "biased_delta_acc": biased_delta_acc}

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-name", type=str, default="huggyllama/llama-7b")
    parser.add_argument("--data-path", type=str, default="/mnt/llms/data/MMLU/")
    parser.add_argument("--num-gpus", type=str, default="2")
    parser.add_argument("--intervene-way", type=int, default=0)
    parser.add_argument("--device", type=str, choices=["cuda", "cpu"], default="cuda")
    args = parser.parse_args()
    model_name = args.model_name
    device = args.device
    llm = ContrastiveDecoding(model_name, device, num_gpus=int(args.num_gpus))
    stop_word_list = ["Q:"]
    llm.set_stop_words(stop_word_list)
    data_path = args.data_path
    llm.set_label_id("mmlu")
    param_list = [(4, 5), (8, 5), (8, 10), (8, 20)]
    for param in param_list: 
        n, m = param
        result = find_layer_head(llm, n=n, m=m, data_path=data_path, need_eval=True)
        print(result)

if __name__ == "__main__":
    main()

    
    

