import glob
import os
import pdb
from pathlib import Path
from typing import Dict, List, Any, Optional, Union, Set

from tqdm import tqdm
from scipy.special import softmax
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import roc_auc_score
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib
from transformers import AutoTokenizer

try:
    matplotlib.use('TkAgg')
except:
    print("failure loading TkAgg, better not be in PyCharm!")
import matplotlib.pyplot as plt

THIS_SCRIPT_PARENT = Path(__file__).parent
HUMAN_METRIC_PATH = f"{THIS_SCRIPT_PARENT.parent}/data/raw-enem-exams/microdados_enem_2022/DADOS/ITENS_PROVA_2022.csv"
MODEL_NAME_TO_FULL = lambda n: f"google/{n}" if n.startswith("ge") else f"mistralai/{n}" if n.startswith(
    "Mi") else f"meta-llama/{n}"
MODELS = ["Llama-2-7b-chat-hf","Llama-2-7b-hf", "Mistral-7B-Instruct-v0.1", "Mistral-7B-v0.1", "gemma-7b-it",
          "gemma-7b"]

MODEL_NAME_MAPPING = {
    "Meta-Llama-3-8B": "LLaMA3-8B",
    "Meta-Llama-3-8B-Instruct": "LLaMA3-8B Instruct",
    "gemma-7b": "Gemma-7B",
    "gemma-7b-it": "Gemma-7B Instruct",
    "Llama-2-13b-hf": "LLaMA2-13B",
    "Llama-2-13b-chat-hf": "LLaMA2-13B Instruct",
    "Mistral-7B-v0.1": "Mistral-7B",
    "Mistral-7B-Instruct-v0.1": "Mistral-7B Instruct",
    "Llama-2-7b-hf": "LLaMA2-7B",
    "Llama-2-7b-chat-hf": "LLaMA2-7B Instruct",
    "GPT-3.5-turbo-0125": "ChatGPT-3.5",
}

ENEM_EXAM_MAPPING = {
    2022: {
        "MT": "2022 Mathematics",
        "CH": "2022 Humanities",
        "CN": "2022 Natural Sciences",
        "LC": "2022 Languages and Codes",
    },
    2023: {
        "MT": "2023 Mathematics",
        "CH": "2023 Humanities",
        "CN": "2023 Natural Sciences",
        "LC": "2023 Languages and Codes",
    }
}


LANGS = ["en", "pt-br"]
SUBJS = ["CH", "LC", "CN","MT"]
SUBJ_TO_VERSION_2022 = {"CH": 1062, "CN":1092,"LC":1072,"MT":1082}


def letter_to_one_hot(answer):
    index = "ABCDE".index(answer)
    res = [0, 0, 0, 0, 0]
    res[index] = 1
    return res


def get_model_metrics(result_path, softmax_temp: Optional[float] = None) -> pd.DataFrame:
    df = pd.read_parquet(result_path)
    if "LC" in result_path:
        df = df[5:]
    df = df[~df.PROB_DIST.isna()]
    if "LOGIT_DIST" not in df.columns and ("LOGITS" in df.columns and len(df.LOGITS.tolist()[0]) > 5):
        model_name = get_model_name_from_path(result_path)
        tokenizer_map = get_tokenizer_letter_map({model_name})[model_name]
        df["LOGIT_DIST"] = df.LOGITS.apply(lambda l: [l[tokenizer_map[option]] for option in
                                                      sorted(tokenizer_map.keys())] if l is not None else None)
    if softmax_temp is not None:
        if "LOGITS" in df.colums and len(df.LOGITS.tolist()[0]) == 5:
            df.PROB_DIST = df.LOGITS.apply(lambda l: softmax(np.array(l) / softmax_temp))
        else:
            df.PROB_DIST = df.LOGIT_DIST.apply(lambda l: softmax(np.array(l) / softmax_temp))
    df["max_prob"] = df.PROB_DIST.apply(lambda x: max(x))
    df["QUESTION_IDX"] = df.QUESTION.apply(lambda x: int(x.split()[-1]))
    df["accuracy"] = (df.CORRECT_ANSWER == df.MODEL_ANSWER).astype(int)
    return df


HUMAN_DF = {}


def get_human_metrics(machine_result_path, normalize=False, ablation=None) -> pd.DataFrame:
    result_file = Path(machine_result_path).name
    assert "prob-dist" in result_file
    test_name = result_file[result_file.find("ENEM"):].split("-")[0]
    test_id = test_name.replace("_OLD", "").split("_")[-1]
    # if test_id not in HUMAN_DF:
    df = pd.read_csv(HUMAN_METRIC_PATH, encoding="latin", sep=";")
    df = df[~df.NU_PARAM_B.isna()]

    df = df[df.CO_PROVA == int(test_id)].sort_values(by="CO_POSICAO")
    # normalize / mean / scaling by subject
    for metric in ["NU_PARAM_A", "NU_PARAM_B", "NU_PARAM_C"]:
        if normalize:
            # df[metric] = (df[metric] - df[metric].mean())/df[metric].std()
            # make it positive and norm 1
            # df[metric] = (df[metric] - df[metric].min()) / df[metric].std()
            # not scaling by norm
            df[metric] = df[metric] - df[metric].min()
        if ablation == "shuffle":
            df[metric] = np.random.permutation(df[metric].values)
        elif ablation == "mean":
            df[metric] = df[metric].mean()
        else:
            assert ablation is None, f"ablation option {ablation} not recognized"
    df = df.rename(columns={"CO_POSICAO": "QUESTION_IDX"})
    #     HUMAN_DF[test_id] = df.copy()
    # else:
    #     df = HUMAN_DF[test_id].copy()

    add_metadata_to_result_df(df, result_file)
    return df


def add_metadata_to_result_df(df, result_file):
    model = [m for m in MODELS if m in result_file][0]
    year, subject = result_file.split("_")[1], result_file.split("_")[2]
    lang = "en" if "-en-" in result_file else "pt-br"
    shuffle = result_file[:result_file.find("original") - 1].split("-")[-1]
    df["YEAR"], df["SUBJECT"], df["MODEL"], df["LANG"], df["shuffle"] = year, subject, model, lang, shuffle


