import gc
import os

import numpy as np
os.environ['HF_HOME'] = "../cache/"
os.environ['TRANSFORMERS_CACHE'] = "../cache/"

import argparse
import time
import pandas as pd
import torch
from models import LLAMA2, LLAMA3, CommandR, Mistral, Gemma, GPT, Pythia
from exam import ENEM
from transformers import set_seed

def run_experiment(args, model_name, model):    
    # Set seed
    seed = args.seed
    set_seed(seed)

    # Print args
    print("Model: ", args.model)
    print("Model size: ", args.model_size)
    print("Instruct version: ", args.is_instruct_version)
    print("System prompt type: ", args.system_prompt_type)
    print("ENEM exam: ", args.enem_exam)
    print("Exam type: ", args.exam_type)
    print("Question order: ", args.question_order)
    print("Language: ", args.language)
    print("Number of options: ", args.number_options)
    print("Seed: ", seed)
    print("\n------------------\n")

    print("Execution started\n")

    filename_check = f"enem-experiments-results/{model_name}-{args.system_prompt_type}-{args.enem_exam}-{args.exam_type}-{args.question_order}-{args.language}-{args.number_options}-{seed}-prob-dist.parquet"
    if os.path.exists(filename_check):
        print(f"File {filename_check} already exists. Skipping...")
        return

    # Load ENEM exam
    enem = ENEM(args.enem_exam, exam_type=args.exam_type, question_order=args.question_order, seed=seed, language=args.language, number_options=args.number_options)

    # Run model on ENEM exam and save results to file

    # Saving model responses (letters and binary pattern), correct responses and ctt score
    model_response_pattern = ""
    correct_response_pattern = ""
    model_response_binary_pattern = ""
    ctt_score = 0

    # Also measure time
    start_time = time.time()

    correct_answers = []
    parsed_answers = []
    prob_dists = []
    all_logits = []
    all_system_fingerprint = []

    for i in range(enem.get_enem_size()):
        print(f"Question {i}")
        st = time.time()
        question = enem.get_question(i)
        correct_answer = enem.get_correct_answer(i)

        if correct_answer == "anulada":
            # Voided question
            model_response_pattern += "V"
            correct_response_pattern += "V"
            model_response_binary_pattern += "0"
            prob_dists.append(None)
            all_logits.append(None)
            all_system_fingerprint.append(None)
            correct_answers.append("anulada")
            parsed_answers.append("anulada")
            continue
        
        # In GPT we do not have the logits
        if model_name == "GPT-3.5-turbo-0125":
            model_answer, system_fingerprint = model.get_answer_from_question(question, system_prompt_type=args.system_prompt_type, language=args.language)
            all_system_fingerprint.append(system_fingerprint)
            prob_dist, logits = None, None
        else:
            model_answer, prob_dist, logits = model.get_answer_from_question(question, system_prompt_type=args.system_prompt_type, language=args.language)            

        prob_dists.append(prob_dist)
        all_logits.append(logits)
        correct_answers.append(correct_answer)
        parsed_answers.append(model_answer)

        if model_answer is None or not model_answer in list("ABCDE"):
            # Raise warning when model answer is None
            print("Warning: model answer is None for question ", i)
            model_answer = "X"

        if model_answer == correct_answer:
            model_response_binary_pattern += "1"
            ctt_score += 1
        else:
            model_response_binary_pattern += "0"

        model_response_pattern += model_answer
        correct_response_pattern += correct_answer

        print(f"Time: {time.time()-st} seconds\n")

    end_time = time.time()

    # Remap answer pattern to original order
    if args.exam_type.startswith("shuffle"):
        model_response_pattern_remapped = enem.remapping_answer_pattern(model_response_pattern)
        correct_response_pattern_remapped = enem.remapping_answer_pattern(correct_response_pattern)

        # Swap variables (TX_RESPOSTAS AND TX_GABARITO have to be in the original order)
        model_response_pattern, model_response_pattern_remapped = model_response_pattern_remapped, model_response_pattern
        correct_response_pattern, correct_response_pattern_remapped = correct_response_pattern_remapped, correct_response_pattern
    else:
        model_response_pattern_remapped = None
        correct_response_pattern_remapped = None

    # Save results to file (in the order of the arguments)
    filename = f"enem-experiments-results/{model_name}-{args.system_prompt_type}-{args.enem_exam}-{args.exam_type}-{args.question_order}-{args.language}-{args.number_options}-{seed}.parquet"
    df = pd.DataFrame({"MODEL_NAME": [model_name], "SYSTEM_PROMPT_TYPE": [args.system_prompt_type], "ENEM_EXAM": [args.enem_exam], "ENEM_EXAM_TYPE": [args.exam_type], "QUESTION_ORDER": [args.question_order], "LANGUAGE": [args.language], "NUMBER_OPTIONS": [args.number_options], "SEED": [seed], "CTT_SCORE": [ctt_score], "TX_RESPOSTAS": [model_response_pattern], "TX_GABARITO": [correct_response_pattern], "TX_RESPOSTAS_SHUFFLE": [model_response_pattern_remapped], "TX_GABARITO_SHUFFLE": [correct_response_pattern_remapped], "RESPONSE_PATTERN": [model_response_binary_pattern], "TOTAL_RUN_TIME_SEC": [end_time-start_time], "AVG_RUN_TIME_PER_ITEM_SEC": [(end_time-start_time)/enem.get_enem_size()]})
    df.to_parquet(filename)

    # Saving the probability distribution of tokens in tokenizer_map (A, B, C, D, E) - prob_dists
    filename = f"enem-experiments-results/{model_name}-{args.system_prompt_type}-{args.enem_exam}-{args.exam_type}-{args.question_order}-{args.language}-{args.number_options}-{seed}-prob-dist.parquet"
    if model_name == "GPT-3.5-turbo-0125":
        df = pd.DataFrame({"QUESTION": enem.get_question_number_array(), "CORRECT_ANSWER": correct_answers, "MODEL_ANSWER": parsed_answers, "SYSTEM_FINGERPRINT": all_system_fingerprint})
    else:
        df = pd.DataFrame({"QUESTION": enem.get_question_number_array(), "LOGITS": all_logits, "PROB_DIST": prob_dists, "CORRECT_ANSWER": correct_answers, "MODEL_ANSWER": parsed_answers})
    df.to_parquet(filename)

    # Call garbage collector
    torch.cuda.empty_cache()
    gc.collect()

    print("Execution finished\n")

