import argparse
from copy import deepcopy
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

ENEM_MAPPING_NAME = {
    #2022
    "ENEM_2022_CH_CO_PROVA_1062": "2022 Humanities",
    "ENEM_2022_CN_CO_PROVA_1092": "2022 Natural Sciences",
    "ENEM_2022_LC_CO_PROVA_1072": "2022 Languages and Codes",
    "ENEM_2022_MT_CO_PROVA_1082": "2022 Mathematics",
    #2023
    "ENEM_2023_CH_CO_PROVA_1197": "2023 Humanities",
    "ENEM_2023_CN_CO_PROVA_1227": "2023 Natural Sciences",
    "ENEM_2023_LC_CO_PROVA_1207": "2023 Languages and Codes",
    "ENEM_2023_MT_CO_PROVA_1217": "2023 Mathematics"

}

ENEM_EXAM_MAPPING = {
    2022: {
        "MT": "2022 Mathematics",
        "CH": "2022 Humanities",
        "CN": "2022 Natural Sciences",
        "LC": "2022 Languages and Codes",
    },
    2023: {
        "MT": "2023 Mathematics",
        "CH": "2023 Humanities",
        "CN": "2023 Natural Sciences",
        "LC": "2023 Languages and Codes",
    }
}


MODEL_NAME_MAPPING = {
    "Meta-Llama-3-8b": "LLaMA3-8B",
    "Meta-Llama-3-8b-Instruct": "LLaMA3-8B Instruct",
    "gemma-7b": "Gemma-7B",
    "gemma-7b-it": "Gemma-7B Instruct",
    "Llama-2-13b-hf": "LLaMA2-13B",
    "Llama-2-13b-chat-hf": "LLaMA2-13B Instruct",
    "Mistral-7B-v0.1": "Mistral-7B",
    "Mistral-7B-Instruct-v0.1": "Mistral-7B Instruct",
    "Llama-2-7b-hf": "LLaMA2-7B",
    "Llama-2-7b-chat-hf": "LLaMA2-7B Instruct",
    "GPT-3.5-turbo-0125": "ChatGPT-3.5",
}

EXAM_ORDERING = {
    2022: ["2022 Languages and Codes", "2022 Humanities", "2022 Natural Sciences", "2022 Mathematics"],
    2023: ["2023 Languages and Codes", "2023 Humanities", "2023 Natural Sciences", "2023 Mathematics"]
}

# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("--year", "-y", type=int, required=True, help="Year of ENEM exam")
parser.add_argument("--system_prompt_type", "-p", type=str, required=True, choices=["one-shot", "four-shot", "zero-shot"], help="System prompt type")
#parser.add_argument("--language", "-l", type=str, required=True, choices=["pt-br", "en"], help="Language")
parser.add_argument('--is_instruct_version', "-i", type=int, choices=[0, 1], help='Whether the model is the instruct version or not')
args = parser.parse_args()
year = args.year
instruct = "instruct" if args.is_instruct_version == 1 else "no-instruct"


# Load the data
# LLMs
df = pd.read_parquet("enem-experiments-results-processed.parquet")
df["ENEM_EXAM"] = df["ENEM_EXAM"].replace(ENEM_MAPPING_NAME)
df["MODEL_NAME"] = df["MODEL_NAME"].replace(MODEL_NAME_MAPPING)
df = deepcopy(df[df["SYSTEM_PROMPT_TYPE"] == args.system_prompt_type])
df = deepcopy(df[df["ENEM_EXAM"].str.contains(f"{year}")])
df = df.rename(columns={"MODEL_NAME": "MODEL"})

if args.is_instruct_version == 0:
    df = df[~df["MODEL"].str.contains("Instruct")]
else:
    #df = df[(df["MODEL"].str.contains("Instruct")) | (df["MODEL"].str.contains("GPT"))]
    df = df[(df["MODEL"].str.contains("Instruct"))]

# Adding the LANGUAGE in the MODEL column
df["MODEL"] = df["MODEL"] + " (" + df["LANGUAGE"].apply(lambda x: x.upper()) + ")"

# Humans
df_human = pd.read_parquet("humans-irt-lz.parquet") 
df_human["EXAM_YEAR"] = df_human["EXAM_YEAR"].astype(int)
df_human = deepcopy(df_human[df_human["EXAM_YEAR"] == year])
# TODO: Remove this line
df_human = df_human.sample(frac=0.01, random_state=0)
df_human = df_human.rename(columns={"PFscores": "LZ_SCORE"})
df_human["ENEM_EXAM"] = df_human["EXAM_SUBJECT"].apply(lambda x: ENEM_EXAM_MAPPING[year][x])
df_human["MODEL"] = "Human"

# Plot KDE
# IRT
plt.rcParams.update({'font.size': 6})
fig, ax = plt.subplots(2, 2, figsize=(1.4*4, 3.5), sharex=True, sharey=True, layout="compressed")
# Flatten the axes
ax = ax.flatten()
#for i, exam in enumerate(sorted(df["ENEM_EXAM"].unique())):
for i, exam in enumerate(EXAM_ORDERING[year]):
    #fig, ax = plt.subplots(figsize=(3.5, 3.5))
    sns.kdeplot(data=df_human[df_human["ENEM_EXAM"] == exam], x="CTT_SCORE", y="IRT_SCORE", hue="MODEL", legend=False, fill=True, alpha=0.3, ax=ax[i], color="black")
    sns.kdeplot(data=df[df["ENEM_EXAM"] == exam], x="CTT_SCORE", y="IRT_SCORE", hue="MODEL", legend=True, fill=False, alpha=1, hue_order=sorted(df.MODEL.unique()), cmap="tab10", levels=1, ax=ax[i], linewidths=0.75)
    #plt.xticks(range(0, 46, 10))
    ax[i].set_xticks([0, 15, 30, 45])
    #plt.yticks(range(-2, 5, 1))
    ax[i].set_yticks(range(-2, 4, 1))
    ax[i].set_yticklabels(range(-2, 4, 1), fontsize=6)
    ax[i].set_xlabel("CTT score")
    ax[i].set_title(exam, fontsize=6)
    ax[i].set_ylabel("IRT score")        
    if i != 3:
        ax[i].legend_.remove()
    else:
        # Move the legend to outside the plot
        sns.move_legend(ax[i], loc='center left', ncol=1, bbox_to_anchor=(1, 0.5), fontsize=5, title_fontsize=0, title="")
        #sns.move_legend(ax[i], "upper left", ncol=2, fontsize=3, title_fontsize=0, title="")

    # Find the key respective to the exam in the ENEM_EXAM_MAPPING
    #exam_key = [key for key, value in ENEM_EXAM_MAPPING[year].items() if value == exam][0]
    # plt.savefig(f"plots/kde-plots/{instruct}/kde-human-ctt-irt-{year}-{exam_key}-{args.system_prompt_type}.pdf", dpi=800)
    # plt.close()
plt.tight_layout()
plt.savefig(f"plots/kde-plots/kde-human-ctt-irt-{year}-{args.system_prompt_type}-{instruct}.pdf", dpi=800)