def get_binned_correlations(df, x: str = "probs", bin: int = 10, top_p_only: bool = False):
    df["LABEL_IDX"] = df.CORRECT_ANSWER.apply(lambda x: int("ABCDE".index(x)))
    # df.loc[df.index,"LABEL_OH"] = df.CORRECT_ANSWER.apply(lambda x: letter_to_one_hot(x))
    df["LABEL_OH"] = df.CORRECT_ANSWER.apply(lambda x: letter_to_one_hot(x))
    df["CORRECT_OH"] = df.apply(lambda r: r.LABEL_OH if r.CORRECT_ANSWER == r.MODEL_ANSWER else [0, 0, 0, 0, 0], 1)
    additional_metric_cols = ["NU_PARAM_A", "NU_PARAM_B", "NU_PARAM_C"]
    if top_p_only:
        # this is a bug, should be top prediction's prob
        # probs = df.apply(lambda r: r.PROB_DIST[r.LABEL_IDX], 1).to_numpy()
        # labels = df.apply(lambda r: r.LABEL_OH[r.LABEL_IDX], 1).to_numpy()
        # corrects = df.apply(lambda r: r.CORRECT_OH[r.LABEL_IDX], 1).to_numpy()
        probs = df.apply(lambda r: max(r.PROB_DIST), 1).to_numpy()
        labels = df.apply(lambda r: 1 if r.CORRECT_ANSWER == r.MODEL_ANSWER else 0, 1).to_numpy()
        corrects = labels
        metrics = {}
        for col in additional_metric_cols:
            metrics[col] = df[col]
    else:
        probs = np.concatenate(df.PROB_DIST.to_numpy())
        labels = np.concatenate(df.LABEL_OH.to_numpy())
        corrects = np.concatenate(df.CORRECT_OH.to_numpy())
        metrics = {}
        for col in additional_metric_cols:
            metric = []
            for val in df[col]:
                metric.extend([val] * 5)
            metrics[col] = metric
    exp_df = pd.DataFrame(dict(probs=probs, labels=labels, corrects=corrects, **metrics))
    exp_df["pct"] = pd.qcut(exp_df[x], bin, duplicates="drop")
    res = exp_df.groupby("pct", observed=False).mean()
    for col in ["SUBJECT", "YEAR", "MODEL", "LANG"]:
        res[col] = df[col].iloc[0]
    return res


def plot_binned_correlation(df, x, y, hue=None, postfix=""):
    os.makedirs(f"{THIS_SCRIPT_PARENT}/analysis_results/model_human_correlations", exist_ok=True)
    plt.figure()
    if hue is not None:
        df = df.sort_values(hue)
    sns.lineplot(data=df, x=x, y=y, hue=hue, marker="o")
    plt.title(postfix)
    plt.savefig(f"{THIS_SCRIPT_PARENT}/analysis_results/model_human_correlations/binned_scatter_{postfix}.png")


def exp_calculate_subject_exam_separately(softmax_temp: Optional[float] = None):
    model_result_path = "../enem-experiments-results/gemma-7b-few-shot-ENEM_2022_CH_CO_PROVA_1062-default-original-en-5-0-prob-dist.parquet"
    model_result_paths = [file for file in glob.glob("../enem-experiments-results/*") if "prob-dist" in file]

    total_df = pd.DataFrame()
    total_binned_df = pd.DataFrame()
    # these df are binned by beta, same across shuffles, so okay to bin for each generation
    for model_result_path in tqdm(model_result_paths):
        df_model = get_model_metrics(model_result_path, softmax_temp)
        df_human = get_human_metrics(model_result_path)
        df = df_model.merge(df_human, on="QUESTION_IDX")
        binned_df = get_binned_correlations(df, x="NU_PARAM_B", bin=10, top_p_only=True)
        total_df = pd.concat([total_df, df])
        total_binned_df = pd.concat([total_binned_df, binned_df], ignore_index=True)

    for lang in ["en", "pt-br"]:
        for subj in ["CH", "MT", "CN", "LC"]:
            # confidence vs beta
            plot_binned_correlation(total_binned_df[(total_binned_df.SUBJECT == subj) & (total_binned_df.LANG == lang)],
                                    "NU_PARAM_B", "probs", hue="MODEL", postfix=f"conf_vs_beta_{subj}_top_p_{lang}")
            # acc vs beta
            plot_binned_correlation(total_binned_df[(total_binned_df.SUBJECT == subj) & (total_binned_df.LANG == lang)],
                                    "NU_PARAM_B", "corrects", hue="MODEL", postfix=f"acc_vs_beta_{subj}_top_p_{lang}")

    # for metrics that needs to be binned by confidence, we need to aggregate by model/year/subject/lang, then bin
    total_binned_df_by_probs = pd.DataFrame()
    for _, group_df in total_df.groupby(["MODEL", "YEAR", "SUBJECT", "LANG"]):
        binned_df = get_binned_correlations(group_df, x="probs", bin=10, top_p_only=False)
        total_binned_df_by_probs = pd.concat([total_binned_df_by_probs, binned_df], ignore_index=True)

    for lang in ["en", "pt-br"]:
        for subj in ["CH", "MT", "CN", "LC"]:
            # acc vs conf
            plot_binned_correlation(total_binned_df_by_probs[(total_binned_df_by_probs.SUBJECT == subj) & (
                        total_binned_df_by_probs.LANG == lang)], "probs", "corrects", hue="MODEL",
                                    postfix=f"acc_vs_conf_{subj}_{lang}")
            # beta vs conf
            plot_binned_correlation(total_binned_df_by_probs[(total_binned_df_by_probs.SUBJECT == subj) & (
                        total_binned_df_by_probs.LANG == lang)], "probs", "NU_PARAM_B", hue="MODEL",
                                    postfix=f"beta_vs_conf_{subj}_{lang}")


def get_tokenizer_letter_map(model_names: Set[str]) -> Dict[str, Dict]:
    result = {}
    for name in model_names:
        tokenizer = AutoTokenizer.from_pretrained(name)
        tokenize_map = {letter: tokenizer.encode(f"({letter}")[-1] for letter in "ABCDE"}
        result[name] = tokenize_map
    return result


def get_model_name_from_path(p):
    keyword = "-few" if "-few" in Path(p).name else "-zero"
    return MODEL_NAME_TO_FULL(Path(p).name[:Path(p).name.find(keyword)])


def add_unsoftmax_option_logit_column(result_dir: str):
    result_paths = [file for file in glob.glob(f"{result_dir}/*") if "prob-dist" in file]
    all_model_names = set([get_model_name_from_path(p) for p in result_paths])
    model_to_map = get_tokenizer_letter_map(all_model_names)
    for p in tqdm(result_paths):
        model_name = get_model_name_from_path(p)
        tokenizer_map = model_to_map[model_name]
        df = pd.read_parquet(p)
        raw_logits = df.LOGITS.apply(
            lambda l: [l[tokenizer_map[option]] for option in sorted(tokenizer_map.keys())] if l is not None else None)
        df["LOGIT_DIST"] = raw_logits
        df.to_parquet(p)