# Token: HF_TOKEN env variable
token = os.getenv("HF_TOKEN")

# Create folders
if not os.path.exists("enem-experiments-results"):
    os.makedirs("enem-experiments-results")

# Get pytorch device
device = "cuda"

seed = 0

# Create an argparser
parser = argparse.ArgumentParser(description='Run model on ENEM exam')
# LLMs args
parser.add_argument('--model', type=str, choices=["llama2", "mistral", "gemma", "llama3", "pythia", "command-r", "gpt-3.5"], required=True, help='Model to run')
parser.add_argument('--model_size', type=str, choices=["70m", "160m", "410m", "1b", "1.4b", "2.8b", "6.9b", "12b", "7b", "8b", "13b", "8x7b", "35b"], help='Model size')
parser.add_argument('--is_instruct_version', type=int, choices=[0, 1], help='Whether the model is the instruct version or not')
args = parser.parse_args()

args.is_instruct_version = bool(args.is_instruct_version)

# Load model
if args.model == "llama2":
    if args.model_size == "7b" or args.model_size == "13b":
        model = LLAMA2(args.model_size, token, device, is_instruct_version=args.is_instruct_version, random_seed=seed)
        instruct_version = "-chat" if args.is_instruct_version else ""
        model_name = f"Llama-2-{args.model_size}{instruct_version}-hf"
    else:
        raise Exception("Model size not implemented for Llama-2")
