# %%
import os
import argparse
from utils.draw_heatmap import draw_heatmap, make_latex_table, calc_confident_interval

model_names = [
    "qwen_0.5B",
    "qwen_1.5B",
    "qwen_3B",
    "qwen_7B",
    "qwen_14B",
    "qwen_32B",
    "qwen_72B",
    "llama_8B",
    "llama_70B",
    "deepseekv3",
    "gpt-4o-mini",
    "gpt-4o",
]

def get_args():

    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_method', type=str, default='lime', choices=['lime', 'shap','shapley'])
    args,_ = parser.parse_known_args()
    return args
args = get_args()
exp_method = args.exp_method

# for file in os.listdir('lime_samples'):
#     if not file.startswith('sst_test_samples_') or 'counted' in file:
#         continue
#     model_names.append(os.path.splitext(file[len('sst_test_samples_'):])[0])

# model_names

# %%
# model_names = ['bert']

# %%

# %%
import pandas as pd

# %%


# %%
res_df = {}
exp_df = {}

# %%
build_pools = False

# %%
from tqdm.auto import tqdm

# %%
for model_name in tqdm(model_names):
    exp_df[model_name] = pd.read_csv(f'{exp_method}_res/{model_name}.csv',index_col=0)
    res_df[model_name] = pd.read_csv(f'lime_samples/nq_perturb_{model_name}.csv', sep='\t', index_col=None, keep_default_na=False, dtype={'binary_representation': str})


# %%


# %%


# %%
import torch

# %%
import numpy as np
acc_res = np.zeros((len(exp_df),len(exp_df)))

# %%
def aopc_fn(scores):
    input_list = scores.copy()
    def interpolation_func(x):
        if x < 0 or x > 10:
            raise ValueError("Input should be in the range [0, 10]")
        
        # 获取索引，计算小数部分
        idx = x * (len(input_list) - 1) / 10
        lower_idx = int(np.floor(idx))
        upper_idx = int(np.ceil(idx))
        
        # 如果命中下标，直接返回该下标的值
        if lower_idx == upper_idx:
            return input_list[lower_idx]
        
        # 否则进行线性插值
        lower_value = input_list[lower_idx]
        upper_value = input_list[upper_idx]
        weight = idx - lower_idx  # 计算插值的权重
        
        return lower_value * (1 - weight) + upper_value * weight

    return interpolation_func


# %%
from tqdm.auto import tqdm

# %%
import numpy as np

# %%
import torch

# %%
def calc_aopc(local_exp,loc_res_df):
    to_be_calced = []
    base = loc_res_df[loc_res_df['binary_representation']==('1'*len(local_exp))]
    base_value = base['scores'].values[0]
    # if base_value<=0.5:
    #     local_exp = - local_exp
    idxes = np.argsort(-local_exp)
    aopcs = [0]
    # accs = [1]
    init_bin = np.ones(len(local_exp),dtype=str)
    calc_aopc = True
    for i in idxes:
        idx = idxes[i]
        if local_exp[i] < 0:
            break
        init_bin[idx] = '0'
        if ''.join(init_bin) not in loc_res_df['binary_representation'].values:
            to_be_calced.append(''.join(init_bin))
            calc_aopc = False
            continue
        now_row = loc_res_df[loc_res_df['binary_representation']==''.join(init_bin)]
        now_value = now_row['scores'].values[0]
        aopcs.append(np.abs(now_value-base_value))
        # accs.append( (now_value>0.5) == (base_value>0.5))
    if not calc_aopc or len(aopcs) < 2:
        return False, to_be_calced
    if(len(aopcs)==1):
        aopc = aopcs[0]
    else:
        aopc = (np.sum(aopcs) - (aopcs[0] + aopcs[-1]) / 2) / (len(aopcs) - 1)
    # acc = np.mean(accs)
    return True, aopc

# %%
import spacy
nlp = spacy.load("en_core_web_lg")

# %%
# acc_res = np.zeros((len(exp_df),len(exp_df)))
aopc_res =  np.zeros((len(exp_df),len(exp_df)))
aopc_ci = np.zeros((len(exp_df),len(exp_df)))
# %%
for (j,model2) in tqdm(enumerate(model_names),position=0, dynamic_ncols=True, total=len(model_names)):
    local_ress = []
    max_idx = max(res_df[model2]['question_index'].values)

    # accs = [[] for _ in range(len(model_names))]
    aopcs = [[] for _ in range(len(model_names))]
    flag = True
    for idx in tqdm(range(max_idx+1), position=1, dynamic_ncols=True):

        loc_res_df = res_df[model2][res_df[model2]['question_index']==idx]
        if len(loc_res_df) != 2**len(loc_res_df['binary_representation'].values[0]):
            continue
        for (i,model1) in enumerate(model_names):
            local_exp = exp_df[model1].loc[idx,'weight']
            local_exp = np.array(eval(local_exp))
            res = calc_aopc(local_exp, loc_res_df)
            if res[0]:
                aopcs[i].append(res[1])
                # accs[i].append(res[2])
                continue
            flag = False
            if len(res[1]) == 0:
                continue
            # to_be_calced = res[1]
            # to_be_calced = [
            #     {
            #         # question_index	sentence	binary_representation	sample_sentence	logits_positive	logits_negative
            #         'question_index': idx,
            #         'sentence': text,
            #         'binary_representation': x,
            #         'sample_sentence': ''.join([tokens[k].text + tokens[k].whitespace_ for k in range(len(tokens)) if x[k] == '1']),
            #         'logits_positive': -10000,
            #         'logits_negative': -10000
            #     }
            #     for x in to_be_calced
            # ]
            
            # to_be_calced = pd.DataFrame(to_be_calced)
            # loc_res_df = pd.concat([loc_res_df, to_be_calced], ignore_index=True)
            # loc_res_df.sort_values(by=['binary_representation'], inplace=True)
            # loc_res_df.drop_duplicates(subset=['question_index', 'binary_representation'], inplace=True)
            # loc_res_df.reset_index(drop=True, inplace=True)
        local_ress.append(loc_res_df)
    for i in range(len(model_names)):
        if len(aopcs[i]) == 0:
            aopc_res[i][j] = 0
        else:
            aopc_res[i][j] = np.mean(aopcs[i])
            aopc_ci[i,j] = calc_confident_interval(aopcs[i])[2]
    
        # if len(accs[i]) == 0:
        #     acc_res[i][j] = 0
        # else:
        #     acc_res[i][j] = np.mean(accs[i])
    
    # if not flag:
    #     print(f"Model {model2} has some samples that need to be calculated.")
    #     local_ress = pd.concat(local_ress, ignore_index=True)
    #     local_ress.to_csv(f'samples_pools/sst_test_cached_{model2}.csv', sep='\t')
print(aopc_res)


from datetime import datetime
import json

nowtime = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
os.makedirs(f'heatmaps/{exp_method}_{nowtime}', exist_ok=True)
draw_heatmap(aopc_res, model_names, 'aopc', f'heatmaps/{exp_method}_{nowtime}/aopc.png')

table_aopc =  make_latex_table(aopc_res, aopc_ci, model_names, caption="AOPC Results", label="tab:aopc_results")
with open(f'heatmaps/{exp_method}_{nowtime}/aopc_results.tex', 'w') as f:
    f.write(table_aopc)


# %%