def plot_varying_temperature(df, x, y, hue=None, folder="", postfix="", calibration=False):
    os.makedirs(f"{THIS_SCRIPT_PARENT}/analysis_results/{folder}", exist_ok=True)
    plt.figure()
    if hue is not None:
        df = df.sort_values(hue)
    g = sns.lineplot(data=df, x=x, y=y, hue=hue, marker="o")
    if calibration:
        g = (g.set(xlim=(0, 1), ylim=(0, 1)))
    plt.title(postfix)
    plt.savefig(f"{THIS_SCRIPT_PARENT}/analysis_results/{folder}/temp_{postfix}.png")


def get_calibration_metrics(df: pd.DataFrame):
    """df is binned calibration using top predictions for each multiple choice question only
    see https://arxiv.org/pdf/2207.05221.pdf appendix A for detail"""
    n = len(df)
    # ece = np.sum(np.abs(df.probs - df.labels))/n  # expected calibration error
    # rmsece = np.sqrt(np.sum((df.probs - df.labels)**2) / n)  # RMS expected calibration error
    # brier_score = np.sum((df.labels - df.probs)**2)/n
    ece = np.sum(np.abs(df.probs - df.corrects)) / n  # expected calibration error
    rmsece = np.sqrt(np.sum((df.probs - df.corrects) ** 2) / n)  # RMS expected calibration error
    brier_score = np.sum((df.labels - df.corrects) ** 2) / n
    return {"ece": ece, "rmsece": rmsece, "brier_score": brier_score}


def get_correlation(df: pd.DataFrame):
    """df needs to be raw data (not binned) containing beta + conf/accuracy"""
    df["accuracy"] = df.CORRECT_ANSWER == df.MODEL_ANSWER
    res = {
        "beta_top_prob_pearson": pearsonr(df.max_prob, df.NU_PARAM_B).correlation,
        "beta_accuracy_pearson": pearsonr(df.accuracy, df.NU_PARAM_B).correlation,
        "beta_top_prob_spearman": spearmanr(df.max_prob, df.NU_PARAM_B).correlation,
        "beta_accuracy_spearman": spearmanr(df.accuracy, df.NU_PARAM_B).correlation,
    }
    return res


def exp_calibration_across_temperatures(softmax_temp_range: List[float]):
    model_result_path = "../enem-experiments-results-default/gemma-7b-few-shot-ENEM_2022_CH_CO_PROVA_1062-default-original-en-5-0-prob-dist.parquet"
    model_result_paths = [file for file in glob.glob("../enem-experiments-results-default/*") if "prob-dist" in file]

    total_df = pd.DataFrame()
    res_df = []
    # these df are binned by beta, same across shuffles, so okay to bin for each generation
    for model_result_path in tqdm(model_result_paths[:100]):
        for t in softmax_temp_range:
            df_model = get_model_metrics(model_result_path, t)
            df_human = get_human_metrics(model_result_path)
            df = df_model.merge(df_human, on="QUESTION_IDX")
            df["temperature"] = t
            res = get_correlation(df)
            res.update(df[["MODEL", "YEAR", "SUBJECT", "LANG"]].iloc[0].to_dict())
            res["temperature"] = t
            res_df.append(res)
            total_df = pd.concat([total_df, df.drop(columns=["LOGITS"])])
    res_df = pd.DataFrame(res_df)

    for lang in ["en", "pt-br"]:
        for subj in ["CH", "MT", "CN", "LC"]:
            # beta-confidence correlation vs temp
            plot_varying_temperature(res_df[(res_df.SUBJECT == subj) & (res_df.LANG == lang)], "temperature",
                                     "beta_top_prob_pearson", hue="MODEL", folder="calibration_across_temperature",
                                     postfix=f"beta-conf-pearsonr_vs_temp_{subj}_{lang}")
            plot_varying_temperature(res_df[(res_df.SUBJECT == subj) & (res_df.LANG == lang)], "temperature",
                                     "beta_top_prob_spearman", hue="MODEL", folder="calibration_across_temperature",
                                     postfix=f"beta-conf-spearmanr_vs_temp_{subj}_{lang}")
    # for metrics that needs to be binned by confidence, we need to aggregate by model/year/subject/lang, then bin
    binned_res_df = []
    for group_name, group_df in total_df.groupby(["MODEL", "YEAR", "SUBJECT", "LANG"]):
        for t in softmax_temp_range:
            binned_df = get_binned_correlations(group_df[group_df.temperature == t], x="probs", bin=10, top_p_only=True)
            res = get_calibration_metrics(binned_df)
            res.update(group_df[["MODEL", "YEAR", "SUBJECT", "LANG"]].iloc[0].to_dict())
            res["temperature"] = t
            binned_res_df.append(res)
    binned_res_df = pd.DataFrame(binned_res_df)

    for lang in ["en", "pt-br"]:
        for subj in ["CH", "MT", "CN", "LC"]:
            # calibration vs temp
            for metrics in ["ece", "rmsece", "brier_score"]:
                plot_varying_temperature(binned_res_df[(binned_res_df.SUBJECT == subj) & (binned_res_df.LANG == lang)],
                                         "temperature", metrics, hue="MODEL", folder="calibration_across_temperature",
                                         postfix=f"{metrics}_vs_temp_{subj}_{lang}")


