import os

import numpy as np
import pandas as pd

from evaluation.schemas import OpenEndedTaskBase, GeneratorModelBase

choices = ["A", "B", "C", "D"]


def format_subject(subject):
    l = subject.split("_")
    s = ""
    for entry in l:
        s += " " + entry
    return s


def download_mmlu():
    current_dir = os.path.abspath(os.path.dirname(__file__))
    os.system(f"wget -P {current_dir} https://people.eecs.berkeley.edu/~hendrycks/data.tar")
    os.system(f"mkdir -p {current_dir}/data/eval")
    os.system(f"tar -xvf {current_dir}/data.tar -C {current_dir}/data")
    os.system(f"mv {current_dir}/data/data {current_dir}/data/eval/mmlu")
    os.system(f"rm -r {current_dir}/data.tar")


class MMLU(OpenEndedTaskBase):
    def __init__(self):
        super().__init__()

    def required_input_fields(self) -> list[str]:
        """
        Get the required input fields for the task.

        Returns:
            List of required input fields.
        """
        return ["instruction"]

    @staticmethod
    def format_example(data: dict, include_answer: bool) -> str:
        if include_answer:
            return f'{data["question"]}\n' \
                   f'A. {data["A"]}\n' \
                   f'B. {data["B"]}\n' \
                   f'C. {data["C"]}\n' \
                   f'D. {data["D"]}\n' \
                   f'Answer: {data["answer"]}\n\n'
        else:
            return f'{data["question"]}\n' \
                   f'A. {data["A"]}\n' \
                   f'B. {data["B"]}\n' \
                   f'C. {data["C"]}\n' \
                   f'D. {data["D"]}\n' \
                   f'Answer:'

    def load(self) -> list[dict]:
        current_dir = os.path.abspath(os.path.dirname(__file__))
        if not os.path.exists(os.path.join(current_dir, "data/eval/mmlu")):
            download_mmlu()

        # data/eval/mmlu/test/*_test.csv
        subjects = sorted(
            [f.split("_test.csv")[0] for f in os.listdir(os.path.join(current_dir, "data/eval/mmlu/test"))
             if f.endswith("_test.csv")]
        )
        all_test_data = []
        for subject in subjects:
            # define column names to question,A,B,C,D,answer
            test_df = pd.read_csv(os.path.join(current_dir, "data/eval/mmlu/test/{}_test.csv".format(subject)),
                                  header=None)
            test_df.columns = ["question", "A", "B", "C", "D", "answer"]
            # add subject column
            test_df["subject"] = subject
            all_test_data.append(test_df)
        all_test_data = pd.concat(all_test_data, ignore_index=True)
        all_test_data = all_test_data.to_dict(orient="records")

        # format the data.
        for data in all_test_data:
            subject = data["subject"]
            prefix = "The following are multiple choice questions (with answers) about {}.\n\n".format(
                format_subject(subject.strip())
            )
            data["instruction"] = prefix + self.format_example(data, include_answer=False)
        return all_test_data

    def load_dev(self) -> list[dict]:
        current_dir = os.path.abspath(os.path.dirname(__file__))
        # data/eval/mmlu/dev/*_dev.csv
        subjects = sorted(
            [f.split("_dev.csv")[0] for f in os.listdir(os.path.join(current_dir, "data/eval/mmlu/dev"))
             if f.endswith("_dev.csv")]
        )
        all_dev_data = []
        for subject in subjects:
            dev_df = pd.read_csv(os.path.join(current_dir, "data/eval/mmlu/dev/{}_dev.csv".format(subject))
                                 )
            dev_df["subject"] = subject
            all_dev_data.append(dev_df)
        all_dev_data = pd.concat(all_dev_data, ignore_index=True)
        return all_dev_data.to_dict(orient="records")

    def _evaluate(self, model: GeneratorModelBase) -> tuple[dict, list[dict]]:
        current_model_input_template = model.model_input_template
        if current_model_input_template[-1] not in ["\n", " "]:
            model.model_input_template = current_model_input_template + " The answer is:"
        else:
            model.model_input_template = current_model_input_template + "The answer is:"

        # get subjects
        subjects = list(set([data["subject"] for data in self.data]))

        results = {}
        all_accs = []
        all_predictions = []
        hf_model, hf_tokenizer = model.load_model(model_name_or_path=model.model_name_or_path)
        for subject in subjects:
            # get data points from the same subject
            test_data = [data for data in self.data if data["subject"] == subject]
            prompts = [data["instruction"] for data in test_data]
            predictions, probs = model.get_next_word_predictions(hf_model, hf_tokenizer, prompts, choices)
            cors = []
            for i in range(len(predictions)):
                prediction = choices[predictions[i]]
                ground_truth = test_data[i]["answer"]
                cors.append(prediction == ground_truth)
            acc = np.mean(cors)
            all_accs.append(acc)
            results[subject] = acc
            for i in range(len(predictions)):
                all_predictions.append({
                    "id": len(all_predictions),
                    "input": prompts[i],
                    "target": test_data[i]["answer"],
                    "prediction": choices[predictions[i]],
                })

        results["overall"] = np.mean(all_accs)
        model.model_input_template = current_model_input_template
        return results, all_predictions
