import glob
import json
import os
import random
import re

import evaluate

from evaluation.schemas import OpenEndedTaskBase, GeneratorModelBase
from evaluation.tasks.metrics import get_average

exact_match = evaluate.load("exact_match")


def download_bbh():
    current_dir = os.path.abspath(os.path.dirname(__file__))
    # use the following url_format for the Big-Bench-Hard dataset
    url = "https://github.com/suzgunmirac/BIG-Bench-Hard/archive/refs/heads/main.zip"
    os.makedirs("data/downloads/bbh", exist_ok=True)
    os.system(f"wget {url} -O data/downloads/bbh_data.zip")
    os.system(f"unzip data/downloads/bbh_data.zip -d data/downloads/bbh")
    os.system(f"mv data/downloads/bbh/BIG-Bench-Hard-main/* {current_dir} && "
              "rm -r data/downloads/bbh data/downloads/bbh_data.zip")


class BigBenchHard(OpenEndedTaskBase):
    MAX_NUM_EXAMPLES_PER_TASK = 40

    def __init__(self, no_cot: bool):
        self.max_new_tokens, self.temperature, self.top_p = self.prepare_hparams()
        self.no_cot = no_cot
        super().__init__()

    def prepare_hparams(self):
        """
        max_new_tokens: int, maximum number of tokens to generate.
        temperature: float, temperature for sampling.
        top_p: float, top_p for nucleus
        """
        return 512, 0.0, 1.0

    def required_input_fields(self) -> list[str]:
        """
        Get the required input fields for the task.

        Returns:
            List of required input fields.
        """
        return ["instruction"]

    def load(self) -> list[dict]:
        current_dir = os.path.abspath(os.path.dirname(__file__))
        if not os.path.exists(os.path.join(current_dir, "bbh")):
            download_bbh()

        task_files = glob.glob(os.path.join(current_dir, "bbh", "*.json"))
        all_tasks = {}
        random.seed(42)
        for task_file in task_files:
            with open(task_file, "r") as f:
                task_name = os.path.basename(task_file).split(".")[0]
                all_tasks[task_name] = json.load(f)["examples"]
                all_tasks[task_name] = random.sample(all_tasks[task_name], self.MAX_NUM_EXAMPLES_PER_TASK)

        all_prompts = {}
        cot_prompt_files = glob.glob(os.path.join(current_dir, "cot-prompts", "*.txt"))
        for cot_prompt_file in cot_prompt_files:
            with open(cot_prompt_file, "r") as f:
                task_name = os.path.basename(cot_prompt_file).split(".")[0]
                task_prompt = "".join(f.readlines()[2:])

            if self.no_cot:
                prompt_fields = task_prompt.split("\n\n")
                new_prompt_fields = []
                for prompt_field in prompt_fields:
                    if prompt_field.startswith("Q:"):
                        assert "So the answer is" in prompt_field, f"`So the answer is` not found in prompt field of {task_name}.txt."
                        assert "\nA:" in prompt_field, "`\nA:` not found in prompt field."
                        answer = prompt_field.split("So the answer is")[-1].strip()
                        question = prompt_field.split("\nA:")[0].strip()
                        new_prompt_fields.append(question + "\nA: " + answer)
                    else:
                        new_prompt_fields.append(prompt_field)
                task_prompt = "\n\n".join(new_prompt_fields)
            all_prompts[task_name] = task_prompt

        assert set(all_tasks.keys()) == set(all_prompts.keys()), \
            "task names in task data and task prompts are not the same."

        task_data = [{
            "task_name": task_name,
            "task_inputs": all_tasks[task_name],
            "cot_prompt": all_prompts[task_name],
        } for task_name in all_tasks.keys()]

        loaded_data = []
        # input_prompts, targets, task_names = [], [], []
        for data in task_data:
            cot_prompt = data["cot_prompt"]
            task_inputs = data["task_inputs"]
            for task_input in task_inputs:
                loaded_data.append({
                    "instruction": cot_prompt.strip() + "\n\nQ: " + task_input["input"],
                    "target": task_input["target"],
                    "task_name": data["task_name"],
                })

        return loaded_data

    def _evaluate(self, model: GeneratorModelBase) -> tuple[dict, list[dict]]:
        # update model template
        current_model_input_template = model.model_input_template
        if current_model_input_template[-1] not in ["\n", " "]:
            model.model_input_template = current_model_input_template + " A:"
        else:
            model.model_input_template = current_model_input_template + "A:"

        completions = model.generate_completions(self.data,
                                                 max_new_tokens=self.max_new_tokens,
                                                 temperature=self.temperature,
                                                 top_p=self.top_p)
        model.model_input_template = current_model_input_template

        task_names = [data["task_name"] for data in self.data]
        unique_task_names = list(set(task_names))

        # cluster per task_name.
        results_per_task_name = {k: [] for k in unique_task_names}
        for task_name, completion, data in zip(task_names, completions, self.data):
            results_per_task_name[task_name].append({
                "completion": completion,
                "target": data["target"],
                "input": data["instruction"],
            })

        performance = {}
        final_outputs = []
        for task_name, results in results_per_task_name.items():
            targets = [result["target"] for result in results]
            completions = []
            for result in results:
                extracted_answer = re.search(r"[t|T]he answer is (.*?)\.", result["completion"])
                if extracted_answer:
                    completion = extracted_answer.group(1).strip()
                else:
                    completion = result["completion"].strip()
                completions.append(completion)
                final_outputs.append({
                    "id": len(final_outputs),
                    "task_name": task_name,
                    "input": result["input"],
                    "target": result["target"],
                    "output": completion,
                })
            performance[task_name] = exact_match.compute(predictions=completions,
                                                         references=targets,
                                                         ignore_case=True,
                                                         ignore_punctuation=True)["exact_match"]
        performance["overall"] = get_average(list(performance.values()))
        return performance, final_outputs
