import argparse
from copy import deepcopy
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

ENEM_MAPPING_NAME = {
    #2022
    "ENEM_2022_CH_CO_PROVA_1062": "2022 Humanities",
    "ENEM_2022_CN_CO_PROVA_1092": "2022 Natural Sciences",
    "ENEM_2022_LC_CO_PROVA_1072": "2022 Languages and Codes",
    "ENEM_2022_MT_CO_PROVA_1082": "2022 Mathematics",
    #2023
    "ENEM_2023_CH_CO_PROVA_1197": "2023 Humanities",
    "ENEM_2023_CN_CO_PROVA_1227": "2023 Natural Sciences",
    "ENEM_2023_LC_CO_PROVA_1207": "2023 Languages and Codes",
    "ENEM_2023_MT_CO_PROVA_1217": "2023 Mathematics"
}

ENEM_EXAM_MAPPING = {
    2022: {
        "MT": "2022 Mathematics",
        "CH": "2022 Humanities",
        "CN": "2022 Natural Sciences",
        "LC": "2022 Languages and Codes",
    },
    2023: {
        "MT": "2023 Mathematics",
        "CH": "2023 Humanities",
        "CN": "2023 Natural Sciences",
        "LC": "2023 Languages and Codes",
    }
}

MODEL_NAME_MAPPING = {
    "Meta-Llama-3-8b": "LLaMA3-8B",
    "Meta-Llama-3-8b-Instruct": "LLaMA3-8B Instruct",
    "gemma-7b": "Gemma-7B",
    "gemma-7b-it": "Gemma-7B Instruct",
    "Llama-2-13b-hf": "LLaMA2-13B",
    "Llama-2-13b-chat-hf": "LLaMA2-13B Instruct",
    "Mistral-7B-v0.1": "Mistral-7B",
    "Mistral-7B-Instruct-v0.1": "Mistral-7B Instruct",
    "Llama-2-7b-hf": "LLaMA2-7B",
    "Llama-2-7b-chat-hf": "LLaMA2-7B Instruct",
    "GPT-3.5-turbo-0125": "ChatGPT-3.5",
}

EXAM_ORDERING = {
    2022: ["2022 Languages and Codes", "2022 Humanities", "2022 Natural Sciences", "2022 Mathematics"],
    2023: ["2023 Languages and Codes", "2023 Humanities", "2023 Natural Sciences", "2023 Mathematics"]
}

# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("--year", "-y", type=int, required=True, help="Year of ENEM exam")
parser.add_argument("--system_prompt_type", "-p", type=str, required=True, choices=["one-shot", "four-shot", "zero-shot"], help="System prompt type")
parser.add_argument('--is_instruct_version', "-i", type=int, choices=[0, 1], required=True, help='Whether the model is the instruct version or not')
args = parser.parse_args()
year = args.year
instruct = "instruct" if args.is_instruct_version == 1 else "no-instruct"


# Load the data
# LLMs
df = pd.read_parquet("enem-experiments-results-processed.parquet")
df["ENEM_EXAM"] = df["ENEM_EXAM"].replace(ENEM_MAPPING_NAME)
df["MODEL_NAME"] = df["MODEL_NAME"].replace(MODEL_NAME_MAPPING)
df["LANGUAGE"] = df["LANGUAGE"].replace({"pt-br": "PT-BR", "en": "EN"})
df = deepcopy(df[df["SYSTEM_PROMPT_TYPE"] == args.system_prompt_type])
df = deepcopy(df[df["ENEM_EXAM"].str.contains(f"{year}")])
df = df.rename(columns={"MODEL_NAME": "MODEL"})

if args.is_instruct_version == 0:
    df = df[~df["MODEL"].str.contains("Instruct")]
else:
    #df = df[(df["MODEL"].str.contains("Instruct")) | (df["MODEL"].str.contains("GPT"))]
    df = df[(df["MODEL"].str.contains("Instruct"))]

# # Downsampling the data
# # Choosing 9 shuffles + default
# shuffles_sampled = [f"shuffle-{i}" for i in range(10)] + ['default']
# df = deepcopy(df[df["ENEM_EXAM_TYPE"].isin(shuffles_sampled)])

