import json
import os

import evaluate

from evaluation.schemas import OpenEndedTaskBase, GeneratorModelBase, VllmGeneratorModel
from evaluation.tasks.generation.codex_eval.utils import read_problems, evaluate_functional_correctness, write_jsonl

exact_match = evaluate.load("exact_match")


def download_codex():
    current_dir = os.path.abspath(os.path.dirname(__file__))
    os.system(
        f"wget -P {current_dir} https://github.com/openai/human-eval/raw/master/data/HumanEval.jsonl.gz")
    os.system(f"wget -P {current_dir} "
              f"https://huggingface.co/datasets/bigcode/humanevalpack/raw/main/data/python/data/humanevalpack.jsonl")


class CodexEval(OpenEndedTaskBase):
    EVAL_PASS_AT_KS = [1, 5, 10, 20]
    UNBIASED_SAMPLING_SIZE_N = 20

    def __init__(self):
        self.max_new_tokens, self.temperature, self.top_p = self.prepare_hparams()
        super().__init__()

    def prepare_hparams(self):
        """
        max_new_tokens: int, maximum number of tokens to generate.
        temperature: float, temperature for sampling.
        top_p: float, top_p for nucleus
        """
        return 512, 0.8, 0.95

    def required_input_fields(self) -> list[str]:
        """
        Get the required input fields for the task.

        Returns:
            List of required input fields.
        """
        return ["instruction"]

    def load(self) -> list[dict]:
        current_dir = os.path.abspath(os.path.dirname(__file__))
        if not os.path.exists(os.path.join(current_dir, "HumanEval.jsonl.gz")):
            download_codex()

        test_data = list(read_problems(os.path.join(current_dir, "HumanEval.jsonl.gz")).values())

        # these stop sequences are those mentioned in the codex paper.
        self.stop_sequences = ["\nclass", "\ndef", "\n#", "\nif", "\nprint"]

        data_file_hep_path = os.path.join(current_dir, "humanevalpack.jsonl")
        with open(data_file_hep_path) as fin:
            instructions = [json.loads(l) for l in fin]
            self.instructions_dict = {
                x["task_id"].replace("Python", "HumanEval"): x["instruction"] for x in instructions
            }
            self.stop_sequences.append("\n```")

        loaded_data = []
        for example in test_data:
            instruction = self.instructions_dict[example["task_id"]]
            loaded_data.append({"instruction": instruction, **example})

        return loaded_data

    def _evaluate(self, model: VllmGeneratorModel) -> tuple[dict, list[dict]]:
        assert isinstance(model, VllmGeneratorModel), \
            "Model should be an instance of VllmGeneratorModel. Not supported for other models."
        current_dir = os.path.abspath(os.path.dirname(__file__))

        # update model template
        current_model_input_template = model.model_input_template
        answer = "Here is the function:\n\n```python\n"
        if current_model_input_template[-1] not in ["\n", " "]:
            model.model_input_template = current_model_input_template + " " + answer + "{prompt}"
        else:
            model.model_input_template = current_model_input_template + answer + "{prompt}"

        completions = model.generate_completions(self.data,
                                                 max_new_tokens=self.max_new_tokens,
                                                 temperature=self.temperature,
                                                 top_p=self.top_p,
                                                 n=self.UNBIASED_SAMPLING_SIZE_N,
                                                 return_full_outputs=True,
                                                 stop=self.stop_sequences)
        completions = [output.text for it in completions for output in it.outputs]
        model.model_input_template = current_model_input_template

        # duplicates test data to match the number of outputs.
        duplicate_test_data = [
            example for example in self.data for _ in range(self.UNBIASED_SAMPLING_SIZE_N)
        ]
        assert len(duplicate_test_data) == len(completions)

        predictions = [{"id": example["task_id"] + f"_{i % self.UNBIASED_SAMPLING_SIZE_N}",
                        "task_id": example["task_id"],
                        "instruction": self.instructions_dict[example["task_id"]],
                        "prompt": example["prompt"],
                        "completion": output} for
                       i, (example, output) in enumerate(zip(duplicate_test_data, completions))]
        prediction_save_path = os.path.join(current_dir, "codex_eval_predictions.jsonl")
        write_jsonl(prediction_save_path, predictions)

        pass_at_k_results = evaluate_functional_correctness(
            sample_file=prediction_save_path,
            k=self.EVAL_PASS_AT_KS,
            problems={example["task_id"]: example for example in self.data},
            n_workers=64
        )

        return pass_at_k_results, predictions
