import sys
import os
import pandas as pd
import numpy as np
import random
from tqdm import tqdm
import time
import argparse
current_path = os.path.abspath(os.path.dirname(os.getcwd()))
sys.path.append(os.path.join(current_path))
sys.path.append(os.path.join(current_path, "src")) 
from decoding_algorithm import ContrastiveDecoding
from utils.format_data_mmlu import get_mmlu_data, get_contrast_data
from loguru import logger

# CUDA_VISIBLE_DEVICES=2 python mmlu_eval.py --model-name /mnt/llms/model/meta-llama/Llama-3-8b-hf/ --data-path ../../data/MMLU/ --num-gpus=1
# CUDA_VISIBLE_DEVICES=4 python mmlu_eval.py --model-name /mnt/llms/model/google-gemma/gemma-2b --data-path ../../data/MMLU/ --num-gpus=1
T = 0.5
# gemma-2b
attn_t = [0, {12: ([3, 7, 2, 1], T), 14: ([0, 1, 6, 7], T)}]
# llama3-8b
# att_t = [0, {17: ([24, 25, 26, 28], T), 14: ([23, 5, 4, 20], T)}]
attn_t = 1

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-name", type=str, default="huggyllama/llama-7b")
    parser.add_argument("--data-path", type=str, default="/mnt/llms/data/MMLU/")
    parser.add_argument("--num-gpus", type=str, default="2")
    parser.add_argument("--intervene-way", type=int, default=0)
    parser.add_argument("--device", type=str, choices=["cuda", "cpu"], default="cuda")
    args = parser.parse_args()
    model_name = args.model_name
    num_gpus = args.num_gpus
    device = args.device
    logger.add("mmlu_eval_prompt_bias_{}.log".format(model_name.split("/")[-1]), level="DEBUG")
    llm = ContrastiveDecoding(model_name, device, num_gpus=int(args.num_gpus))
    stop_word_list = ["Q:"]
    llm.set_stop_words(stop_word_list)
    data_path = args.data_path
    llm.set_label_id("mmlu")
    param_dict_list = [
        {"prompt_bias": "A", "bias_ans": "A"},
        {"prompt_bias": "A", "bias_ans": "B"},   
        {"prompt_bias": "A", "bias_ans": "C"},
        {"prompt_bias": "A", "bias_ans": "D"},
        {"prompt_bias": "B", "bias_ans": "A"},
        {"prompt_bias": "B", "bias_ans": "B"},
        {"prompt_bias": "B", "bias_ans": "C"},
        {"prompt_bias": "B", "bias_ans": "D"},
        {"prompt_bias": "C", "bias_ans": "A"},
        {"prompt_bias": "C", "bias_ans": "B"},
        {"prompt_bias": "C", "bias_ans": "C"},
        {"prompt_bias": "C", "bias_ans": "D"},
        {"prompt_bias": "D", "bias_ans": "A"},
        {"prompt_bias": "D", "bias_ans": "B"},
        {"prompt_bias": "D", "bias_ans": "C"},
        {"prompt_bias": "D", "bias_ans": "D"},
    ]
    result = {
        "prompt_no_bias":{}, 
        "prompt_bias_A": {}, 
        "prompt_bias_B": {}, 
        "prompt_bias_C": {}, 
        "prompt_bias_D": {}
    }

    all_data = get_mmlu_data(
        llm=llm, 
        bias=False, 
        data_path=data_path,
        prompt_bias_list=["A", "B", "C", "D", "A"]
    )
    acc_no_bias = llm.eval_model(all_data, attn_t=1, prompt_bias=False)
    print("acc_no_bias")
    acc_no_bias = llm.eval_model(all_data, attn_t=1, prompt_bias=True)

    ans_bias = False
    prompt_bias = False
    logger.info(f"Use attn {attn_t} ...")
    logger.info(f"=== ans bias {ans_bias} prompt bias {prompt_bias} ===")
    all_data = get_mmlu_data(
        llm=llm,
        prompt_num=0,
        bias=False, 
        data_path=data_path
    )
    acc_no_bias = llm.eval_model(all_data, attn_t=attn_t, prompt_bias=False)
    logger.info(f"ZERO SHOT ACC: {acc_no_bias}")

    logger.info(f"=== ans bias {ans_bias} prompt bias {prompt_bias} ===")
    all_data = get_mmlu_data(
        llm=llm, 
        bias=False, 
        data_path=data_path
    )
    acc_no_bias = llm.eval_model(all_data, attn_t=attn_t, prompt_bias=False)
    logger.info(f"5 SHOT ACC: {acc_no_bias}")

    for param_dict in param_dict_list:
        prompt_bias = param_dict["prompt_bias"]
        bias_ans = param_dict["bias_ans"]
        all_data = get_mmlu_data(
            llm=llm, 
            prompt_bias=prompt_bias, 
            bias=True, 
            bias_ans=bias_ans,
            data_path=data_path,
            prompt_bias_list=["A", "B", "C", "D", "A"]
        )
        if bias_ans not in result["prompt_no_bias"].keys():
            acc_no_bias = llm.eval_model(all_data, attn_t=attn_t, prompt_bias=False)
            result["prompt_no_bias"][bias_ans] = acc_no_bias
        acc_bias = llm.eval_model(all_data, attn_t=attn_t, prompt_bias=True)
        result["prompt_bias_"+prompt_bias][bias_ans] = acc_bias
        logger.info(f"prompt_bias: {prompt_bias}, ans_bias: {bias_ans}, acc: {acc_bias:.3f}")

    df = pd.DataFrame(result)
    # df.columns = ['A', 'B', 'C', 'D']
    # Convert to Markdown table
    markdown_table = df.to_markdown()
    # Print the Markdown table
    logger.info("\n{}".format(markdown_table))

: