# %%


# %%
import os
import math
from sklearn.metrics import roc_auc_score

import scipy as sp
# %%
import sklearn
from sklearn.metrics import pairwise
import pandas as pd
# %%
from tqdm.auto import tqdm
from multiprocessing import Pool

import numpy as np
import torch
from tqdm.auto import tqdm
from utils.draw_heatmap import draw_heatmap
import argparse

def get_args():

    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_method', type=str, default='lime', choices=['lime', 'shap'])
    parser.add_argument('--subjects', nargs='+', default=os.listdir('samples_pools'))
    args,_ = parser.parse_known_args()
    return args


args = get_args()
exp_method = args.exp_method



subjects = args.subjects

# subjects = os.listdir('samples_pools')
# model_names = ['qwen_0.5B', 'llama_8B', 'qwen_1.5B', 'llama_70B', 'qwen_14B', 'qwen_3B',  'qwen_32B', 'qwen_72B', 'qwen_7B']
os.makedirs(f'heatmap_{exp_method}', exist_ok=True)
# %%

# subjects.remove('high_school_world_history')

model_names = [
    "qwen_0.5B",
    "qwen_1.5B",
    "qwen_3B",
    "qwen_7B",
    "qwen_14B",
    "qwen_32B",
    "qwen_72B",
    "llama_8B",
    "llama_70B",
    "deepseekv3",
    "gpt-4o-mini",
    "gpt-4o",
]
def make_latex_table(acc_res, acc_res_ci, model_names, caption="Accuracy Results", label="tab:acc_results"):
    table = []
    table.append("\\begin{table}")
    table.append("\\centering")
    table.append("\\begin{tabular}{l" + "c" * len(model_names) + "}")
    table.append("\\toprule")
    header = [""] + model_names
    table.append(" & ".join(header) + " \\\\")
    table.append("\\midrule")
    for i, row_name in enumerate(model_names):
        row = [row_name]
        for j in range(len(model_names)):
            mean = acc_res[i, j] * 100
            ci = acc_res_ci[i, j] * 100
            row.append(f"{mean:.2f}\\% $\\pm$ {ci:.2f}")
        table.append(" & ".join(row) + " \\\\")
    table.append("\\bottomrule")
    table.append("\\end{tabular}")
    table.append(f"\\caption{{{caption} (百分比表示)}}")
    table.append(f"\\label{{{label}}}")
    table.append("\\end{table}")
    return "\n".join(table)

