from copy import deepcopy
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import argparse

matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
plt.rc('font', size=6)

ENEM_MAPPING_NAME = {
    #2022
    "ENEM_2022_CH_CO_PROVA_1062": "2022 Humanities",
    "ENEM_2022_CN_CO_PROVA_1092": "2022 Natural Sciences",
    "ENEM_2022_LC_CO_PROVA_1072": "2022 Languages and Codes",
    "ENEM_2022_MT_CO_PROVA_1082": "2022 Mathematics",
    #2023
    "ENEM_2023_CH_CO_PROVA_1197": "2023 Humanities",
    "ENEM_2023_CN_CO_PROVA_1227": "2023 Natural Sciences",
    "ENEM_2023_LC_CO_PROVA_1207": "2023 Languages and Codes",
    "ENEM_2023_MT_CO_PROVA_1217": "2023 Mathematics"
}

MODEL_NAME_MAPPING = {
    "Meta-Llama-3-8b": "LLaMA3-8B",
    "Meta-Llama-3-8b-Instruct": "LLaMA3-8B Instruct",
    "gemma-7b": "Gemma-7B",
    "gemma-7b-it": "Gemma-7B Instruct",
    "Llama-2-13b-hf": "LLaMA2-13B",
    "Llama-2-13b-chat-hf": "LLaMA2-13B Instruct",
    "Mistral-7B-v0.1": "Mistral-7B",
    "Mistral-7B-Instruct-v0.1": "Mistral-7B Instruct",
    "Llama-2-7b-hf": "LLaMA2-7B",
    "Llama-2-7b-chat-hf": "LLaMA2-7B Instruct",
    "GPT-3.5-turbo-0125": "ChatGPT-3.5",
}

EXAM_ORDERING = {
    2022: ["2022 Languages and Codes", "2022 Humanities", "2022 Natural Sciences", "2022 Mathematics"],
    2023: ["2023 Languages and Codes", "2023 Humanities", "2023 Natural Sciences", "2023 Mathematics"]
}

# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("--year", "-y", type=int, required=True, help="Year of ENEM exam")
parser.add_argument("--system_prompt_type", "-p", type=str, required=True, choices=["one-shot", "four-shot", "zero-shot"], help="System prompt type")
parser.add_argument('--is_instruct_version', "-i", type=int, choices=[0, 1], help='Whether the model is the instruct version or not')
args = parser.parse_args()
year = args.year
instruct = "instruct" if args.is_instruct_version == 1 else "no-instruct"

df = pd.read_parquet("enem-experiments-results-processed.parquet")
df = deepcopy(df[df.ENEM_EXAM.str.contains(f"{year}")])
df = deepcopy(df[df.SYSTEM_PROMPT_TYPE == args.system_prompt_type])

df.replace({"MODEL_NAME": MODEL_NAME_MAPPING}, inplace=True)
if args.is_instruct_version == 0:
    df = df[~df["MODEL_NAME"].str.contains("Instruct")]
else:
    #df = df[(df["MODEL_NAME"].str.contains("Instruct")) | (df["MODEL_NAME"].str.contains("GPT"))]
    df = df[(df["MODEL_NAME"].str.contains("Instruct"))]

# concat MODEL_NAME and MODEL_SIZE in one column
df["FULL_MODEL"] = df["MODEL_NAME"].astype(str)

df["ENEM_EXAM_YEAR"] = df["ENEM_EXAM"].apply(lambda x: x.split("_")[1])
df["ENEM_EXAM_CODE"] = df["ENEM_EXAM"].apply(lambda x: x.split("_")[2])

df.replace({"ENEM_EXAM": ENEM_MAPPING_NAME}, inplace=True)


# Plot heatmap of models x questions sorted by difficulty
df_items = pd.read_csv(f"data/raw-enem-exams/microdados_enem_{year}/DADOS/ITENS_PROVA_{year}.csv", sep=";", encoding="latin-1")