def exp_calibration_beta_dependent_temperatures(result_dir: str, coefficient: List[float]):
    model_result_path = "../enem-experiments-results-default/gemma-7b-few-shot-ENEM_2022_CH_CO_PROVA_1062-default-original-en-5-0-prob-dist.parquet"
    model_result_paths = [file for file in glob.glob(f"{result_dir}/*") if ("prob-dist" in file and "2022" in file)]

    total_df = pd.DataFrame()
    res_df = []
    # these df are binned by beta, same across shuffles, so okay to bin for each generation
    for model_result_path in tqdm(model_result_paths):
        for c in coefficient:
            df_model = get_model_metrics(model_result_path)
            df_human = get_human_metrics(model_result_path, normalize=True)
            # post-hoc change temperature depending on beta
            df = df_model.merge(df_human, on="QUESTION_IDX")
            df.PROB_DIST = df.apply(lambda r: softmax(np.array(r.LOGIT_DIST) / (1 - c * r.NU_PARAM_B)), 1)
            df.max_prob = df.PROB_DIST.apply(lambda x: max(x))
            df["beta_temp_coefficient"] = c
            # res = get_correlation(df)
            # res.update(df[["MODEL", "YEAR", "SUBJECT", "LANG"]].iloc[0].to_dict())
            # res["beta_temp_coefficient"] = c
            # res_df.append(res)
            total_df = pd.concat([total_df, df.drop(columns=["LOGITS"])])
    # res_df = pd.DataFrame(res_df)

    # for lang in ["en", "pt-br"]:
    #     for subj in ["CH", "MT", "CN", "LC"]:
    #         # beta-confidence correlation vs temp
    #         plot_varying_temperature(res_df[(res_df.SUBJECT == subj) & (res_df.LANG == lang)], "beta_temp_coefficient", "beta_top_prob_pearson", hue="MODEL", folder="calibration_beta_dependent", postfix=f"beta-conf-pearsonr_vs_c_{subj}_{lang}")
    #         plot_varying_temperature(res_df[(res_df.SUBJECT == subj) & (res_df.LANG == lang)], "beta_temp_coefficient", "beta_top_prob_spearman", hue="MODEL", folder="calibration_beta_dependent", postfix=f"beta-conf-spearmanr_vs_c_{subj}_{lang}")

    # for metrics that needs to be binned by confidence, we need to aggregate by model/year/subject/lang, then bin
    binned_res_df = []
    for group_name, group_df in total_df.groupby(["MODEL", "YEAR", "SUBJECT", "LANG"]):
        for c in coefficient:
            binned_df = get_binned_correlations(group_df[group_df.beta_temp_coefficient == c], x="probs", bin=10,
                                                top_p_only=True)
            res = get_calibration_metrics(binned_df)
            res.update(group_df[["MODEL", "YEAR", "SUBJECT", "LANG"]].iloc[0].to_dict())
            res["beta_temp_coefficient"] = c
            res["group_size"] = sum(group_df.beta_temp_coefficient == c)
            binned_res_df.append(res)
    binned_res_df = pd.DataFrame(binned_res_df)
    binned_res_df.to_csv("analysis_results/binned_df.csv")

    for lang in ["en", "pt-br"]:
        for subj in ["CH", "MT", "CN", "LC"]:
            # calibration vs temp
            for metrics in ["rmsece"]:  # "ece", "brier_score"
                plot_varying_temperature(binned_res_df[(binned_res_df.SUBJECT == subj) & (binned_res_df.LANG == lang)],
                                         "beta_temp_coefficient", metrics, hue="MODEL",
                                         folder="calibration_beta_dependent", postfix=f"{metrics}_vs_c_{subj}_{lang}")


def t_given_c(c: float, beta: float, max_beta: float = None, mean_beta: float = None, mode="mean"):
    # post-hoc change temperature depending on beta
    eps = 1e-20
    if mode == "max":
        # t(c) = (1+c) - 2c beta / max(beta), where c \in [-1, 1]
        # this gives us positive/negative dependency on beta, where t = 1 when c = 0
        # if needed, we can vary the maximum slop to go beyond 2/max(beta)
        return max((1 + c) - 2 * c * beta / max_beta + eps, eps)
    elif mode == "mean":
        # t(c) = (1+c) - 2c beta / mean(beta), where c \in [-1, 1]
        # this gives us positive/negative dependency on beta, where t = 1 when c = 0
        # if needed, we can vary the maximum slop to go beyond 2/max(beta)
        return max((1 + c) - c * beta / mean_beta + eps, eps)


def exp_calibration_beta_dependent_temperatures_efficient(result_dir: str, coefficient: List[float],
                                                          ablation: Optional[str] = None):
    model_result_paths = [file for file in glob.glob(f"{result_dir}/*") if
                          all(w in file for w in ["prob-dist", "2022", "few-shot-ENEM"]) and all(
                              w not in file for w in ["_OLD"])]

    binned_res_df = []
    # for metrics that needs to be binned by confidence, we need to aggregate by model/year/subject/lang, then bin
    bar = tqdm(total=len(MODELS) * len(LANGS) * len(SUBJS))
    for model in MODELS:
        for lang in LANGS:
            for subj in SUBJS:
                year = "2022"
                bar.set_description(f"processing {model} results for {lang}, {subj} in {year}")
                sub_result_paths = [p for p in model_result_paths if
                                    all(bool(s in p) for s in [f"{model}-few-shot-", f"-{lang}-", f"_{subj}_", year])]
                group_df = get_group_df(sub_result_paths, normalize=True, ablation=ablation)

                for c in coefficient:

                    if "LOGIT_DIST" in group_df.columns:
                        group_df["PROB_DIST"] = group_df.apply(lambda r: softmax(np.array(r.LOGIT_DIST) /
                                                                                 t_given_c(c, r.NU_PARAM_B,
                                                                                           mean_beta=group_df.NU_PARAM_B.mean(),
                                                                                           mode="mean")), 1)
                    else:
                        group_df["PROB_DIST"] = group_df.apply(lambda r: softmax(np.array(r.LOGITS) /
                                                                                 t_given_c(c, r.NU_PARAM_B,
                                                                                           mean_beta=group_df.NU_PARAM_B.mean(),
                                                                                           mode="mean")), 1)
                    group_df["probs"] = group_df.PROB_DIST.apply(lambda x: max(x))

                    binned_df = get_binned_correlations(group_df, x="probs",
                                                        bin=10, top_p_only=True)
                    res = get_calibration_metrics(binned_df)
                    res.update(group_df[["MODEL", "YEAR", "SUBJECT", "LANG"]].iloc[0].to_dict())
                    res["beta_temp_coefficient"] = c
                    res["group_size"] = len(group_df)
                    res["auroc_single"] = roc_auc_score(group_df["accuracy"], group_df.apply(
                        lambda r: r.PROB_DIST[int("ABCDE".index(r.CORRECT_ANSWER))], 1))
                    res["auroc_multi"] = roc_auc_score(
                        np.stack(group_df.CORRECT_ANSWER.apply(lambda x: letter_to_one_hot(x))),
                        np.stack(group_df.PROB_DIST), multi_class="ovr")
                    binned_res_df.append(res)
                bar.update(1)

    binned_res_df = pd.DataFrame(binned_res_df)
    binned_res_df.to_csv(f"analysis_results/binned_df{'' if ablation is None else '_' + ablation}.csv")

    for lang in ["en", "pt-br"]:
        for subj in ["CH", "MT", "CN", "LC"]:
            # calibration vs temp
            for metrics in ["ece", "brier_score", "rmsece", "auroc_single", "auroc_multi"]:  # "ece", "brier_score"
                plot_varying_temperature(binned_res_df[(binned_res_df.SUBJECT == subj) & (binned_res_df.LANG == lang)],
                                         "beta_temp_coefficient", metrics, hue="MODEL",
                                         folder=f"calibration_beta_dependent{'' if ablation is None else '_' + ablation}",
                                         postfix=f"{metrics}_vs_c_{subj}_{lang}")