elif args.model == "mistral":
    if args.model_size == "7b" or args.model_size == "8x7b":
        model = Mistral(args.model_size, token, device, is_instruct_version=args.is_instruct_version, random_seed=seed)
        instruct_version = "-Instruct" if args.is_instruct_version else ""
        model_name = f"Mistral-7B{instruct_version}-v0.1" if args.model_size == "7b" else f"Mistral-8x7B{instruct_version}-v0.1"
    else:
        raise Exception("Model size not implemented for Mistral")
elif args.model == "gemma":
    if args.model_size == "2b" or args.model_size == "7b":
        model = Gemma(args.model_size, token, device, is_instruct_version=args.is_instruct_version, random_seed=seed)
        instruct_version = "-it" if args.is_instruct_version else ""
        model_name = f"gemma-{args.model_size}{instruct_version}"
    else:
        raise Exception("Model size not implemented for Gemma")
elif args.model == "llama3":
    if args.model_size == "8b":
        model = LLAMA3(args.model_size, token, device, is_instruct_version=args.is_instruct_version, random_seed=seed)
        instruct_version = "-Instruct" if args.is_instruct_version else ""
        model_name = f"Meta-Llama-3-{args.model_size}{instruct_version}"
elif args.model == "pythia":
    if args.is_instruct_version:
        raise Exception("Pythia model is not available in the instruct version")
    
    if args.model_size in ["70m", "160m", "410m", "1b", "1.4b", "2.8b", "6.9b", "12b"]:
        model = Pythia(args.model_size, token, device, random_seed=seed)
        model_name = f"pythia-{args.model_size}-deduped"
elif args.model == "command-r":
    if not args.is_instruct_version:
        raise Exception("CommandR model is only available in the instruct version")
    
    if args.model_size == "35b":
        model = CommandR(args.model_size, token, device, random_seed=seed)
        model_name = f"c4ai-command-r-v01"
elif args.model == "gpt-3.5":
    if not args.is_instruct_version:
        raise Exception("GPT-3.5 model is only available in the instruct version")
    
    model = GPT("gpt-3.5-turbo-0125")
    model_name = "GPT-3.5-turbo-0125"
else:
    raise Exception("Model not implemented")

for enem_exam in ["ENEM_2022_LC_CO_PROVA_1072", "ENEM_2022_MT_CO_PROVA_1082", "ENEM_2022_CN_CO_PROVA_1092", "ENEM_2022_CH_CO_PROVA_1062", "ENEM_2023_LC_CO_PROVA_1207", "ENEM_2023_CH_CO_PROVA_1197", "ENEM_2023_MT_CO_PROVA_1217", "ENEM_2023_CN_CO_PROVA_1227"]:
    for exam_type in ["default", "shuffle-0", "shuffle-1", "shuffle-2", "shuffle-3", "shuffle-4", "shuffle-5", "shuffle-6", "shuffle-7", "shuffle-8", "shuffle-9", "shuffle-10", "shuffle-11", "shuffle-12", "shuffle-13", "shuffle-14", "shuffle-15", "shuffle-16", "shuffle-17", "shuffle-18", "shuffle-19", "shuffle-20", "shuffle-21", "shuffle-22", "shuffle-23", "shuffle-24", "shuffle-25", "shuffle-26", "shuffle-27", "shuffle-28", "shuffle-29"]:
        for language in ["en", "pt-br"]:
            if args.model == "pythia" and language == "pt-br":
                # Pythia was trained with English only corpora
                continue

            for system_prompt_type in ["zero-shot", "four-shot", "one-shot"]:
                number_options = 5
                question_order = "original"

                args.system_prompt_type = system_prompt_type
                args.enem_exam = enem_exam
                args.exam_type = exam_type
                args.question_order = question_order
                args.language = language
                args.number_options = number_options
                args.seed = seed

                run_experiment(args, model_name, model)