import argparse
from copy import deepcopy
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd

ENEM_MAPPING_NAME = {
    #2022
    "ENEM_2022_CH_CO_PROVA_1062": "2022 Humanities",
    "ENEM_2022_CN_CO_PROVA_1092": "2022 Natural Sciences",
    "ENEM_2022_LC_CO_PROVA_1072": "2022 Languages and Codes",
    "ENEM_2022_MT_CO_PROVA_1082": "2022 Mathematics",
    #2023
    "ENEM_2023_CH_CO_PROVA_1197": "2023 Humanities",
    "ENEM_2023_CN_CO_PROVA_1227": "2023 Natural Sciences",
    "ENEM_2023_LC_CO_PROVA_1207": "2023 Languages and Codes",
    "ENEM_2023_MT_CO_PROVA_1217": "2023 Mathematics"    
}

MODEL_NAME_MAPPING = {
    "Meta-Llama-3-8b": "LLaMA3-8B",
    "Meta-Llama-3-8b-Instruct": "LLaMA3-8B Instruct",
    "gemma-7b": "Gemma-7B",
    "gemma-7b-it": "Gemma-7B Instruct",
    "Llama-2-13b-hf": "LLaMA2-13B",
    "Llama-2-13b-chat-hf": "LLaMA2-13B Instruct",
    "Mistral-7B-v0.1": "Mistral-7B",
    "Mistral-7B-Instruct-v0.1": "Mistral-7B Instruct",
    "Llama-2-7b-hf": "LLaMA2-7B",
    "Llama-2-7b-chat-hf": "LLaMA2-7B Instruct",
    "GPT-3.5-turbo-0125": "ChatGPT-3.5",
}

EXAM_ORDERING = {
    2022: ["ENEM_2022_LC_CO_PROVA_1072", "ENEM_2022_CH_CO_PROVA_1062", "ENEM_2022_CN_CO_PROVA_1092", "ENEM_2022_MT_CO_PROVA_1082"],
    2023: ["ENEM_2023_LC_CO_PROVA_1207", "ENEM_2023_CH_CO_PROVA_1197", "ENEM_2023_CN_CO_PROVA_1227", "ENEM_2023_MT_CO_PROVA_1217"]
}

def fisher_info_irt(df, theta):
    """
    Compute the Fisher Information for the Item Response Theory (IRT) model.

    Parameters
    ----------
    df : pd.DataFrame
        DataFrame of a exam with the following columns:
        - NU_PARAM_A: Discrimination parameter
        - NU_PARAM_B: Difficulty parameter
        - NU_PARAM_C: Pseudo-guessing parameter
    """

    # Compute the Fisher Information for each question
    p_i = lambda theta: df.NU_PARAM_C + ((1 - df.NU_PARAM_C) / (1 + np.exp(-df.NU_PARAM_A * (theta - df.NU_PARAM_B))))
    q_i = lambda theta: 1 - p_i(theta)
    fisher_info = lambda theta: ((df.NU_PARAM_A ** 2) * ((p_i(theta) - df.NU_PARAM_C) ** 2) * q_i(theta)) / (((1 - df.NU_PARAM_C) ** 2) * p_i(theta))
    return fisher_info(theta)

def fisher_info_plot():
    raise NotImplementedError("Fisher information plot not implemented yet")

# # Argparse with the parameter plot_type
parser = argparse.ArgumentParser()
parser.add_argument("--year", "-y", type=int, required=True, help="Year of ENEM exam")
parser.add_argument("--system_prompt_type", "-p", type=str, required=True, choices=["one-shot", "four-shot", "zero-shot"], help="System prompt type")
parser.add_argument("--fisher_info", "-f", type=str, required=True, choices=["total", "item"], help="Fisher information type")
parser.add_argument('--is_instruct_version', "-i", type=int, choices=[0, 1], help='Whether the model is the instruct version or not')
args = parser.parse_args()
year = args.year
instruct = "instruct" if args.is_instruct_version == 1 else "no-instruct"

df_llms = pd.read_parquet("enem-experiments-results-processed.parquet")
df_llms = deepcopy(df_llms[df_llms.ENEM_EXAM.str.contains(f"{year}")])
df_llms = deepcopy(df_llms[df_llms.SYSTEM_PROMPT_TYPE == args.system_prompt_type])

df_llms.replace({"MODEL_NAME": MODEL_NAME_MAPPING}, inplace=True)
if args.is_instruct_version == 0:
    df_llms = df_llms[~df_llms["MODEL_NAME"].str.contains("Instruct")]
else:
    #df_llms = df_llms[(df_llms["MODEL_NAME"].str.contains("Instruct")) | (df_llms["MODEL_NAME"].str.contains("GPT"))]
    df_llms = df_llms[(df_llms["MODEL_NAME"].str.contains("Instruct"))]