# Humans
df_human = pd.read_parquet("humans-irt-lz.parquet") 
df_human["EXAM_YEAR"] = df_human["EXAM_YEAR"].astype(int)
df_human = deepcopy(df_human[df_human["EXAM_YEAR"] == year])
# TODO: Remove this line
df_human = df_human.sample(frac=0.01)
df_human = df_human.rename(columns={"PFscores": "LZ_SCORE"})
df_human["ENEM_EXAM"] = df_human["EXAM_SUBJECT"].apply(lambda x: ENEM_EXAM_MAPPING[year][x])
df_human["MODEL"] = "Human"

# # Plot scatterplots
# # IRT
# for i, exam in enumerate(sorted(df["ENEM_EXAM"].unique())):
#     print(exam)
#     fig, ax = plt.subplots(figsize=(4, 3))
#     sns.kdeplot(data=df_human[df_human["ENEM_EXAM"] == exam], x="CTT_SCORE", y="IRT_SCORE", hue="MODEL", style="LANGUAGE", fill=True, alpha=0.5, legend=False, ax=ax)
#     g = sns.scatterplot(data=df[df["ENEM_EXAM"] == exam], x="CTT_SCORE", y="IRT_SCORE", hue="MODEL", style="LANGUAGE", hue_order=sorted(df.MODEL.unique()), style_order=sorted(df.LANGUAGE.unique()), ax=ax)
#     plt.xticks(range(0, 46, 5))
#     plt.yticks(range(-2, 5, 1))
#     plt.xlabel("CTT Score (Accuracy)")
#     plt.ylabel(r"$\theta$")
#     # Remove legend if exam is not 2022 Mathematics
#     if exam != "2022 Mathematics":
#         g.legend_.remove()
#     else:
#         g.legend(ncol=1, loc='lower right', bbox_to_anchor=(1.0, 0.0), fontsize=6)
#     plt.tight_layout()
#     # Find the key respective to the exam in the ENEM_EXAM_MAPPING
#     exam_key = [key for key, value in ENEM_EXAM_MAPPING[year].items() if value == exam][0]
#     plt.savefig(f"plots/scatterplot-human-ctt-irt-{year}-{exam_key}-{args.system_prompt_type}.pdf", dpi=800)
#     plt.close()

# LZ
plt.rcParams.update({'font.size': 6})
fig, ax = plt.subplots(1, 4, figsize=(1.4*4, 1.65), sharex=True, sharey=True, layout="compressed")
#for i, exam in enumerate(sorted(df["ENEM_EXAM"].unique())):
for i, exam in enumerate(EXAM_ORDERING[year]):
    print(exam)
    #fig, ax = plt.subplots(figsize=(4, 3))
    sns.kdeplot(data=df_human[df_human["ENEM_EXAM"] == exam], x="LZ_SCORE", y="IRT_SCORE", hue="MODEL", style="LANGUAGE", fill=True, alpha=0.5, legend=False, ax=ax[i])
    g = sns.scatterplot(data=df[df["ENEM_EXAM"] == exam], x="LZ_SCORE", y="IRT_SCORE", hue="MODEL", style="LANGUAGE", hue_order=sorted(df.MODEL.unique()), style_order=sorted(df.LANGUAGE.unique()), ax=ax[i], s=10)
    ax[i].set_xticks(range(-6, 5, 2))
    ax[i].set_yticks(range(-2, 5, 1))
    ax[i].set_xlabel(r"$l_z$ score")
    #ax[i].set_ylabel(r"$\theta$")
    ax[i].set_ylabel("IRT score")
    ax[i].set_title(exam, fontsize=6)
    # Remove legend if exam is not Mathematics
    if i != 3:
        g.legend_.remove()
    else:
        # Move the legend to the top left corner
        sns.move_legend(ax[i], loc='center left', ncol=1, bbox_to_anchor=(1, 0.5), fontsize=5)

    # # Remove legend if exam is not Mathematics
    # if exam != ENEM_EXAM_MAPPING[year]["MT"]:
    #     g.legend_.remove()
    # else:
    #     # Move the legend to the top left corner
    #     g.legend(ncol=2, fontsize=6, loc="upper left")
    plt.tight_layout()
    exam_key = [key for key, value in ENEM_EXAM_MAPPING[year].items() if value == exam][0]
    # plt.savefig(f"plots/scatterplots/{instruct}/scatterplot-human-ctt-lz-{year}-{exam_key}-{args.system_prompt_type}.pdf", dpi=800)
    # plt.close()
plt.savefig(f"plots/scatterplots/scatterplot-human-ctt-lz-{year}-{args.system_prompt_type}-{instruct}.pdf", dpi=800)