def exp_calibration_irt_theta_dependent_temperatures_efficient(result_dir: str, coefficient: List[float],
                                                               ablation: Optional[str] = None):
    model_result_paths = [file for file in glob.glob(f"{result_dir}/*") if
                          all(w in file for w in ["prob-dist", "2022", "few-shot-ENEM"]) and all(
                              w not in file for w in ["_OLD"])]
    # mean from 2022 exam, imporper, but for quick exploration
    DEFAULT_THETAS = {'Llama-2-7b-chat-hf': 0.41569473212618185, 'Llama-2-7b-hf': 0.15280820992422456,
                      'Mistral-7B-Instruct-v0.1': 0.6070607247488036, 'Mistral-7B-v0.1': 1.0247184336140631,
                      'gemma-7b': 1.1837435720962348, 'gemma-7b-it': 0.707890291250033}
    binned_res_df = []
    # for metrics that needs to be binned by confidence, we need to aggregate by model/year/subject/lang, then bin
    bar = tqdm(total=len(MODELS) * len(LANGS) * len(SUBJS))
    # for model in MODELS:
    for model in MODELS:
        for lang in LANGS:
            for subj in SUBJS:
                year = "2022"
                bar.set_description(f"processing {model} results for {lang}, {subj} in {year}")
                sub_result_paths = [p for p in model_result_paths if
                                    all(bool(s in p) for s in [f"{model}-few-shot-", f"-{lang}-", f"_{subj}_", year])]
                group_df = get_group_df(sub_result_paths, normalize=False, ablation=ablation)
                theta = DEFAULT_THETAS[model]
                group_df["IRT_PROB"] = np.exp(group_df.NU_PARAM_A * (theta - group_df.NU_PARAM_B)) / (
                            1 + group_df.NU_PARAM_A * (theta - group_df.NU_PARAM_B))
                for c in coefficient:
                    # post-hoc change temperature depending on p_i
                    if "LOGIT_DIST" in group_df.columns:
                        group_df["PROB_DIST"] = group_df.apply(lambda r: softmax(
                            np.array(r.LOGIT_DIST) / t_given_c(c, r.IRT_PROB, mean_beta=group_df.IRT_PROB.mean(),
                                                               mode="mean")), 1)
                    else:
                        group_df["PROB_DIST"] = group_df.apply(lambda r: softmax(
                            np.array(r.LOGITS) / t_given_c(c, r.IRT_PROB, mean_beta=group_df.IRT_PROB.mean(),
                                                           mode="mean")), 1)
                    group_df["probs"] = group_df.PROB_DIST.apply(lambda x: max(x))

                    binned_df = get_binned_correlations(group_df, x="probs",
                                                        bin=10, top_p_only=True)
                    res = get_calibration_metrics(binned_df)
                    res.update(group_df[["MODEL", "YEAR", "SUBJECT", "LANG"]].iloc[0].to_dict())
                    res["beta_temp_coefficient"] = c
                    res["group_size"] = len(group_df)
                    res["auroc_single"] = roc_auc_score(group_df["accuracy"], group_df["max_prob"])
                    res["auroc_multi"] = roc_auc_score(
                        np.stack(group_df.CORRECT_ANSWER.apply(lambda x: letter_to_one_hot(x))),
                        np.stack(group_df.PROB_DIST), multi_class="ovr")
                    binned_res_df.append(res)
                bar.update(1)

    binned_res_df = pd.DataFrame(binned_res_df)
    binned_res_df.to_csv(f"analysis_results/binned_df_irt_theta{'' if ablation is None else '_' + ablation}.csv")

    for lang in ["en", "pt-br"]:
        for subj in ["CH", "MT", "CN", "LC"]:
            # calibration vs temp
            for metrics in ["ece", "brier_score", "rmsece", "auroc_single", "auroc_multi"]:  # "ece", "brier_score"
                plot_varying_temperature(binned_res_df[(binned_res_df.SUBJECT == subj) & (binned_res_df.LANG == lang)],
                                         "beta_temp_coefficient", metrics, hue="MODEL",
                                         folder=f"calibration_irt_theta_dependent{'' if ablation is None else '_' + ablation}",
                                         postfix=f"{metrics}_vs_c_{subj}_{lang}")


def exp_calibration_beta_dependent_temperatures_vs_constant(result_folder: str, c: float, t: float = 1.0):
    model_result_path = "../enem-experiments-results-default/gemma-7b-few-shot-ENEM_2022_CH_CO_PROVA_1062-default-original-en-5-0-prob-dist.parquet"
    model_result_paths = [file for file in glob.glob("../enem-experiments-results-default/*") if "prob-dist" in file]

    total_df = pd.DataFrame()
    res_df = []
    # these df are binned by beta, same across shuffles, so okay to bin for each generation
    for model_result_path in tqdm(model_result_paths[:100]):
        df_model = get_model_metrics(model_result_path)
        df_human = get_human_metrics(model_result_path, normalize=True)
        # post-hoc change temperature depending on beta
        df = df_model.merge(df_human, on="QUESTION_IDX")
        df["PROB_DIST_C"] = df.apply(lambda r: softmax(np.array(r.LOGIT_DIST) / (1 - c * r.NU_PARAM_B)), 1)
        df["beta_temp_coefficient"] = c
        total_df = pd.concat([total_df, df.drop(columns=["LOGITS"])])

    # for metrics that needs to be binned by confidence, we need to aggregate by model/year/subject/lang, then bin
    binned_res_df = []
    for group_name, group_df in total_df.groupby(["MODEL", "YEAR", "SUBJECT", "LANG"]):
        binned_df = get_binned_correlations(group_df[group_df.beta_temp_coefficient == c], x="probs", bin=10,
                                            top_p_only=True)
        res = get_calibration_metrics(binned_df)
        group_df["PROB_DIST"] = group_df.PROB_DIST_C
        binned_df_c = get_binned_correlations(group_df[group_df.beta_temp_coefficient == c], x="probs", bin=10,
                                              top_p_only=True)
        res_c = get_calibration_metrics(binned_df_c)
        res.update(group_df[["MODEL", "YEAR", "SUBJECT", "LANG"]].iloc[0].to_dict())
        res.update({f"{k}_c": v for k, v in res_c.items()})
        res["beta_temp_coefficient"] = c
        binned_res_df.append(res)
    binned_res_df = pd.DataFrame(binned_res_df)

    for lang in ["en", "pt-br"]:
        for models in [["Llama-2-7b-hf", "Mistral-7B-v0.1", "gemma-7b"],
                       ["Llama-2-7b-chat-hf", "Mistral-7B-Instruct-v0.1", "gemma-7b-it"]]:
            # for subj in ["CH", "MT", "CN", "LC"]:
            # calibration vs temp
            # for metrics in ["rmsece"]:  # "ece", "brier_score"
            # sub_df = binned_res_df[(binned_res_df.SUBJECT == subj)&(binned_res_df.LANG == lang)]
            is_instruct_model = True if "Llama-2-7b-chat-hf" in models else False
            sub_df = binned_res_df[(binned_res_df.LANG == lang) & (binned_res_df.MODEL.isin(models))]
            print(f"============== lang={lang}, is_instruct_model = {is_instruct_model} ==============")
            print(sub_df[["ece", "ece_c", "rmsece", "rmsece_c"]].describe())
            x = 1