plot_i, plot_j = 0, 0
plt.rcParams.update({"font.size": 6})
fig, axes = plt.subplots(2, 2, figsize=(4.5, 3), gridspec_kw = {'hspace':0.4})
#for i, enem_exam in enumerate(sorted(df.ENEM_EXAM.unique())):
for i, enem_exam in enumerate(EXAM_ORDERING[year]):
    # Set fontsize
    #fig, axes = plt.subplots(2, 1, figsize=(2.5, 1.75), height_ratios=[0.5, 1], gridspec_kw = {'hspace':0.3})
    #fig, axes = plt.subplots(2, 1, figsize=(3.5, 3.5), height_ratios=[0.5, 1], gridspec_kw = {'hspace':0.3})
    sample_df = deepcopy(df[df.ENEM_EXAM == enem_exam])
    sample_df["CO_PROVA"] = sample_df["CO_PROVA"].astype(int)
    matrix_response_pattern = []
    avg_lz_scores = []
    std_lz_scores = []
    idx_name = []
    exam = df_items[df_items.CO_PROVA == sample_df.iloc[0, :].CO_PROVA]
    # Remove english as foreign language questions
    exam = exam[exam.TP_LINGUA != 0].sort_values(by="CO_POSICAO").reset_index(drop=True)
    exam["IDX_POSICAO"] = exam.index
    exam.sort_values(by="NU_PARAM_B", inplace=True)
    # Remove questions with no difficulty (NaN)
    exam.dropna(subset=["NU_PARAM_B"], inplace=True)
    full_models = sorted(sample_df.FULL_MODEL.unique())
    # Change the order of the models (mixtral (last one) after GPT-3.5 (first one))
    full_models = list(full_models)
    for full_model in full_models:
        for language in sorted(sample_df.LANGUAGE.unique()):
            sample_df_model = sample_df[(sample_df.FULL_MODEL == full_model) & (sample_df.LANGUAGE == language)]
            avg_lz = sample_df_model.LZ_SCORE.mean()
            avg_lz_scores.append(avg_lz)
            std_lz = sample_df_model.LZ_SCORE.std()
            std_lz_scores.append(std_lz)
            response_pattern_matrix = np.array(list(sample_df_model.RESPONSE_PATTERN.apply(lambda x: list(x))))
            # Convert each response pattern to a list of integers
            response_pattern_matrix = response_pattern_matrix.astype(int)
            if response_pattern_matrix.shape != (31, 45):
                print(f"Error in {full_model} {language}")
                print(response_pattern_matrix.shape)
                print(response_pattern_matrix)
                raise SystemExit()
            # Compute the average of the response pattern divided by the number of executions (rows)
            response_pattern = np.mean(response_pattern_matrix, axis=0)

            # Sort the response pattern by the difficulty of the question
            response_pattern = response_pattern[exam.IDX_POSICAO.values]
            matrix_response_pattern.append(response_pattern)
            idx_name.append(f"{full_model} ({language.upper()})")
        
    # Remapping idx_names to pretty names
    
    # idx_name = [name.replace("Meta-Llama-3-8b en", "LLaMA3-8B (EN)") for name in idx_name]
    # idx_name = [name.replace("Meta-Llama-3-8b-Instruct en", "LLaMA3-8B Instruct (EN)") for name in idx_name]
    # idx_name = [name.replace("gemma-7b en", "Gemma-7B (EN)") for name in idx_name]
    # idx_name = [name.replace("gemma-7b-it en", "Gemma-7B Instruct (EN)") for name in idx_name]
    # idx_name = [name.replace("Llama-2-13b-hf en", "LLaMA-13B (EN)") for name in idx_name]
    # idx_name = [name.replace("Llama-2-13b-chat-hf en", "LLaMA-13B Instruct (EN)") for name in idx_name]
    # idx_name = [name.replace("Mistral-7B-v0.1 en", "Mistral-7B (EN)") for name in idx_name]
    # idx_name = [name.replace("Mistral-7B-Instruct-v0.1 en", "Mistral-7B Instruct (EN)") for name in idx_name]
    # idx_name = [name.replace("Llama-2-7b-hf en", "LLaMA-7B (EN)") for name in idx_name]
    # idx_name = [name.replace("Llama-2-7b-chat-hf en", "LLaMA-7B Instruct (EN)") for name in idx_name]
    # idx_name = [name.replace("GPT-3.5-turbo-0125 en", "ChatGPT-3.5 (EN)") for name in idx_name]
    
    # idx_name = [name.replace("Meta-Llama-3-8b pt-br", "LLaMA3-8B (PT-BR)") for name in idx_name]
    # idx_name = [name.replace("Meta-Llama-3-8b-Instruct pt-br", "LLaMA3-8B Instruct (PT-BR)") for name in idx_name]
    # idx_name = [name.replace("gemma-7b pt-br", "Gemma-7B (PT-BR)") for name in idx_name]
    # idx_name = [name.replace("gemma-7b-it pt-br", "Gemma-7B Instruct (PT-BR)") for name in idx_name]
    # idx_name = [name.replace("Llama-2-13b-hf pt-br", "LLaMA-13B (PT-BR)") for name in idx_name]
    # idx_name = [name.replace("Llama-2-13b-chat-hf pt-br", "LLaMA-13B Instruct (PT-BR)") for name in idx_name]
    # idx_name = [name.replace("Mistral-7B-v0.1 pt-br", "Mistral-7B (PT-BR)") for name in idx_name]
    # idx_name = [name.replace("Mistral-7B-Instruct-v0.1 pt-br", "Mistral-7B Instruct (PT-BR)") for name in idx_name]
    # idx_name = [name.replace("Llama-2-7b-hf pt-br", "LLaMA-7B (PT-BR)") for name in idx_name]
    # idx_name = [name.replace("Llama-2-7b-chat-hf pt-br", "LLaMA-7B Instruct (PT-BR)") for name in idx_name]
    # idx_name = [name.replace("GPT-3.5-turbo-0125 pt-br", "ChatGPT-3.5 (PT-BR)") for name in idx_name]

    n_questions = len(exam.IDX_POSICAO.values)
    min_item_difficulty = np.min(exam.NU_PARAM_B.values)
    max_item_difficulty = np.max(exam.NU_PARAM_B.values)
    
    # axes[1].imshow(matrix_response_pattern, cmap="gray_r", aspect="auto")
    # axes[1].set_xticks(np.arange(len(exam.IDX_POSICAO.values)), labels=exam.IDX_POSICAO.values)
    # axes[1].set_yticks(np.arange(len(idx_name)), labels=idx_name, fontsize=6)
    # axes[1].set_xticklabels([])
    # axes[1].set_xlabel("Question", fontsize=6, labelpad=-3)

    axes[plot_i, plot_j].imshow(matrix_response_pattern, cmap="gray_r", aspect="auto")
    axes[plot_i, plot_j].set_xticks(np.arange(len(exam.IDX_POSICAO.values)), labels=exam.IDX_POSICAO.values)
    axes[plot_i, plot_j].set_yticks(np.arange(len(idx_name)), labels=idx_name, fontsize=6)
    #axes[plot_i+1, plot_j].set_xticklabels([])
    axes[plot_i, plot_j].set_xticklabels(range(1, n_questions+1, 1), fontsize=6)
    if plot_i != 0:
        axes[plot_i, plot_j].set_xlabel("Question", fontsize=6, labelpad=-3)
    else:
        axes[plot_i, plot_j].set_xlabel("")

    if plot_j != 0:
        axes[plot_i, plot_j].set_yticks(np.arange(len(idx_name)), labels=[], fontsize=6)

    axes[plot_i, plot_j].set_title(enem_exam, fontsize=6)

    # Hide some of the xtickslabels
    for idx, label in enumerate(axes[plot_i, plot_j].xaxis.get_ticklabels()):
        if idx == 0 or idx == n_questions-1:
            continue
        label.set_visible(False)

    # # Hide some of the xtickslabels
    # for idx, label in enumerate(axes[1].xaxis.get_ticklabels()):
    #     if idx == 0 or idx == n_questions-1:
    #         continue
    #     label.set_visible(False)

    # # Add the average lz scores as text in the end of each row of the heatmap
    # for j, (avg_lz, std_lz) in enumerate(zip(avg_lz_scores, std_lz_scores)):
    #     if avg_lz < 0:
    #         axes[1].text(n_questions+1, j, f"{avg_lz:.2f} ({std_lz:.2f})", fontsize=6, va="center")
    #     else:
    #         axes[1].text(n_questions+1, j, f" {avg_lz:.2f} ({std_lz:.2f})", fontsize=6, va="center")

    # for idx, label in enumerate(axes[plot_i+1, plot_j].xaxis.get_ticklabels()):
    #     if idx == 0 or idx == n_questions-1:
    #         continue
    #     label.set_visible(False)

    # # Add the average lz scores as text in the end of each row of the heatmap
    # for j, (avg_lz, std_lz) in enumerate(zip(avg_lz_scores, std_lz_scores)):
    #     if avg_lz < 0:
    #         axes[plot_i+1, plot_j].text(n_questions+1, j, f"{avg_lz:.2f} ({std_lz:.2f})", fontsize=6, va="center")
    #     else:
    #         axes[plot_i+1, plot_j].text(n_questions+1, j, f" {avg_lz:.2f} ({std_lz:.2f})", fontsize=6, va="center")
    
    # axes[0].plot(range(1, n_questions+1), exam.NU_PARAM_B.values, "-")
    # axes[0].set_xticks(range(1, n_questions+1, 1))
    # axes[0].set_xticklabels(range(1, n_questions+1, 1), fontsize=6)
    # axes[0].set_yticks(axes[0].get_yticks())
    # axes[0].set_yticklabels(axes[0].get_yticks(), fontsize=6)
    # axes[0].set_ylabel("Item\nDifficulty", fontsize=6)
    # axes[0].set_xlim(xmin=1, xmax=n_questions)
    # axes[0].set_ylim(ymin=min_item_difficulty, ymax=max_item_difficulty)

    # # Hide some of the xtickslabels
    # for idx, label in enumerate(axes[0].xaxis.get_ticklabels()):
    #     if idx == 0 or idx == n_questions-1:
    #         continue
    #     label.set_visible(False)

    # axes[plot_i, plot_j].plot(range(1, n_questions+1), exam.NU_PARAM_B.values, "-")
    # axes[plot_i, plot_j].set_xticks(range(1, n_questions+1, 1))
    # #axes[plot_i, plot_j].set_xticklabels(range(1, n_questions+1, 1), fontsize=6)
    # axes[plot_i, plot_j].set_xticklabels([])
    # #axes[plot_i, plot_j].set_yticks(axes[plot_i, plot_j].get_yticks())
    # axes[plot_i, plot_j].set_yticks([-1, 1, 3, 5, 7], labels=[-1, 1, 3, 5, 7], fontsize=6)
    # axes[plot_i, plot_j].set_ylabel("Item\nDifficulty", fontsize=6)
    # if plot_j != 0:
    #     axes[plot_i, plot_j].set_yticklabels([])
    #     axes[plot_i, plot_j].set_ylabel("")
    # axes[plot_i, plot_j].set_xlim(xmin=1, xmax=n_questions)
    # #axes[plot_i, plot_j].set_ylim(ymin=min_item_difficulty, ymax=max_item_difficulty)
    # axes[plot_i, plot_j].set_title(enem_exam, fontsize=6)

    # # Hide some of the xtickslabels
    # for idx, label in enumerate(axes[plot_i, plot_j].xaxis.get_ticklabels()):
    #     if idx == 0 or idx == n_questions-1:
    #         continue
    #     label.set_visible(False)

    # Both are unique in this case
    if len(sample_df.ENEM_EXAM_CODE.unique()) > 1:
        print("More than one code")
        raise SystemExit()

    if len(sample_df.ENEM_EXAM_YEAR.unique()) > 1:
        print("More than one year")
        raise SystemExit()
    
    enem_code = sample_df.ENEM_EXAM_CODE.unique()[0]
    enem_year = sample_df.ENEM_EXAM_YEAR.unique()[0]

    plot_i += 1
    if plot_i == 2:
        plot_i = 0
        plot_j += 1
    # plot_i = 0
    # plot_j += 1

    # plt.savefig(f"plots/heatmaps/response-pattern-heatmap-{enem_year}-{enem_code}-{args.system_prompt_type}.pdf", bbox_inches='tight', pad_inches=0.05, dpi=800)
    # plt.close()
plt.savefig(f"plots/heatmaps/response-pattern-heatmap-{enem_year}-{args.system_prompt_type}-{instruct}.pdf", bbox_inches='tight', pad_inches=0.05, dpi=800)