for subject in subjects:


    # %%
    subjects

    # %%


    # %%
    res_df = {}
    exp_df = {}

    # %%
    for model_name in tqdm(model_names):
        exp_df[model_name] = pd.read_csv(f'{exp_method}_res/{subject}/{model_name}.csv',index_col=0)
        res_df[model_name] = pd.read_csv(f'lime_samples/{subject}/{subject}_perturb_{model_name}.csv', sep='\t',index_col=None,keep_default_na=False)

    # %%
    res_df['qwen_0.5B'].head()

    # %%
    exp_df['qwen_0.5B'].head()

    # %%
    import numpy as np
    # %%
    import torch

    acc_res = np.zeros((len(exp_df),len(exp_df)))
    auroc_res = np.zeros((len(exp_df),len(exp_df)))
    loss_res = np.zeros((len(exp_df),len(exp_df)))

    # %%
    def calc_acc(local_exp, loc_res_df, base_values):
        logits = loc_res_df[['logits_A','logits_B', 'logits_C', 'logits_D']].values
        binary_representation = loc_res_df['binary_representation'].values
        accs = []
        for i in range(len(logits)):
            exp_pred = base_values + np.sum(local_exp[:,np.array(list(binary_representation[i])).astype(bool)],axis=1)
            accs.append(np.argmax(logits[i]) == np.argmax(exp_pred))
        return np.mean(accs)
    
    def calc_acc2(local_exp, loc_res_df, base_values):
        logits = loc_res_df[['logits_A','logits_B', 'logits_C', 'logits_D']].values
        binary_representation = loc_res_df['binary_representation'].values
        base_predict = np.argmax(loc_res_df[loc_res_df['binary_representation']=='1'*local_exp.shape[-1]][['logits_A','logits_B', 'logits_C', 'logits_D']].values[0])

        accs = []
        for i in range(len(logits)):
            exp_pred = base_values + np.sum(local_exp[:,np.array(list(binary_representation[i])).astype(bool)],axis=1)
            accs.append((np.argmax(logits[i])==base_predict) == (np.argmax(exp_pred)==base_predict))

        return np.mean(accs)

    # %%


    def calc_auroc(local_exp, loc_res_df,base_value):
        logits = loc_res_df[['logits_negative','logits_positive']].values
        binary_representation = loc_res_df['binary_representation'].values
        y_true = []
        y_scores = []

        for i in range(len(logits)):
            exp_pred = base_value + sum(local_exp[np.array(list(binary_representation[i])).astype(bool)])
            y_true.append((logits[i][1] > logits[i][0]))
            y_scores.append(exp_pred)

        auroc = roc_auc_score(y_true, y_scores)
        return auroc




    def fidelity_loss(local_exp, loc_res_df,base_value):
        def distance_fn(x):
                return sklearn.metrics.pairwise.pairwise_distances(
                    x, sp.sparse.csr_matrix(np.ones(x.shape[-1])), metric='cosine').ravel() * 100
        binary_representation = loc_res_df['binary_representation'].values
        logits = loc_res_df[['logits_negative','logits_positive']].values
        probs = torch.softmax(torch.tensor(logits), dim=-1).numpy()
        data = np.stack([np.array(list(binary_representation[i])).astype(int) for i in range(len(binary_representation))])
        distances = distance_fn(sp.sparse.csr_matrix(data))
        loss = 0
        exp_pred = np.array([base_value + sum(local_exp[np.array(list(binary_representation[i])).astype(bool)]) for i in range(len(logits))])
        loss = np.sum(np.square(exp_pred-probs[:,1]) * distances)
        
        return loss

    def lime_loss(local_exp, loc_res_df, base_value):
        local_exp=np.array(local_exp)
        loss = fidelity_loss(local_exp, loc_res_df, base_value)
        regular_loss = np.square(local_exp).sum()
        loss = loss + regular_loss
        return loss
        

    # %%
    from tqdm.auto import tqdm

    # %%
    def calc_confident_interval(accs, confidence=0.95):
        n = len(accs)
        mean = np.mean(accs)
        stderr = np.std(accs) / np.sqrt(n)
        h = stderr * 1.96  # for 95% confidence interval
        return mean - h, mean + h, h

    # %%


    # %%


    def process_args(args):
        idx, model1, model2 = args
        local_exp = exp_df[model1].loc[idx, 'weight']
        local_exp = np.array(eval(local_exp))
        base_value = exp_df[model1].loc[idx, 'intercept']
        base_value = np.array(eval(base_value))
        loc_res_df = res_df[model2][res_df[model2]['question_index'] == idx]
        return calc_acc2(local_exp, loc_res_df, base_value)
    # , lime_loss(local_exp, loc_res_df, base_value)

    acc_res = np.zeros((len(exp_df),len(exp_df)))
    auroc_res = np.zeros((len(exp_df),len(exp_df)))
    loss_res = np.zeros((len(exp_df),len(exp_df)))
    acc_ci = np.zeros((len(exp_df),len(exp_df)))
    for (i, model1) in enumerate(model_names):
        for (j, model2) in enumerate(model_names):
            accs = []
            aurocs = []
            losses = []
            # Get base value (same for all calculations for this model pair)
            if min(res_df[model2]['logits_A']) <= -1000:
                continue
            # Prepare arguments for parallel processing
            args_list = [(idx, model1, model2) for idx in range(min(len(exp_df[model1]),res_df[model2].loc[len(res_df[model2])-1,'question_index']+1))]
            
            # Use multiprocessing Pool
            with Pool(16) as pool:  # Uses all available CPUs by default
                ress = list(tqdm(pool.imap(process_args, args_list), total=len(args_list)))
            # for aaa in tqdm(args_list):
            #     process_args(aaa)
            accs = [res for res in ress]
            # aurocs = [res[1] for res in ress]
            # losses = [res[1] for res in ress]
            
            acc_res[i, j] = np.mean(accs)
            acc_ci[i,j] = calc_confident_interval(accs)[2]
            
            # auroc_res[i, j] = np.mean(aurocs)
            # loss_res[i, j] = np.mean(losses)
            # print(f'{model1} explanation for {model2}: acc:{np.mean(accs)}, loss:{np.mean(losses)}')        
            print(f'{model1} explanation for {model2}: acc:{np.mean(accs)}')
    os.makedirs(f'heatmap_{exp_method}/{subject}', exist_ok=True)
    draw_heatmap(acc_res, model_names, title=f'{subject} acc', save_path=f'heatmap_{exp_method}/{subject}/{subject}_acc.png')
    np.save(f'heatmap_{exp_method}/{subject}/{subject}_acc.npy', acc_res)
    table = make_latex_table(acc_res, acc_ci, model_names, caption=f"{subject} Accuracy Results", label=f"tab:{subject}_acc_results")
    with open(f"heatmap_{exp_method}/{subject}/{subject}_acc.tex", "w") as f:
        f.write(table)
    # %%
    acc_res

    # %%


    def process_args2(args):
        idx, model1, model2 = args
        local_exp = exp_df[model1].loc[idx, 'weight']
        local_exp = np.array(eval(local_exp))
        base_value = exp_df[model1].loc[idx, 'intercept']
        base_value = np.array(eval(base_value))
        loc_res_df = res_df[model2][res_df[model2]['question_index'] == idx]

        model2_out = exp_df[model2].loc[idx,'output']
        model1_out = exp_df[model1].loc[idx,'output']
        if model1_out != model2_out:
            return None
        return calc_acc2(local_exp, loc_res_df, base_value)
    # , lime_loss(local_exp, loc_res_df, base_value)


    acc_res = np.zeros((len(exp_df),len(exp_df)))
    auroc_res = np.zeros((len(exp_df),len(exp_df)))
    loss_res = np.zeros((len(exp_df),len(exp_df)))
    acc_ci = np.zeros((len(exp_df),len(exp_df)))
    for (i, model1) in enumerate(model_names):
        for (j, model2) in enumerate(model_names):
            accs = []
            aurocs = []
            losses = []
            # Get base value (same for all calculations for this model pair)
            if min(res_df[model2]['logits_A']) <= -1000:
                continue
            # Prepare arguments for parallel processing
            args_list = [(idx, model1, model2) for idx in range(min(len(exp_df[model1]),res_df[model2].loc[len(res_df[model2])-1,'question_index']+1))]
            
            # Use multiprocessing Pool
            # with Pool() as pool:  # Uses all available CPUs by default
                # ress = list(tqdm(pool.imap(process_args2, args_list), total=len(args_list)))
            with Pool(16) as pool:  # Use 4 threads
                ress = list(tqdm(pool.imap(process_args2, args_list), total=len(args_list)))
            accs = [res for res in ress if res is not None]
            
            # accs = [res for res in ress if res is not None]
            # aurocs = [res[1] for res in ress]
            # losses = [res[1] for res in ress]
            
            acc_res[i, j] = np.mean(accs)
            acc_ci[i,j] = calc_confident_interval(accs)[2]
            # auroc_res[i, j] = np.mean(aurocs)
            # loss_res[i, j] = np.mean(losses)
            # print(f'{model1} explanation for {model2}: acc:{np.mean(accs)}, loss:{np.mean(losses)}')        
            print(f'{model1} explanation for {model2}: acc:{np.mean(accs)}')
    draw_heatmap(acc_res, model_names, title=f'{subject} acc_filtered', save_path=f'heatmap_{exp_method}/{subject}/{subject}_acc_filtered.png')
    np.save(f'heatmap_{exp_method}/{subject}/{subject}_acc_filtered.npy', acc_res)
    table = make_latex_table(acc_res, acc_ci, model_names, caption=f"{subject} Accuracy Results Filtered", label=f"tab:{subject}_acc_filtered_results")
    with open(f"heatmap_{exp_method}/{subject}/{subject}_acc_filtered.tex", "w") as f:
        f.write(table)
    # draw_heatmap(auroc_res, model_names, save_path=f'heatmap/{subject}_auroc.png')

# %%