def exp_calibration_curves_across_difficulty_tiers(result_dir, difficulty_bins: int = 3):
    """with 1-shot result, plot calibrations four subjects binned by difficulty"""
    model_result_paths = [file for file in glob.glob(f"{result_dir}/*") if
                          all(w in file for w in ["prob-dist", "2022", "few-shot-ENEM"]) and all(
                              w not in file for w in ["_OLD"])]
    binned_res_df = []
    # for metrics that needs to be binned by confidence, we need to aggregate by model/year/subject/lang, then bin
    bar = tqdm(total=len(MODELS) * len(LANGS) * len(SUBJS))
    for model in MODELS:
        for lang in LANGS:
            for subj in SUBJS:
                year = "2022"
                bar.set_description(f"processing {model} results for {lang}, {subj} in {year}")
                sub_result_paths = [p for p in model_result_paths if
                                    all(bool(s in p) for s in [f"{model}-", f"-{lang}", f"_{subj}_", f"_{year}_"]) and
                                    any(bool(s in p) for s in [f"{model}-few", f"{model}-z"])]
                group_df = get_group_df(sub_result_paths)

                calibration_df = []
                group_df["difficulty_tier"] = pd.cut(group_df.NU_PARAM_B, difficulty_bins, duplicates="drop")
                for i, bin in enumerate(sorted(group_df["difficulty_tier"].unique())):
                    binned_df = get_binned_correlations(group_df[group_df.difficulty_tier == bin], x="probs",
                                                        bin=10, top_p_only=True)
                    res = get_calibration_metrics(binned_df)
                    res.update(group_df[["MODEL", "YEAR", "SUBJECT", "LANG"]].iloc[0].to_dict())
                    bin_str = f"{str(bin)} (ECE={res['ece']:.2f})"
                    res["group_size"] = len(group_df)
                    res["difficulty_bin"] = bin_str  # alternatively keep ben index i
                    binned_res_df.append(res)
                    binned_df["difficulty_bin"] = bin_str
                    calibration_df.append(binned_df)
                bar.update(1)
                plot_varying_temperature(pd.concat(calibration_df),
                                         "probs", "corrects", hue="difficulty_bin", folder="calibration_binned_by_beta",
                                         postfix=f"ece_binned_by_difficulty_{subj}_{lang}_{model}", calibration=True)

    # binned_res_df = pd.DataFrame(binned_res_df)
    # binned_res_df.to_csv(f"{THIS_SCRIPT_PARENT}/analysis_results/binned_df_few-shot_3bin_by_difficulty.csv")


def get_group_df(sub_result_paths, normalize: bool, ablation: Optional[str]):
    group_df = pd.DataFrame()
    for model_result_path in sub_result_paths:  # different shuffles
        df_model = get_model_metrics(model_result_path)
        df_human = get_human_metrics(model_result_path, normalize=normalize, ablation=ablation)
        df = df_model.merge(df_human, on="QUESTION_IDX")
        group_df = pd.concat([group_df, df])
    return group_df


def exp_get_calibration_results(model_result_paths: List[str], postfix: str):
    binned_res_df = []
    # for metrics that needs to be binned by confidence, we need to aggregate by model/year/subject/lang, then bin
    bar = tqdm(total=len(MODELS) * len(LANGS) * len(SUBJS))
    for model in MODELS:
        for lang in LANGS:
            for subj in SUBJS:
                year = "2022"
                bar.set_description(f"processing {model} results for {lang}, {subj} in {year}")
                sub_result_paths = [p for p in model_result_paths if all(
                    bool(s in p) for s in [f"{model}-", f"-{lang}", f"_{subj}_", f"_{year}_"]) and any(
                    bool(s in p) for s in [f"{model}-few", f"{model}-z"])]
                group_df = pd.DataFrame()
                for model_result_path in sub_result_paths:  # different shuffles
                    df_model = get_model_metrics(model_result_path)
                    df_human = get_human_metrics(model_result_path, normalize=True)
                    df = df_model.merge(df_human, on="QUESTION_IDX")
                    group_df = pd.concat([group_df, df])
                binned_df = get_binned_correlations(group_df, x="probs",
                                                    bin=10, top_p_only=True)
                res = get_calibration_metrics(binned_df)
                res.update(group_df[["MODEL", "YEAR", "SUBJECT", "LANG"]].iloc[0].to_dict())
                res["group_size"] = len(group_df)
                binned_res_df.append(res)
                bar.update(1)

    binned_res_df = pd.DataFrame(binned_res_df)
    binned_res_df.to_csv(f"{THIS_SCRIPT_PARENT}/analysis_results/binned_df_{postfix}.csv")


def get_calibration_results_zero_shots(result_dir):
    model_result_paths = [file for file in glob.glob(f"{result_dir}/*") if
                          all(w in file for w in ["prob-dist", "2022", "few-shot-ENEM"]) and all(
                              w not in file for w in ["_OLD"])]
    print(f"number of few-shot files = {len(model_result_paths)}")
    exp_get_calibration_results(model_result_paths, "few-shot")
    model_result_paths = [file for file in glob.glob(f"{result_dir}/*") if
                          all(w in file for w in ["prob-dist", "2022", "few-shot-ENEM", "_OLD"])]
    print(f"number of few-shot files (pre-modified dataset) = {len(model_result_paths)}")
    exp_get_calibration_results(model_result_paths, "few-shot-old")
    model_result_paths = [file for file in glob.glob(f"{result_dir}/*") if
                          all(w in file for w in ["prob-dist", "2022", "zero-shot-ENEM"])]
    print(f"number of zero-shot files = {len(model_result_paths)}")
    exp_get_calibration_results(model_result_paths, "zero-shot")
    model_result_paths = [file for file in glob.glob(f"{result_dir}/*") if
                          all(w in file for w in ["prob-dist", "2022", "zero-shot-diff-ENEM"])]
    print(f"number of zero-shot-diff files = {len(model_result_paths)}")
    exp_get_calibration_results(model_result_paths, "zero-shot-diff")


def plot_grouped_bars(df, x, y, hue=None, folder="", postfix=""):
    os.makedirs(f"{THIS_SCRIPT_PARENT}/analysis_results/{folder}", exist_ok=True)
    plt.figure()
    if hue is not None:
        df = df.sort_values(hue)
    sns.barplot(data=df, x=x, y=y, hue=hue)
    plt.title(postfix)
    plt.savefig(f"{THIS_SCRIPT_PARENT}/analysis_results/{folder}/bar_{postfix}.png")


