
import sys
import os
import json
import numpy as np
import random
import matplotlib.pyplot as plt  
from tqdm import tqdm
os.environ["CUDA_VISIBLE_DEVICES"]="6"
current_path = os.path.abspath(os.path.dirname(os.getcwd()))
sys.path.append(os.path.join(current_path))
sys.path.append(os.path.join(current_path, "src")) 
from decoding_algorithm import ContrastiveDecoding
from utils.format_data_mmlu import get_mmlu_data
x=np.arange(0.1, 1.5, 0.1)
# model_name = "/mnt/llms/model/meta-llama/Llama-2-7b-hf"
# model_name = "/mnt/llms/model/meta-llama/Llama-3-8b-hf"
model_name="/mnt/llms/model/google-gemma/gemma-2b"
# model_name= "/mnt/llms/model/meta-llama/Llama-2-13b-hf"
llm = ContrastiveDecoding(model_name, num_gpus=1)
stop_word_list = ["Q:"]
llm.set_stop_words(stop_word_list)

no_bias_delta_list = []
bias_delta_list = []
"""
n = 10
m = 8
SAMPLE_NUM = n * m
head_count = np.zeros((llm.decoder_layer_num, llm.decoder_head_num))
all_data = get_mmlu_data(llm)

for T in x:
    if "gemma-2b" in model_name.lower():
        attn_t = [0, {12: ([3, 7, 2, 1], T), 14: ([0, 4, 2, 7], T)}]
    if "gemma-7b" in model_name.lower():
        attn_t = [0, {18: (range(0, 16), T)}]
    if "llama-2-7b" in model_name.lower():
        attn_t = [0, {14: ([24, 4, 20, 31], T), 18: ([30, 10, 25, 28], T)}]
    if "llama-3-8b" in model_name.lower():
        attn_t = [0, {17: ([0, 1, 3, 4, 5, 6, 7, 9, 10, 12, 13, 14, 16, 17, 20, 21, 24, 25, 26, 27, 28, 29, 30, 31], T)}]
    if "mistral-7b" in model_name.lower():
        attn_t = [0, {16: ([12, 14, 13, 0], T), 19: ([8, 9, 16, 10], T)}]
    print(attn_t)
    bias_delta_list.append(llm.check_bias_attn_t(all_data, attn_t, bias=True))
    no_bias_delta_list.append(llm.check_bias_attn_t(all_data, attn_t, bias=False))
    print(bias_delta_list)
    print(no_bias_delta_list)
"""
bias_delta_list = [0.0202962540948583, 0.023429710867397813, 0.026135878080045583, 0.027275316906423597, 0.025922233300099684, 0.02307363623415465, 0.0175188719555619, 0.012177752456914992, 0.005910838911835914, 0.0, -0.007121492664862561, -0.010895883777239712, -0.01652186298248115, -0.021222048141290417]
no_bias_delta_list = [0.016165788349237986, 0.017305227175616, 0.01766130180885911, 0.018231021222048116, 0.016806722689075626, 0.014100555476427856, 0.011038313630536922, 0.010468594217347915, 0.004700185158809267, 0.0, -0.009329155390969956, -0.01445663010967102, -0.0222902720410198, -0.02841475573280161]

plt.plot(x, no_bias_delta_list,'+-', color="#1679AB", label='unbiased (T < 1)')
plt.plot(x, bias_delta_list,'.-', color="#074173", label='biased (T < 1)')

no_bias_delta_list = []
bias_delta_list = []

n = 10
m = 8
SAMPLE_NUM = n * m
head_count = np.zeros((llm.decoder_layer_num, llm.decoder_head_num))
all_data = get_mmlu_data(llm, bias_ans="C")
data_train = []
data_test = []
indexes = list(range(len(all_data)))
random.shuffle(indexes)
for i in indexes[:SAMPLE_NUM]:
    data_train.append(all_data[i])
for i in indexes[SAMPLE_NUM:SAMPLE_NUM+1000]:
    data_test.append(all_data[i])

no_bias_delta_train_list = []
bias_delta_train_list = []
no_bias_delta_test_list = []
bias_delta_test_list = []
for T in x:
    if "llama-2-7b" in model_name.lower():
        attn_t = [0, {11: ([6, 8, 29, 16], T), 9: ([1, 4, 30, 5], T)}]
    if "gemma-2b" in model_name.lower():
        attn_t = [0, {9: ([4, 5, 3, 0], T), 2: ([4, 1, 3, 5], T)}]
    print(attn_t)
    bias_delta_list.append(llm.check_bias_attn_t(all_data, attn_t, bias=True))
    no_bias_delta_list.append(llm.check_bias_attn_t(all_data, attn_t, bias=False))
    print(bias_delta_list)
    print(no_bias_delta_list)
print(bias_delta_list)
print(no_bias_delta_list)

bias_delta_list = [-0.07655604614727246, -0.06801025494943741, -0.056687081612305956, -0.044794188861985496, -0.03617718273750181, -0.027488961686369495, -0.01986896453496656, -0.012391397236860835, -0.005982053838484547, 0.0, 0.001994017946161497, 0.003204671699188144, 0.0019228030195128643, 0.0021364477994587627]
no_bias_delta_list = [-0.07142857142857145, -0.05853866970517024, -0.047073066514741524, -0.037316621563879826, -0.02777382139296397, -0.021435692921236316, -0.014527845036319653, -0.00861700612448374, -0.0031334567725395668, 0.0, 0.0010682238997293814, 0.004059250818971627, 0.005198689645349641, 0.002278877652755973]

plt.plot(x, no_bias_delta_list, '+--', color="#1679AB", label='unbiased (T > 1)')
plt.plot(x, bias_delta_list, '.--', color="#074173", label='biased (T > 1)')
plt.xlabel('T')
plt.ylabel('$\delta$')
plt.legend()
plt.savefig("different_T_influence.pdf")
