import abc
import json
import os
from abc import ABC

import datasets
from alpaca_eval import constants
from alpaca_eval import evaluate as alpaca_farm_evaluate

from evaluation.schemas import OpenEndedTaskBase, GeneratorModelBase


class AlpacaEvalBase(OpenEndedTaskBase, ABC):
    def __init__(self):
        super().__init__()
        self.max_new_tokens, self.temperature, self.top_p = self.prepare_hparams()

    @abc.abstractmethod
    def prepare_hparams(self):
        raise NotImplementedError

    def required_input_fields(self) -> list[str]:
        """
        Get the required input fields for the task.

        Returns:
            List of required input fields.
        """
        return ["instruction"]


class AlpacaEval2_0(AlpacaEvalBase):
    def prepare_hparams(self):
        # eval task specified arguments
        max_new_tokens = 8192
        temperature = 0.0
        top_p = 1.0
        return max_new_tokens, temperature, top_p

    def load(self) -> list[dict]:
        """
        Load the task data. It will load data to self.data.

        Returns:
            List of task data.
        """
        alpaca_eval_data = datasets.load_dataset("tatsu-lab/alpaca_eval", "alpaca_eval")["eval"]
        instructions = [{"instruction": row["instruction"]} for row in alpaca_eval_data]
        return instructions

    def _evaluate(self, model: GeneratorModelBase) -> tuple[dict, list[dict]]:
        inputs = [{"instruction": row["instruction"]} for row in self.data]
        completions = model.generate_completions(inputs,
                                                 max_new_tokens=self.max_new_tokens,
                                                 temperature=self.temperature,
                                                 top_p=self.top_p)
        assert len(completions) == len(self.data)
        model_results = []
        model_id = "_".join(model.model_name_or_path.split("/"))
        for example, output in zip(self.data, completions):
            example["output"] = output
            example["generator"] = model_id
            model_results.append(example)

        current_dir = os.path.dirname(__file__)
        os.makedirs(os.path.join(current_dir, "cache"), exist_ok=True)
        output_path = os.path.join(current_dir, "results", model_id)
        annotator_name = "weighted_alpaca_eval_gpt4_turbo"
        annotation_path = os.path.join(output_path, annotator_name, "annotations.json")
        os.makedirs(output_path, exist_ok=True)
        df_leaderboard, _ = alpaca_farm_evaluate(
            model_outputs=model_results,
            reference_outputs=constants.ALPACAEVAL_REFERENCE_OUTPUTS_2,
            annotators_config="weighted_alpaca_eval_gpt4_turbo",
            fn_metric="get_length_controlled_winrate",
            output_path=output_path,
            is_return_instead_of_print=True,
            caching_path=os.path.join(current_dir, "cache", "alpaca_eval_annotator_cache.json"),
        )
        for ex, input in zip(model_results, self.data):
            assert ex["instruction"] == input["instruction"]
            ex["id"] = input["id"]

        # pick the row with model_id
        selected_row = df_leaderboard[df_leaderboard.index == model_id]
        selected_row = selected_row.to_dict(orient="records")[0]

        with open(annotation_path, "r") as f:
            annotations = json.load(f)

        return selected_row, annotations