def analyze_pre_post_correction():
    pre_df = pd.read_csv(f"{THIS_SCRIPT_PARENT}/analysis_results/binned_df_few-shot-old.csv")
    post_df = pd.read_csv(f"{THIS_SCRIPT_PARENT}/analysis_results/binned_df_few-shot.csv")
    for metric in ["ece", "rmsece", "brier_score"]:
        pre_df[f"{metric} delta"] = post_df[metric] - pre_df[metric]
        for lang in ["en", "pt-br"]:
            plot_grouped_bars(pre_df[pre_df.LANG == lang], "SUBJECT", f"{metric} delta", "MODEL",
                              "before_after_analysis", f"pre-post-diff_{lang}_{metric}")


def analyze_difficulty_prompt_affect():
    pre_df = pd.read_csv(f"{THIS_SCRIPT_PARENT}/analysis_results/binned_df_zero-shot.csv")
    post_df = pd.read_csv(f"{THIS_SCRIPT_PARENT}/analysis_results/binned_df_zero-shot-diff.csv")
    for metric in ["ece", "rmsece", "brier_score"]:
        pre_df[f"{metric} delta"] = post_df[metric] - pre_df[metric]
        for lang in ["en", "pt-br"]:
            plot_grouped_bars(pre_df[pre_df.LANG == lang], "SUBJECT", f"{metric} delta", "MODEL",
                              "difficulty_prompt_analysis", f"w_difficulty_vs_wo_{lang}_{metric}")


def analyze_beta_effect_over_ablation():
    df = pd.read_csv(f"{THIS_SCRIPT_PARENT}/analysis_results/binned_df.csv")
    df_shuffle = pd.read_csv(f"{THIS_SCRIPT_PARENT}/analysis_results/binned_df_shuffled.csv")
    df_mean = pd.read_csv(f"{THIS_SCRIPT_PARENT}/analysis_results/binned_df_mean.csv")

    for lang in LANGS:
        for model in MODELS:
            for subj in SUBJS:
                sub_df = df[(df.LANG == lang) & (df.MODEL == model) & (df.SUBJECT == subj)]
                sub_df_shuffle = df_shuffle[
                    (df_shuffle.LANG == lang) & (df_shuffle.MODEL == model) & (df_shuffle.SUBJECT == subj)]
                sub_df_mean = df_mean[(df_mean.LANG == lang) & (df_mean.MODEL == model) & (df_mean.SUBJECT == subj)]
                print(f"==== lang={lang},model={model},subj={subj} =====\n"
                      f"baseline ece={sub_df.ece.min()}, shuffled ece={sub_df_shuffle.ece.min()}, mean ece={sub_df_mean.ece.min()}\n"
                      f"baseline auc={sub_df.auroc_single.max()}, shuffled auc={sub_df_shuffle.auroc_single.max()}, mean auc={sub_df_mean.auroc_single.max()}")


def exp_calibration_auc_across_temperatures(result_dir, temp_range: List[float]):
    """for each subject, one model, plot calibration/auc curve given different temperatures"""
    model_result_paths = [file for file in glob.glob(f"{result_dir}/*") if
                          all(w in file for w in ["prob-dist", "2022", "few-shot-ENEM"]) and all(
                              w not in file for w in ["_OLD"])]
    binned_res_df = []
    # for metrics that needs to be binned by confidence, we need to aggregate by model/year/subject/lang, then bin
    bar = tqdm(total=len(MODELS) * len(LANGS) * len(SUBJS))
    for model in MODELS:
        for lang in LANGS:
            for subj in SUBJS:
                year = "2022"
                bar.set_description(f"processing {model} results for {lang}, {subj} in {year}")
                sub_result_paths = [p for p in model_result_paths if
                                    all(bool(s in p) for s in
                                        [f"{model}-", f"-{lang}", f"_{subj}_", f"_{year}_"]) and
                                    any(bool(s in p) for s in [f"{model}-few", f"{model}-z"])]
                group_df = get_group_df(sub_result_paths, normalize=True, ablation=None)

                calibration_df = []
                roc_df = pd.DataFrame()
                for t in temp_range:
                    if "LOGIT_DIST" in group_df.columns:
                        group_df.PROB_DIST = group_df.apply(lambda r: softmax(np.array(r.LOGIT_DIST) / t), 1)
                    else:
                        group_df.PROB_DIST = group_df.apply(lambda r: softmax(np.array(r.LOGITS) / t), 1)
                    group_df["probs"] = group_df.PROB_DIST.apply(lambda x: max(x))

                    binned_df = get_binned_correlations(group_df, x="probs",
                                                        bin=10, top_p_only=True)
                    res = get_calibration_metrics(binned_df)
                    res.update(group_df[["MODEL", "YEAR", "SUBJECT", "LANG"]].iloc[0].to_dict())
                    res["group_size"] = len(group_df)
                    t_str = f"{t} (ECE={res['ece']:.2f})"
                    res["temperature"] = t_str  # alternatively keep ben index i
                    binned_res_df.append(res)
                    binned_df["temperature"] = t_str
                    calibration_df.append(binned_df)

                    # fpr, tpr, thresholds = roc_curve(group_df["accuracy"], group_df["probs"])
                    # roc_df = pd.concat([roc_df, pd.DataFrame({
                    #     "false_positive_rate": fpr,
                    #     "true_positive_rate": tpr,
                    #     "thresholds": thresholds,
                    #     "temperature": [t]*len(fpr),
                    # })])

                bar.update(1)
                plot_varying_temperature(pd.concat(calibration_df),
                                         "probs", "corrects", hue="temperature",
                                         folder="calibration_across_temp",
                                         postfix=f"calibration_across_temp_{subj}_{lang}_{model}",
                                         calibration=True)
                # plot_varying_temperature(roc_df,
                #                          "false_positive_rate", "true_positive_rate", hue="temperature",
                #                          folder="roc_across_temp",
                #                          postfix=f"roc_across_temp_{subj}_{lang}_{model}",
                #                          calibration=True)


def plot_lm_plot(df, x, y, hue=None, folder="", postfix=""):
    os.makedirs(f"{THIS_SCRIPT_PARENT}/analysis_results/{folder}", exist_ok=True)
    plt.rcParams.update({'font.size': 6})

    if hue is not None:
        df = df.sort_values(hue)

    if len(df.year.unique()) > 1 or len(df.lang.unique()) > 1 or len(df.subj.unique()) > 1:
        raise ValueError("Only one year, lang, subj allowed")
    
    year, lang, subj = df.year.unique()[0], df.lang.unique()[0], df.subj.unique()[0]
    title = f"{ENEM_EXAM_MAPPING[year][subj]} ({lang.upper()})"
    
    # remove % from the postfix
    postfix = postfix.replace("% ", "")

    g = sns.lmplot(data=df, x=x, y=y, hue=hue, height=1.75, aspect=0.75, markers=".", line_kws={"linewidth": 0.5}, scatter_kws={"s": 5})
    plt.title(title, fontsize=6)
    plt.tight_layout()
    g._legend.remove()
    # Legend outside right
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), title="Model: Pearson r (p-value)", fontsize=6, title_fontsize=6)
    plt.xlabel("Accuracy")
    plt.ylabel("MIN-20% Prob")
    plt.xticks([0, 0.2, 0.4, 0.6, 0.8, 1.0], ["0", "0.2", "0.4", "0.6", "0.8", "1.0"])
    plt.yticks([0, 10, 20, 30, 40, 50], ["0", "10", "20", "30", "40", "50"])
    plt.savefig(f"{THIS_SCRIPT_PARENT}/analysis_results/{folder}/scatter_{postfix}.pdf", bbox_inches="tight")
    plt.close()