df_llms.replace({"LANGUAGE": {"pt-br": "PT-BR", "en": "EN"}}, inplace=True)

# Reading data
plt.rcParams.update({"font.size": 6})
fig, ax = plt.subplots(2, 4, sharex=True, figsize=(1.4*4, 2), gridspec_kw={'height_ratios': [1, 2.5]})
#for i, enem_exam in enumerate(sorted(df_llms.ENEM_EXAM.unique())):
for i, enem_exam in enumerate(EXAM_ORDERING[year]):
    df_llms_exam = df_llms[df_llms.ENEM_EXAM == enem_exam]

    df_llms_exam = df_llms_exam.groupby(["MODEL_NAME", "SYSTEM_PROMPT_TYPE", "LANGUAGE"]).agg(
        IRT_SCORE=("IRT_SCORE", "mean"),
        IRT_SCORE_EMP_SE=("IRT_SCORE", lambda x: 1.96 * x.std() / np.sqrt(x.shape[0]))
    ).reset_index()

    enem_code = enem_exam.split("_")[2]
    co_prova = int(enem_exam.split("_")[-1])
    df_items = pd.read_csv(f"data/raw-enem-exams/microdados_enem_{year}/DADOS/ITENS_PROVA_{year}.csv", sep=";", encoding="latin-1")
    df_items = deepcopy(df_items[df_items.CO_PROVA == co_prova])
    # Drop NA values (items not used in the exam)
    df_items = df_items.dropna(subset=["NU_PARAM_A", "NU_PARAM_B", "NU_PARAM_C"])

    # Compute Fisher Information
    thetas = np.linspace(-2, 5, 1000)
    fisher_info = np.array([fisher_info_irt(df_items, theta) for theta in thetas])

    # # Total fisher information
    if args.fisher_info == "total":
        fisher_info = fisher_info.sum(axis=1)

    # Plot Fisher Information
    # Make the second plot smaller
    # Set fontsize
    # plt.rcParams.update({"font.size": 6})
    # fig, ax = plt.subplots(2, 1, sharex=True, figsize=(3, 2.5), gridspec_kw={'height_ratios': [1, 1.5]})

    ax[0, i].plot(thetas, fisher_info)
    ax[0, i].set_xlabel("")
    ax[0, i].set_ylabel(r"$\mathcal {I} (\theta)$")
    ax[0, i].set_xticks([-2, -1, 0, 1, 2, 3, 4, 5])
    #ax[0, i].set_xticklabels([])
    ax[0, i].set_yticks([0, 10, 20, 30, 40])
    ax[0, i].set_yticklabels([0, 10, 20, 30, 40])
    ax[0, i].grid()
    ax[0, i].set_title(ENEM_MAPPING_NAME[enem_exam], fontsize=6)
    
    # Errorplot of IRT_SCORE (IRT_SCORE_EMP_SE) by MODEL_NAME, LANGUAGE (horizontal, using the same theta axis)
    for model_name in df_llms_exam.MODEL_NAME.unique():
        for language in df_llms_exam.LANGUAGE.unique():
            df_model = df_llms_exam[(df_llms_exam.MODEL_NAME == model_name) & (df_llms_exam.LANGUAGE == language)]
            ax[1, i].errorbar(df_model.IRT_SCORE, df_model.MODEL_NAME, xerr=df_model.IRT_SCORE_EMP_SE, fmt='.', capsize=4, elinewidth=2, label=f"{language}", color="C0" if language == "PT-BR" else "C1")
    # Remove legend duplicates
    handles, labels = ax[1, i].get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    ax[1, i].legend(by_label.values(), by_label.keys(), fontsize=5)
    ax[1, i].set_xlabel(r"IRT score ($\theta$)")
    ax[1, i].set_ylabel("")
    ax[1, i].set_xticks([-2, -1, 0, 1, 2, 3, 4, 5])
    ax[1, i].set_xticklabels([-2, -1, 0, 1, 2, 3, 4, 5])
    ax[1, i].grid()

    if i != 0:
        ax[0, i].set_ylabel("")
        ax[1, i].set_ylabel("")
        ax[0, i].set_yticklabels([])
        ax[1, i].set_yticklabels([])
    
    if i != 3:
        # Remove legend
        ax[1, i].legend().remove()

    # plt.tight_layout()
    # plt.savefig(f"plots/fisher-info/{args.fisher_info}-fisher-info-{year}-{enem_code}-{args.system_prompt_type}.pdf", bbox_inches='tight', pad_inches=0.05, dpi=800)
    # plt.close()
plt.tight_layout()
plt.savefig(f"plots/fisher-info/{args.fisher_info}-fisher-info-{year}-{args.system_prompt_type}-{instruct}.pdf", bbox_inches='tight', pad_inches=0.05, dpi=800)
plt.close()