def get_top_difficult_questions(beta_df, subj, top_k_pct_difficult):
    subj_df = beta_df[beta_df.CO_PROVA==SUBJ_TO_VERSION_2022[subj]]
    num_questions = int(len(subj_df)*top_k_pct_difficult)
    top_q_ids = subj_df.sort_values("NU_PARAM_B")["CO_POSICAO"][-num_questions:]
    return top_q_ids

def exp_ppl_vs_accuracy(result_dir, top_k_pct_difficult: 0.0):
    MODELS = ["Llama-2-7b-hf", "Llama-2-7b-chat-hf", "Mistral-7B-v0.1", "Mistral-7B-Instruct-v0.1", "gemma-7b-it",
              "gemma-7b", "Llama-2-13b-hf", "Llama-2-13b-chat-hf", "Meta-Llama-3-8B", "Meta-Llama-3-8B-Instruct"]
    
    if not os.path.isfile(f"analysis_results/acc_vs_ppl.csv"):

        # get 4-shot new paths
        model_result_paths = [file for file in glob.glob(f"{result_dir}/*") if
                              all(w in file for w in ["prob-dist", "202", "four-shot-ENEM"]) and
                              all(w not in file for w in ["_OLD", "2020", "2021"])]
        

        ppl_path = "ppl_results.csv"
        ppl_df = pd.read_csv(ppl_path)

        binned_res_df = pd.DataFrame()
        # for metrics that needs to be binned by confidence, we need to aggregate by model/year/subject/lang, then bin
        bar = tqdm(total=2 * len(MODELS) * len(LANGS) * len(SUBJS))
        for year in [2022, 2023]:
            for model in MODELS:
                for lang in LANGS:
                    for subj in SUBJS:
                        bar.set_description(f"processing {model} results for {lang}, {subj} in {year}")
                        if model == "Meta-Llama-3-8B-Instruct" or model == "Meta-Llama-3-8B":
                            # Make the B letter lowercase
                            model_small_b = model.replace("B", "b")
                            sub_result_paths = [p for p in model_result_paths if all(bool(s in p) for s in
                                                [f"{model_small_b}-four-shot-", f"-{lang}-", f"_{subj}_", str(year)])]
                        else:
                            sub_result_paths = [p for p in model_result_paths if all(bool(s in p) for s in
                                                [f"{model}-four-shot-", f"-{lang}-", f"_{subj}_", str(year)])]
                        df_ppl_sub = ppl_df[(ppl_df.lang == lang) & (ppl_df.model == model) & (ppl_df.subj == subj) & (
                                    ppl_df.year == year)]
                                                
                        group_df = pd.DataFrame()
                        for model_result_path in sub_result_paths:  # different shuffles
                            df_model = get_model_metrics(model_result_path)
                            group_df = pd.concat([group_df, df_model])
                        
                        binned_df = group_df.groupby("QUESTION_IDX")["accuracy"].mean()
                        # get per question accuracy aggregating shuffle
                        df_ppl_sub = df_ppl_sub.merge(binned_df, how="left", right_on="QUESTION_IDX",
                                                      left_on="question_number")
                        binned_res_df = pd.concat([binned_res_df, df_ppl_sub])
                        bar.update(1)

        binned_res_df.to_csv(f"analysis_results/acc_vs_ppl.csv")
    else:
        binned_res_df = pd.read_csv(f"analysis_results/acc_vs_ppl.csv")


    binned_res_df.model.replace(MODEL_NAME_MAPPING, inplace=True)
    NEW_MODELS = [MODEL_NAME_MAPPING[m] for m in MODELS] # update models to the new names

    beta_df = pd.read_csv("../data/raw-enem-exams/microdados_enem_2022/DADOS/ITENS_PROVA_2022.csv", encoding="latin",sep=";")

    # columns = {year, lang, subj, *model-coef/model-p}
    report_df = []
    for year in [2022, 2023]:
        for lang in LANGS:
            for subj in SUBJS:
                res_df = binned_res_df[
                    (binned_res_df.subj == subj) & (binned_res_df.lang == lang) & (
                            binned_res_df.year == year)].dropna()
                if top_k_pct_difficult > 0:
                    top_difficult_q_ids = get_top_difficult_questions(beta_df, subj, top_k_pct_difficult)
                    res_df = res_df[res_df.question_number.isin(top_difficult_q_ids)]
                # for metrics in ["ppl", "ppl/lowercase_ppl", "ppl/zlib", "Min_5.0% Prob", "Min_10.0% Prob",
                #                 "Min_20.0% Prob", "Min_30.0% Prob", "Min_40.0% Prob", "Min_50.0% Prob",
                #                 "Min_60.0% Prob"]:  # "ece", "brier_score"
                for metrics in ["Min_20.0% Prob"]:
                    name_map = {}
                    report_row = {"year": year, "lang": lang, "subj": subj, "metric": metrics}
                    for model in NEW_MODELS:
                        model_res_df = res_df[res_df.model == model]
                        pr = pearsonr(model_res_df.accuracy, model_res_df[metrics])
                        stat_str = f"{pr.statistic:.2f}/{pr.pvalue:.2f}" if pr.pvalue>0.05 else \
                            f"\\textbr{{{pr.statistic:.2f}/{pr.pvalue:.2f}}}"
                        report_row[model] = stat_str

                        name_map[model] = f"{model}: {pr.statistic:.2f} ({pr.pvalue:.2f})"
                        report_df.append(report_row)
                    res_df["model"] = res_df.model.map(name_map)

                    plot_lm_plot(res_df, "accuracy", metrics, hue="model",
                                 folder=f"acc_vs_ppl",
                                 postfix=f"{metrics.replace('/', '_')}_vs_acc_{year}_{subj}_{lang}")
    report_df = pd.DataFrame(report_df)
    report_df.to_csv(f"acc_vs_ppl_report{('_top'+top_k_pct_difficult) if top_k_pct_difficult>0 else ''}.csv")


if __name__ == "__main__":
    exp_ppl_vs_accuracy("../enem-experiments-results", top_k_pct_difficult=0.0)