"""
General place for datasets. Handle conversion to prompt - label form, for use in influence calculation.

We have two variants of test datasets:
- Prompts from the original task, randomly selected. If there are multiple subtasks we sample evenly from them (e.g. multiple categories in MMLU, or tasks in BBH). We round-robin sample to fill the number of samples requested.
    The default provided num of samples ensures even sampling.
- Prompts from the few-shot provided tasks, where possible. This allows doing data selection with these, and then evaluating on the 'entire' test set.

We also prompt a file dataset for random prompt testing. Input data should have prompt and label films already prepared.
"""
from minimal_multitask.eval.codex_humaneval.data import read_problems
import json
from datasets import load_dataset, Dataset
from minimal_multitask.eval.mmlu.run_mmlu_eval import construct_prompts
from minimal_multitask.utils import get_appropriate_data_dir, create_prompt_with_tulu_chat_format, encode_with_messages_format
import glob
import os
import tqdm
import random
import torch
import re
from minimal_multitask.eval.gsm.examplars import EXAMPLARS as GSM_EXAMPLARS
from minimal_multitask.eval.squad.run_squad_eval import SQUAD_SHOTS


class TestDataset:
    def __init__(self, tokenizer) -> None:
        self.tokenizer = tokenizer

    def get_all_test_prompts(self, num_samples=500, seed=42):
        raise NotImplementedError


# turns sample into something we can feed into a model
def construct_test_sample(tokenizer, sample, max_length=2048, prompt_only=False, response_only=False):
    prompt, label = sample["prompts"], sample["labels"]
    # tokenizer label and prompt independently, then concatenate.
    # this is so we match the test format (tokenize the prompt and generate)
    if prompt_only:
        prompt = prompt + tokenizer.eos_token  # add eos token to prompt
        prompt = prompt.replace("\n<|assistant|>\n", "")  # remove assistant turn
    if response_only:
        label = tokenizer.bos_token + label  # add bos token to label
    inputs = tokenizer(prompt, return_tensors="pt", max_length=max_length, truncation=True)
    labels = tokenizer(
        label + tokenizer.eos_token,
        return_tensors="pt",
        max_length=max_length,
        truncation=True,
        add_special_tokens=False,
    )
    if prompt_only:
        return {
            "input_ids": inputs["input_ids"][0],
            "attention_mask": inputs["attention_mask"][0],
            "labels": torch.ones_like(inputs["input_ids"][0]) * -100,
        }
    elif response_only:
        return {
            "input_ids": labels["input_ids"][0],
            "attention_mask": labels["attention_mask"][0],
            "labels": labels["input_ids"][0],
        }
    return {
        "input_ids": torch.cat([inputs["input_ids"][0], labels["input_ids"][0]], dim=0),
        "attention_mask": torch.cat([inputs["attention_mask"][0], labels["attention_mask"][0]], dim=0),
        "labels": torch.cat([torch.ones_like(inputs["input_ids"][0]) * -100, labels["input_ids"][0]], dim=0),
    }


# MMLU
class MMLU(TestDataset):
    shots = 0
    # default 982 prompts, just all test prompts

    def get_all_test_prompts(self, num_samples=-1, seed=42, max_length=512, prompt_only=False, response_only=False):
        # get the prompts for each subject
        prompts_per_subject = construct_prompts(
            self.tokenizer,
            use_chat_format=True,
            ntrain=self.shots,
            use_dev_samples=False,
            data_dir=os.path.join(get_appropriate_data_dir(), "eval/mmlu"),
        )
        prompts, labels = [], []
        for subject in prompts_per_subject:
            for prompt, label in prompts_per_subject[subject]:
                prompts.append(prompt)
                labels.append(label)
        # convert as normal.
        test_dataset = Dataset.from_dict({"prompts": prompts, "labels": labels})

        def make_test_sample(x):
            return construct_test_sample(self.tokenizer, x, max_length=max_length, prompt_only=prompt_only, response_only=response_only)

        test_dataset = test_dataset.map(make_test_sample, load_from_cache_file=False)
        test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
        test_dataset = test_dataset.shuffle(seed=seed)
        if num_samples > 0:
            test_dataset = test_dataset.select(range(num_samples))
        return test_dataset


class MMLUShots(TestDataset):
    # default 285 prompts, 5 per category (these are the shots) used during eval.
    def get_all_test_prompts(self, num_samples=285, seed=42, max_length=512, prompt_only=False, response_only=False):
        random_gen = random.Random(seed)
        # get the prompts for each subject
        prompts_per_subject = construct_prompts(
            self.tokenizer,
            use_chat_format=True,
            ntrain=0,
            use_dev_samples=True,
            data_dir=os.path.join(get_appropriate_data_dir(), "eval/mmlu"),
        )
        prompts, labels = [], []
        while len(prompts) < num_samples:
            for subject in prompts_per_subject:
                # random sample and remove a prompt
                random_gen.shuffle(prompts_per_subject[subject])
                prompt, label = prompts_per_subject[subject].pop()
                prompts.append(prompt)
                labels.append(label)
        # convert as normal.
        test_dataset = Dataset.from_dict({"prompts": prompts, "labels": labels})

        def make_test_sample(x):
            return construct_test_sample(self.tokenizer, x, max_length=max_length, prompt_only=prompt_only, response_only=response_only)

        test_dataset = test_dataset.map(make_test_sample, load_from_cache_file=False)
        test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
        test_dataset = test_dataset.shuffle(seed=seed).select(range(num_samples))
        return test_dataset


class MMLUShotsShots(TestDataset):
    # default 285 prompts, 5 per category (these are the shots) used during eval.
    def get_all_test_prompts(self, num_samples=285, seed=42, max_length=512, prompt_only=False, response_only=False):
        random_gen = random.Random(seed)
        # get the prompts for each subject
        prompts_per_subject = construct_prompts(
            self.tokenizer,
            use_chat_format=True,
            ntrain=5,
            use_dev_samples=True,
            data_dir=os.path.join(get_appropriate_data_dir(), "eval/mmlu"),
        )
        prompts, labels = [], []
        while len(prompts) < num_samples:
            for subject in prompts_per_subject:
                # random sample and remove a prompt
                random_gen.shuffle(prompts_per_subject[subject])
                prompt, label = prompts_per_subject[subject].pop()
                prompts.append(prompt)
                labels.append(label)
        # convert as normal.
        test_dataset = Dataset.from_dict({"prompts": prompts, "labels": labels})

        def make_test_sample(x):
            return construct_test_sample(self.tokenizer, x, max_length=max_length, prompt_only=prompt_only, response_only=response_only)

        test_dataset = test_dataset.map(make_test_sample, load_from_cache_file=False)
        test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
        test_dataset = test_dataset.shuffle(seed=seed).select(range(num_samples))
        return test_dataset


# GSM8k


class GSM8kEval(TestDataset):
    data_dir = os.path.join(get_appropriate_data_dir(), "eval/gsm")
    n_shot = 8
    cot = True

    def get_all_test_prompts(self, num_samples=None, seed=42, max_length=2048, prompt_only=False, response_only=False):
        test_data = []
        tokenizer = self.tokenizer
        with open(os.path.join(self.data_dir, "test.jsonl")) as fin:
            for line in fin:
                example = json.loads(line)
                test_data.append(
                    {"question": example["question"], "answer": example["answer"].split("####")[1].strip()}
                )

        # some numbers are in the `x,xxx` format, and we want to remove the comma
        for example in test_data:
            example["answer"] = re.sub(r"(\d),(\d)", r"\1\2", example["answer"])
            assert float(example["answer"]), f"answer is not a valid number: {example['answer']}"

        if num_samples and len(test_data) > num_samples:
            test_data = random.sample(test_data, num_samples)

        global GSM_EXAMPLARS
        if self.n_shot:
            if len(GSM_EXAMPLARS) > self.n_shot:
                GSM_EXAMPLARS = random.sample(GSM_EXAMPLARS, self.n_shot)
            demonstrations = []
            for example in GSM_EXAMPLARS:
                if not self.cot:
                    demonstrations.append(
                        "Quesion: " + example["question"] + "\n" + "Answer: " + example["short_answer"]
                    )
                else:
                    demonstrations.append(
                        "Question: " + example["question"] + "\n" + "Answer: " + example["cot_answer"]
                    )
            prompt_prefix = "Answer the following questions.\n\n" + "\n\n".join(demonstrations) + "\n\n"
        else:
            prompt_prefix = "Answer the following question.\n\n"

        prompts = []
        labels = []

        def apply_chat_format(example, tokenizer):
            messages = [{"role": "user", "content": prompt_prefix + "Question: " + example["question"].strip()}]
            prompt = create_prompt_with_tulu_chat_format(messages, tokenizer, add_bos=False)
            prompt += "Answer:" if prompt[-1] in ["\n", " "] else " Answer:"
            return prompt

        for example in test_data:
            prompts.append(apply_chat_format(example, tokenizer))
            labels.append(example["answer"])
        test_dataset = Dataset.from_dict({"prompts": prompts, "labels": labels})
        test_dataset = test_dataset.map(lambda x: construct_test_sample(self.tokenizer, x, max_length=max_length, prompt_only=prompt_only, response_only=response_only))
        test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
        test_dataset = test_dataset.shuffle(seed=seed).select(range(min(num_samples, len(test_dataset))))
        return test_dataset


# for this, we just use the examplars. We don't need to sample from the test set.


class GSM8kShots(TestDataset):
    n_shot = 8
    cot = True

    def get_all_test_prompts(self, num_samples=8, seed=42, max_length=2048, prompt_only=False, response_only=False):
        self.tokenizer

        global GSM_EXAMPLARS
        if len(GSM_EXAMPLARS) > num_samples:
            GSM_EXAMPLARS = random.sample(GSM_EXAMPLARS, num_samples)
        prompt_prefix = "Answer the following question.\n\nQuestion: "

        prompts = []
        labels = []

        def apply_chat_format(example, tokenizer):
            messages = [{"role": "user", "content": prompt_prefix + example["question"].strip()}]
            prompt = create_prompt_with_tulu_chat_format(messages, tokenizer, add_bos=False)
            prompt += "Answer:" if prompt[-1] in ["\n", " "] else " Answer:"
            return prompt

        for example in GSM_EXAMPLARS:
            prompts.append(apply_chat_format(example, self.tokenizer))
            labels.append(example["short_answer"] if not self.cot else example["cot_answer"])
        test_dataset = Dataset.from_dict({"prompts": prompts, "labels": labels})
        test_dataset = test_dataset.map(lambda x: construct_test_sample(self.tokenizer, x, max_length=max_length, prompt_only=prompt_only, response_only=response_only))
        test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
        test_dataset = test_dataset.shuffle(seed=seed).select(range(min(num_samples, len(test_dataset))))
        return test_dataset


class GSM8kShotsShots(TestDataset):
    n_shot = 7
    cot = True

    def get_all_test_prompts(self, num_samples=8, seed=42, max_length=2048, prompt_only=False, response_only=False):
        self.tokenizer

        global GSM_EXAMPLARS

        def construct_prompt_prefix(question_text):
            exemplars = [x for x in GSM_EXAMPLARS if x['question'].lower() != question_text.lower()]
            if self.n_shot:
                if len(exemplars) > self.n_shot:
                    exemplars = random.sample(exemplars, self.n_shot)
                demonstrations = []
                for example in exemplars:
                    if not self.cot:
                        demonstrations.append(
                            "Quesion: " + example["question"] + "\n" + "Answer: " + example["short_answer"]
                        )
                    else:
                        demonstrations.append(
                            "Question: " + example["question"] + "\n" + "Answer: " + example["cot_answer"]
                        )
                prompt_prefix = "Answer the following questions.\n\n" + "\n\n".join(demonstrations) + "\n\n"
            else:
                prompt_prefix = "Answer the following question.\n\n"
            return prompt_prefix

        prompts = []
        labels = []

        def apply_chat_format(example, tokenizer, prompt_prefix):
            messages = [{"role": "user", "content": prompt_prefix + "Question: " + example["question"].strip()}]
            prompt = create_prompt_with_tulu_chat_format(messages, tokenizer, add_bos=False)
            prompt += "Answer:" if prompt[-1] in ["\n", " "] else " Answer:"
            return prompt

        for example in GSM_EXAMPLARS:
            prompts.append(apply_chat_format(example, self.tokenizer, construct_prompt_prefix(example['question'])))
            labels.append(example["short_answer"] if not self.cot else example["cot_answer"])
        test_dataset = Dataset.from_dict({"prompts": prompts, "labels": labels})
        test_dataset = test_dataset.map(lambda x: construct_test_sample(self.tokenizer, x, max_length=max_length, prompt_only=prompt_only, response_only=response_only))
        test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
        test_dataset = test_dataset.shuffle(seed=seed).select(range(min(num_samples, len(test_dataset))))
        return test_dataset


# BBH


class BBHEval(TestDataset):
    subsets = [
        "boolean_expressions",
        "causal_judgement",
        "date_understanding",
        "disambiguation_qa",
        "dyck_languages",
        "formal_fallacies",
        "geometric_shapes",
        "hyperbaton",
        "logical_deduction_five_objects",
        "logical_deduction_seven_objects",
        "logical_deduction_three_objects",
        "movie_recommendation",
        "multistep_arithmetic_two",
        "navigate",
        "object_counting",
        "penguins_in_a_table",
        "reasoning_about_colored_objects",
        "ruin_names",
        "salient_translation_error_detection",
        "snarks",
        "sports_understanding",
        "temporal_sequences",
        "tracking_shuffled_objects_five_objects",
        "tracking_shuffled_objects_seven_objects",
        "tracking_shuffled_objects_three_objects",
        "web_of_lies",
        "word_sorting",
    ]
    cot = True
    data_dir = os.path.join(get_appropriate_data_dir(), "eval/bbh")
    max_num_examples_per_task = 40  # following eval setting

    # 270 prompts: 10 per task.
    def get_all_test_prompts(self, num_samples=270, seed=42, max_length=2048, prompt_only=False, response_only=False):
        random.seed(42)
        all_tasks = {}
        task_files = glob.glob(os.path.join(self.data_dir, "bbh", "*.json"))
        for task_file in tqdm.tqdm(task_files, desc="Loading tasks"):
            with open(task_file, "r") as f:
                task_name = os.path.basename(task_file).split(".")[0]
                all_tasks[task_name] = json.load(f)["examples"]
                if self.max_num_examples_per_task:
                    all_tasks[task_name] = random.sample(all_tasks[task_name], self.max_num_examples_per_task)
        all_prompts = {}
        cot_prompt_files = glob.glob(os.path.join(self.data_dir, "cot-prompts", "*.txt"))
        for cot_prompt_file in tqdm.tqdm(cot_prompt_files, desc="Loading prompts"):
            with open(cot_prompt_file, "r") as f:
                task_name = os.path.basename(cot_prompt_file).split(".")[0]
                task_prompt = "".join(f.readlines()[2:])
                if not self.cot:
                    prompt_fields = task_prompt.split("\n\n")
                    new_prompt_fields = []
                    for prompt_field in prompt_fields:
                        if prompt_field.startswith("Q:"):
                            assert (
                                "So the answer is" in prompt_field
                            ), f"`So the answer is` not found in prompt field of {task_name}.txt."
                            assert "\nA:" in prompt_field, "`\nA:` not found in prompt field."
                            answer = prompt_field.split("So the answer is")[-1].strip()
                            question = prompt_field.split("\nA:")[0].strip()
                            new_prompt_fields.append(question + "\nA: " + answer)
                        else:
                            new_prompt_fields.append(prompt_field)
                    task_prompt = "\n\n".join(new_prompt_fields)
                all_prompts[task_name] = task_prompt
        prompts = []
        labels = []
        for task_name in tqdm.tqdm(all_tasks.keys(), desc="Evaluating"):
            task_examples = all_tasks[task_name]
            task_prompt = all_prompts[task_name]
            for example in task_examples:
                prompt = task_prompt.strip() + "\n\nQ: " + example["input"]
                messages = [{"role": "user", "content": prompt}]
                prompt = create_prompt_with_tulu_chat_format(messages, self.tokenizer, add_bos=False)
                prompt += "A:" if prompt[-1] in ["\n", " "] else " A:"
                prompts.append(prompt)
                labels.append(example["target"])
        test_dataset = Dataset.from_dict({"prompts": prompts, "labels": labels})
        test_dataset = test_dataset.map(lambda x: construct_test_sample(self.tokenizer, x, max_length=max_length, prompt_only=prompt_only, response_only=response_only))
        test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
        test_dataset = test_dataset.shuffle(seed=seed).select(range(min(num_samples, len(test_dataset))))
        return test_dataset


class BBHShots(TestDataset):
    # basically the prompts, but we have to split them up a bunch.
    subsets = [
        "boolean_expressions",
        "causal_judgement",
        "date_understanding",
        "disambiguation_qa",
        "dyck_languages",
        "formal_fallacies",
        "geometric_shapes",
        "hyperbaton",
        "logical_deduction_five_objects",
        "logical_deduction_seven_objects",
        "logical_deduction_three_objects",
        "movie_recommendation",
        "multistep_arithmetic_two",
        "navigate",
        "object_counting",
        "penguins_in_a_table",
        "reasoning_about_colored_objects",
        "ruin_names",
        "salient_translation_error_detection",
        "snarks",
        "sports_understanding",
        "temporal_sequences",
        "tracking_shuffled_objects_five_objects",
        "tracking_shuffled_objects_seven_objects",
        "tracking_shuffled_objects_three_objects",
        "web_of_lies",
        "word_sorting",
    ]
    cot = True
    data_dir = os.path.join(get_appropriate_data_dir(), "eval/bbh")
    max_num_examples_per_task = 10

    # 270 prompts: 3 per task.
    def get_all_test_prompts(self, num_samples=81, seed=42, max_length=2048, prompt_only=False, response_only=False):
        random.seed(42)
        all_prompts = {}
        cot_prompt_files = glob.glob(os.path.join(self.data_dir, "cot-prompts", "*.txt"))
        for cot_prompt_file in tqdm.tqdm(cot_prompt_files, desc="Loading prompts"):
            with open(cot_prompt_file, "r") as f:
                task_name = os.path.basename(cot_prompt_file).split(".")[0]
                task_prompt = "".join(f.readlines()[2:])
                if not self.cot:
                    prompt_fields = task_prompt.split("\n\n")
                    new_prompt_fields = []
                    for prompt_field in prompt_fields:
                        if prompt_field.startswith("Q:"):
                            assert (
                                "So the answer is" in prompt_field
                            ), f"`So the answer is` not found in prompt field of {task_name}.txt."
                            assert "\nA:" in prompt_field, "`\nA:` not found in prompt field."
                            answer = prompt_field.split("So the answer is")[-1].strip()
                            question = prompt_field.split("\nA:")[0].strip()
                            new_prompt_fields.append(question + "\nA: " + answer)
                        else:
                            new_prompt_fields.append(prompt_field)
                    task_prompt = "\n\n".join(new_prompt_fields)
                all_prompts[task_name] = task_prompt
        # post process the prompts: split into the initial prompt, the question, and the answer.
        prompts = []
        labels = []
        for task_name in all_prompts:
            prompt_questions = all_prompts[task_name].split("\n\nQ:")
            initial_prompts = prompt_questions[0]
            for prompt_question in prompt_questions[1:]:
                question, answer = prompt_question.split("\nA: ")
                prompts.append(initial_prompts + "\n\nQ:" + question)
                labels.append(answer)
        test_dataset = Dataset.from_dict({"prompts": prompts, "labels": labels})
        test_dataset = test_dataset.map(lambda x: construct_test_sample(self.tokenizer, x, max_length=max_length, prompt_only=prompt_only, response_only=response_only))
        test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
        test_dataset = test_dataset.shuffle(seed=seed).select(range(min(num_samples, len(test_dataset))))
        return test_dataset


class BBHShotsShots(TestDataset):
    subsets = [
        "boolean_expressions",
        "causal_judgement",
        "date_understanding",
        "disambiguation_qa",
        "dyck_languages",
        "formal_fallacies",
        "geometric_shapes",
        "hyperbaton",
        "logical_deduction_five_objects",
        "logical_deduction_seven_objects",
        "logical_deduction_three_objects",
        "movie_recommendation",
        "multistep_arithmetic_two",
        "navigate",
        "object_counting",
        "penguins_in_a_table",
        "reasoning_about_colored_objects",
        "ruin_names",
        "salient_translation_error_detection",
        "snarks",
        "sports_understanding",
        "temporal_sequences",
        "tracking_shuffled_objects_five_objects",
        "tracking_shuffled_objects_seven_objects",
        "tracking_shuffled_objects_three_objects",
        "web_of_lies",
        "word_sorting",
    ]
    cot = True
    data_dir = os.path.join(get_appropriate_data_dir(), "eval/bbh")
    max_num_examples_per_task = 40  # following eval setting

    # 270 prompts: 10 per task.
    def get_all_test_prompts(self, num_samples=270, seed=42, max_length=2048, prompt_only=False, response_only=False):
        random.seed(42)
        all_tasks = {}
        task_files = glob.glob(os.path.join(self.data_dir, "bbh", "*.json"))
        for task_file in tqdm.tqdm(task_files, desc="Loading tasks"):
            with open(task_file, "r") as f:
                task_name = os.path.basename(task_file).split(".")[0]
                all_tasks[task_name] = json.load(f)["examples"]
                if self.max_num_examples_per_task:
                    all_tasks[task_name] = random.sample(all_tasks[task_name], self.max_num_examples_per_task)
        all_prompts = {}
        cot_prompt_files = glob.glob(os.path.join(self.data_dir, "cot-prompts", "*.txt"))
        for cot_prompt_file in tqdm.tqdm(cot_prompt_files, desc="Loading prompts"):
            with open(cot_prompt_file, "r") as f:
                task_name = os.path.basename(cot_prompt_file).split(".")[0]
                task_prompt = "".join(f.readlines()[2:])
                if not self.cot:
                    prompt_fields = task_prompt.split("\n\n")
                    new_prompt_fields = []
                    for prompt_field in prompt_fields:
                        if prompt_field.startswith("Q:"):
                            assert (
                                "So the answer is" in prompt_field
                            ), f"`So the answer is` not found in prompt field of {task_name}.txt."
                            assert "\nA:" in prompt_field, "`\nA:` not found in prompt field."
                            answer = prompt_field.split("So the answer is")[-1].strip()
                            question = prompt_field.split("\nA:")[0].strip()
                            new_prompt_fields.append(question + "\nA: " + answer)
                        else:
                            new_prompt_fields.append(prompt_field)
                    task_prompt = "\n\n".join(new_prompt_fields)
                all_prompts[task_name] = task_prompt
        prompts = []
        labels = []
        for task_name in tqdm.tqdm(all_tasks.keys(), desc="Evaluating"):
            task_prompt = all_prompts[task_name]
            prompt_questions = all_prompts[task_name].split("\n\nQ:")
            for prompt_question in prompt_questions[1:]:
                question, answer = prompt_question.split("\nA: ")
                q_task_prompt = task_prompt.strip().replace("Q:" + prompt_question, "").replace("\n\n\n\n", "\n\n")
                prompt = q_task_prompt.strip() + "\n\nQ: " + question
                messages = [{"role": "user", "content": prompt}]
                prompt = create_prompt_with_tulu_chat_format(messages, self.tokenizer, add_bos=False)
                prompt += "A:" if prompt[-1] in ["\n", " "] else " A:"
                prompts.append(prompt)
                labels.append(answer)
        test_dataset = Dataset.from_dict({"prompts": prompts, "labels": labels})
        test_dataset = test_dataset.map(lambda x: construct_test_sample(self.tokenizer, x, max_length=max_length, prompt_only=prompt_only, response_only=response_only))
        test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
        test_dataset = test_dataset.shuffle(seed=seed).select(range(min(num_samples, len(test_dataset))))
        return test_dataset

# TydiQA


class TydiqaEval(TestDataset):
    data_dir = os.path.join(get_appropriate_data_dir(), "eval/tydiqa")
    max_num_examples_per_lang = 100
    n_shot = 1
    encoding_templates_with_context = {
        "english": (
            "Answer the following question based on the information in the given passage.",
            "Passage:",
            "Question:",
            "Answer:",
        ),
        "arabic": ("أجب على السؤال التالي بناءً على المعلومات في المقطع المعطى.", "المقطع:", "السؤال:", "الإجابة:"),
        "bengali": (
            "প্রদত্ত অধ্যায়ের তথ্যের উপর ভিত্তি করে নিম্নলিখিত প্রশ্নের উত্তর দিন।",
            "অধ্যায়:",
            "প্রশ্ন:",
            "উত্তর:",
        ),
        "finnish": (
            "Vastaa seuraavaan kysymykseen annetun kappaleen tiedon perusteella.",
            "Kappale:",
            "Kysymys:",
            "Vastaus:",
        ),
        "indonesian": (
            "Jawab pertanyaan berikut berdasarkan informasi di bagian yang diberikan.",
            "Bagian:",
            "Pertanyaan:",
            "Jawaban:",
        ),
        "korean": ("주어진 문단의 정보에 기반하여 다음 질문에 답하십시오.", "문단:", "질문:", "답변:"),
        "russian": (
            "Ответьте на следующий вопрос на основе информации в данном отрывке.",
            "Отрывок:",
            "Вопрос:",
            "Ответ:",
        ),
        "swahili": (
            "Jibu swali lifuatalo kulingana na habari kwenye kifungu kilichotolewa.",
            "Kifungu:",
            "Swali:",
            "Jibu:",
        ),
        "telugu": ("ఇచ్చిన పేరాలోని సమాచారం ఆధారంగా కింది ప్రశ్నకు సమాధానం ఇవ్వండి.", "పేరా:", "ప్రశ్న:", "సమాధానం:"),
    }
    encoding_templates_without_context = {
        "english": ("Answer the following question.", "Question:", "Answer:"),
        "arabic": ("أجب على السؤال التالي.", "السؤال:", "الإجابة:"),
        "bengali": ("নিম্নলিখিত প্রশ্নের উত্তর দিন।", "প্রশ্ন:", "উত্তর:"),
        "finnish": ("Vastaa seuraavaan kysymykseen.", "Kysymys:", "Vastaus:"),
        "indonesian": ("Jawab pertanyaan berikut.", "Pertanyaan:", "Jawaban:"),
        "korean": ("다음 질문에 답하십시오.", "질문:", "답변:"),
        "russian": ("Ответьте на следующий вопрос.", "Вопрос:", "Ответ:"),
        "swahili": ("Jibu swali lifuatalo.", "Swali:", "Jibu:"),
        "telugu": ("క్రింది ప్రశ్నకు సమాధానం ఇవ్వండి.", "ప్రశ్న:", "సమాధానం:"),
    }
    max_context_length = 1024
    no_context = False
    # 90 prompts: 10 per language.

    def get_all_test_prompts(self, num_samples=90, seed=42, max_length=2048, prompt_only=False, response_only=False):
        test_data = []
        tokenizer = self.tokenizer
        with open(os.path.join(self.data_dir, "tydiqa-goldp-v1.1-dev.json")) as fin:
            dev_data = json.load(fin)
            for article in dev_data["data"]:
                for paragraph in article["paragraphs"]:
                    for qa in paragraph["qas"]:
                        example = {
                            "id": qa["id"],
                            "lang": qa["id"].split("-")[0],
                            "context": paragraph["context"],
                            "question": qa["question"],
                            "answers": qa["answers"],
                        }
                        test_data.append(example)
        data_languages = sorted(list(set([example["lang"] for example in test_data])))
        if self.max_num_examples_per_lang:
            sampled_examples = []
            for lang in data_languages:
                examples_for_lang = [example for example in test_data if example["lang"] == lang]
                if len(examples_for_lang) > self.max_num_examples_per_lang:
                    examples_for_lang = random.sample(examples_for_lang, self.max_num_examples_per_lang)
                sampled_examples += examples_for_lang
            test_data = sampled_examples

        print(f"Loaded {len(test_data)} examples from {len(data_languages)} languages: {data_languages}")

        if self.n_shot > 0:
            train_data_for_langs = {lang: [] for lang in data_languages}
            with open(os.path.join(self.data_dir, "tydiqa-goldp-v1.1-train.json")) as fin:
                train_data = json.load(fin)
                for article in train_data["data"]:
                    for paragraph in article["paragraphs"]:
                        for qa in paragraph["qas"]:
                            lang = qa["id"].split("-")[0]
                            if lang in data_languages:
                                example = {
                                    "id": qa["id"],
                                    "lang": lang,
                                    "context": paragraph["context"],
                                    "question": qa["question"],
                                    "answers": qa["answers"],
                                }
                                train_data_for_langs[lang].append(example)
                for lang in data_languages:
                    # sample n_shot examples from each language
                    train_data_for_langs[lang] = random.sample(train_data_for_langs[lang], self.n_shot)
            # assert that we have exactly n_shot examples for each language
            assert all([len(train_data_for_langs[lang]) == self.n_shot for lang in data_languages])

        # assert we have templates for all data languages
        assert all([lang in self.encoding_templates_with_context.keys() for lang in data_languages])

        if self.max_context_length:
            for example in test_data:
                tokenized_context = tokenizer.encode(example["context"])
                if len(tokenized_context) > self.max_context_length:
                    example["context"] = tokenizer.decode(tokenized_context[: self.max_context_length])
            if self.n_shot > 0:
                for lang in data_languages:
                    for example in train_data_for_langs[lang]:
                        tokenized_context = tokenizer.encode(example["context"])
                        if len(tokenized_context) > self.max_context_length:
                            example["context"] = tokenizer.decode(tokenized_context[: self.max_context_length])

        prompts = []
        labels = []
        for example in test_data:
            lang = example["lang"]

            if self.no_context:
                prompt, q_template, a_template = self.encoding_templates_without_context[lang]
                p_template = ""
            else:
                prompt, p_template, q_template, a_template = self.encoding_templates_with_context[lang]

            prompt += "\n\n"

            if self.n_shot > 0:
                formatted_demo_examples = []
                for train_example in train_data_for_langs[lang]:
                    if self.no_context:
                        formatted_demo_examples.append(
                            q_template
                            + " "
                            + train_example["question"]
                            + "\n"
                            + a_template
                            + " "
                            + train_example["answers"][0]["text"]
                        )
                    else:
                        formatted_demo_examples.append(
                            p_template
                            + " "
                            + train_example["context"]
                            + "\n"
                            + q_template
                            + " "
                            + train_example["question"]
                            + "\n"
                            + a_template
                            + " "
                            + train_example["answers"][0]["text"]
                        )
                prompt += "\n\n".join(formatted_demo_examples) + "\n\n"

            if self.no_context:
                prompt += q_template + " " + format(example["question"]) + "\n"
            else:
                prompt += (
                    p_template
                    + " "
                    + format(example["context"])
                    + "\n"
                    + q_template
                    + " "
                    + format(example["question"])
                    + "\n"
                )

            messages = [{"role": "user", "content": prompt}]
            prompt = create_prompt_with_tulu_chat_format(messages, self.tokenizer, add_bos=False)
            prompt += a_template if prompt[-1] in ["\n", " "] else " " + a_template
            prompts.append(prompt)
            labels.append(example["answers"][0]["text"])
        test_dataset = Dataset.from_dict({"prompts": prompts, "labels": labels})
        test_dataset = test_dataset.map(lambda x: construct_test_sample(self.tokenizer, x, max_length=max_length, prompt_only=prompt_only, response_only=response_only))
        test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
        test_dataset = test_dataset.shuffle(seed=seed).select(range(min(num_samples, len(test_dataset))))
        return test_dataset


class TydiQAShots(TestDataset):
    data_dir = os.path.join(get_appropriate_data_dir(), "eval/tydiqa")
    max_num_examples_per_lang = 10
    n_shot = 1
    encoding_templates_with_context = {
        "english": (
            "Answer the following question based on the information in the given passage.",
            "Passage:",
            "Question:",
            "Answer:",
        ),
        "arabic": ("أجب على السؤال التالي بناءً على المعلومات في المقطع المعطى.", "المقطع:", "السؤال:", "الإجابة:"),
        "bengali": (
            "প্রদত্ত অধ্যায়ের তথ্যের উপর ভিত্তি করে নিম্নলিখিত প্রশ্নের উত্তর দিন।",
            "অধ্যায়:",
            "প্রশ্ন:",
            "উত্তর:",
        ),
        "finnish": (
            "Vastaa seuraavaan kysymykseen annetun kappaleen tiedon perusteella.",
            "Kappale:",
            "Kysymys:",
            "Vastaus:",
        ),
        "indonesian": (
            "Jawab pertanyaan berikut berdasarkan informasi di bagian yang diberikan.",
            "Bagian:",
            "Pertanyaan:",
            "Jawaban:",
        ),
        "korean": ("주어진 문단의 정보에 기반하여 다음 질문에 답하십시오.", "문단:", "질문:", "답변:"),
        "russian": (
            "Ответьте на следующий вопрос на основе информации в данном отрывке.",
            "Отрывок:",
            "Вопрос:",
            "Ответ:",
        ),
        "swahili": (
            "Jibu swali lifuatalo kulingana na habari kwenye kifungu kilichotolewa.",
            "Kifungu:",
            "Swali:",
            "Jibu:",
        ),
        "telugu": ("ఇచ్చిన పేరాలోని సమాచారం ఆధారంగా కింది ప్రశ్నకు సమాధానం ఇవ్వండి.", "పేరా:", "ప్రశ్న:", "సమాధానం:"),
    }
    encoding_templates_without_context = {
        "english": ("Answer the following question.", "Question:", "Answer:"),
        "arabic": ("أجب على السؤال التالي.", "السؤال:", "الإجابة:"),
        "bengali": ("নিম্নলিখিত প্রশ্নের উত্তর দিন।", "প্রশ্ন:", "উত্তর:"),
        "finnish": ("Vastaa seuraavaan kysymykseen.", "Kysymys:", "Vastaus:"),
        "indonesian": ("Jawab pertanyaan berikut.", "Pertanyaan:", "Jawaban:"),
        "korean": ("다음 질문에 답하십시오.", "질문:", "답변:"),
        "russian": ("Ответьте на следующий вопрос.", "Вопрос:", "Ответ:"),
        "swahili": ("Jibu swali lifuatalo.", "Swali:", "Jibu:"),
        "telugu": ("క్రింది ప్రశ్నకు సమాధానం ఇవ్వండి.", "ప్రశ్న:", "సమాధానం:"),
    }
    max_context_length = 1024
    no_context = False
    # 90 prompts: 10 per language.

    def get_all_test_prompts(self, num_samples=90, seed=42, max_length=2048, prompt_only=False, response_only=False):
        test_data = []
        self.tokenizer
        with open(os.path.join(self.data_dir, "tydiqa-goldp-v1.1-dev.json")) as fin:
            dev_data = json.load(fin)
            for article in dev_data["data"]:
                for paragraph in article["paragraphs"]:
                    for qa in paragraph["qas"]:
                        example = {
                            "id": qa["id"],
                            "lang": qa["id"].split("-")[0],
                            "context": paragraph["context"],
                            "question": qa["question"],
                            "answers": qa["answers"],
                        }
                        test_data.append(example)
        data_languages = sorted(list(set([example["lang"] for example in test_data])))
        if self.max_num_examples_per_lang:
            sampled_examples = []
            for lang in data_languages:
                examples_for_lang = [example for example in test_data if example["lang"] == lang]
                if len(examples_for_lang) > self.max_num_examples_per_lang:
                    examples_for_lang = random.sample(examples_for_lang, self.max_num_examples_per_lang)
                sampled_examples += examples_for_lang
            test_data = sampled_examples

        print(f"Loaded {len(test_data)} examples from {len(data_languages)} languages: {data_languages}")

        if self.n_shot > 0:
            train_data_for_langs = {lang: [] for lang in data_languages}
            with open(os.path.join(self.data_dir, "tydiqa-goldp-v1.1-train.json")) as fin:
                train_data = json.load(fin)
                for article in train_data["data"]:
                    for paragraph in article["paragraphs"]:
                        for qa in paragraph["qas"]:
                            lang = qa["id"].split("-")[0]
                            if lang in data_languages:
                                example = {
                                    "id": qa["id"],
                                    "lang": lang,
                                    "context": paragraph["context"],
                                    "question": qa["question"],
                                    "answers": qa["answers"],
                                }
                                train_data_for_langs[lang].append(example)
                for lang in data_languages:
                    # sample n_shot examples from each language
                    train_data_for_langs[lang] = random.sample(train_data_for_langs[lang], self.n_shot)
            # assert that we have exactly n_shot examples for each language
            assert all([len(train_data_for_langs[lang]) == self.n_shot for lang in data_languages])

        # construct prompts, labels directly from few-shots
        prompts, labels = [], []
        for lang in train_data_for_langs:
            for example in train_data_for_langs[lang]:
                if self.no_context:
                    prompt, q_template, a_template = self.encoding_templates_without_context[lang]
                    p_template = ""
                else:
                    prompt, p_template, q_template, a_template = self.encoding_templates_with_context[lang]
                prompt += "\n\n"
                if self.no_context:
                    prompt += q_template + " " + example["question"] + "\n"
                else:
                    prompt += (
                        p_template + " " + example["context"] + "\n" + q_template + " " + example["question"] + "\n"
                    )
                messages = [{"role": "user", "content": prompt}]
                prompt = create_prompt_with_tulu_chat_format(messages, self.tokenizer, add_bos=False)
                prompt += a_template if prompt[-1] in ["\n", " "] else " " + a_template
                prompts.append(prompt)
                labels.append(example["answers"][0]["text"])
        test_dataset = Dataset.from_dict({"prompts": prompts, "labels": labels})
        test_dataset = test_dataset.map(lambda x: construct_test_sample(self.tokenizer, x, max_length=max_length, prompt_only=prompt_only, response_only=response_only))
        test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
        test_dataset = test_dataset.shuffle(seed=seed).select(range(min(num_samples, len(test_dataset))))
        return test_dataset


class TydiQAShotsShots(TestDataset):
    data_dir = os.path.join(get_appropriate_data_dir(), "eval/tydiqa")
    max_num_examples_per_lang = 10
    n_shot = 2
    encoding_templates_with_context = {
        "english": (
            "Answer the following question based on the information in the given passage.",
            "Passage:",
            "Question:",
            "Answer:",
        ),
        "arabic": ("أجب على السؤال التالي بناءً على المعلومات في المقطع المعطى.", "المقطع:", "السؤال:", "الإجابة:"),
        "bengali": (
            "প্রদত্ত অধ্যায়ের তথ্যের উপর ভিত্তি করে নিম্নলিখিত প্রশ্নের উত্তর দিন।",
            "অধ্যায়:",
            "প্রশ্ন:",
            "উত্তর:",
        ),
        "finnish": (
            "Vastaa seuraavaan kysymykseen annetun kappaleen tiedon perusteella.",
            "Kappale:",
            "Kysymys:",
            "Vastaus:",
        ),
        "indonesian": (
            "Jawab pertanyaan berikut berdasarkan informasi di bagian yang diberikan.",
            "Bagian:",
            "Pertanyaan:",
            "Jawaban:",
        ),
        "korean": ("주어진 문단의 정보에 기반하여 다음 질문에 답하십시오.", "문단:", "질문:", "답변:"),
        "russian": (
            "Ответьте на следующий вопрос на основе информации в данном отрывке.",
            "Отрывок:",
            "Вопрос:",
            "Ответ:",
        ),
        "swahili": (
            "Jibu swali lifuatalo kulingana na habari kwenye kifungu kilichotolewa.",
            "Kifungu:",
            "Swali:",
            "Jibu:",
        ),
        "telugu": ("ఇచ్చిన పేరాలోని సమాచారం ఆధారంగా కింది ప్రశ్నకు సమాధానం ఇవ్వండి.", "పేరా:", "ప్రశ్న:", "సమాధానం:"),
    }
    encoding_templates_without_context = {
        "english": ("Answer the following question.", "Question:", "Answer:"),
        "arabic": ("أجب على السؤال التالي.", "السؤال:", "الإجابة:"),
        "bengali": ("নিম্নলিখিত প্রশ্নের উত্তর দিন।", "প্রশ্ন:", "উত্তর:"),
        "finnish": ("Vastaa seuraavaan kysymykseen.", "Kysymys:", "Vastaus:"),
        "indonesian": ("Jawab pertanyaan berikut.", "Pertanyaan:", "Jawaban:"),
        "korean": ("다음 질문에 답하십시오.", "질문:", "답변:"),
        "russian": ("Ответьте на следующий вопрос.", "Вопрос:", "Ответ:"),
        "swahili": ("Jibu swali lifuatalo.", "Swali:", "Jibu:"),
        "telugu": ("క్రింది ప్రశ్నకు సమాధానం ఇవ్వండి.", "ప్రశ్న:", "సమాధానం:"),
    }
    max_context_length = 1024
    no_context = False
    # 90 prompts: 10 per language.

    def get_all_test_prompts(self, num_samples=90, seed=42, max_length=2048, prompt_only=False, response_only=False):
        test_data = []
        self.tokenizer
        with open(os.path.join(self.data_dir, "tydiqa-goldp-v1.1-dev.json")) as fin:
            dev_data = json.load(fin)
            for article in dev_data["data"]:
                for paragraph in article["paragraphs"]:
                    for qa in paragraph["qas"]:
                        example = {
                            "id": qa["id"],
                            "lang": qa["id"].split("-")[0],
                            "context": paragraph["context"],
                            "question": qa["question"],
                            "answers": qa["answers"],
                        }
                        test_data.append(example)
        data_languages = sorted(list(set([example["lang"] for example in test_data])))
        if self.max_num_examples_per_lang:
            sampled_examples = []
            for lang in data_languages:
                examples_for_lang = [example for example in test_data if example["lang"] == lang]
                if len(examples_for_lang) > self.max_num_examples_per_lang:
                    examples_for_lang = random.sample(examples_for_lang, self.max_num_examples_per_lang)
                sampled_examples += examples_for_lang
            test_data = sampled_examples

        print(f"Loaded {len(test_data)} examples from {len(data_languages)} languages: {data_languages}")

        test_data = []
        if self.n_shot > 0:
            train_data_for_langs = {lang: [] for lang in data_languages}
            with open(os.path.join(self.data_dir, "tydiqa-goldp-v1.1-train.json")) as fin:
                train_data = json.load(fin)
                for article in train_data["data"]:
                    for paragraph in article["paragraphs"]:
                        for qa in paragraph["qas"]:
                            lang = qa["id"].split("-")[0]
                            if lang in data_languages:
                                example = {
                                    "id": qa["id"],
                                    "lang": lang,
                                    "context": paragraph["context"],
                                    "question": qa["question"],
                                    "answers": qa["answers"],
                                }
                                train_data_for_langs[lang].append(example)
                for lang in data_languages:
                    # sample n_shot examples from each language, take the first as test example
                    # rest are few-shot samples
                    two_shots = random.sample(train_data_for_langs[lang], self.n_shot)
                    train_data_for_langs[lang] = two_shots[1:]
                    test_data.append(two_shots[0])
            # assert that we have exactly n_shot examples for each language
            assert all([len(train_data_for_langs[lang]) == self.n_shot - 1 for lang in data_languages])

        # construct prompts, labels directly from few-shots
        prompts, labels = [], []
        prompts = []
        labels = []
        for example in test_data:
            lang = example["lang"]

            if self.no_context:
                prompt, q_template, a_template = self.encoding_templates_without_context[lang]
                p_template = ""
            else:
                prompt, p_template, q_template, a_template = self.encoding_templates_with_context[lang]

            prompt += "\n\n"

            if self.n_shot > 0:
                formatted_demo_examples = []
                for train_example in train_data_for_langs[lang]:
                    if self.no_context:
                        formatted_demo_examples.append(
                            q_template
                            + " "
                            + train_example["question"]
                            + "\n"
                            + a_template
                            + " "
                            + train_example["answers"][0]["text"]
                        )
                    else:
                        formatted_demo_examples.append(
                            p_template
                            + " "
                            + train_example["context"]
                            + "\n"
                            + q_template
                            + " "
                            + train_example["question"]
                            + "\n"
                            + a_template
                            + " "
                            + train_example["answers"][0]["text"]
                        )
                prompt += "\n\n".join(formatted_demo_examples) + "\n\n"

            if self.no_context:
                prompt += q_template + " " + format(example["question"]) + "\n"
            else:
                prompt += (
                    p_template
                    + " "
                    + format(example["context"])
                    + "\n"
                    + q_template
                    + " "
                    + format(example["question"])
                    + "\n"
                )

            messages = [{"role": "user", "content": prompt}]
            prompt = create_prompt_with_tulu_chat_format(messages, self.tokenizer, add_bos=False)
            prompt += a_template if prompt[-1] in ["\n", " "] else " " + a_template
            prompts.append(prompt)
            labels.append(example["answers"][0]["text"])
        test_dataset = Dataset.from_dict({"prompts": prompts, "labels": labels})
        test_dataset = test_dataset.map(lambda x: construct_test_sample(self.tokenizer, x, max_length=max_length, prompt_only=prompt_only, response_only=response_only))
        test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
        test_dataset = test_dataset.shuffle(seed=seed).select(range(min(num_samples, len(test_dataset))))
        return test_dataset


# CODEX-EVAL


class CodexEval(TestDataset):
    data_file = os.path.join(get_appropriate_data_dir(), "eval/codex_humaneval/HumanEval_dev.jsonl.gz")
    data_file_hep = os.path.join(get_appropriate_data_dir(), "eval/codex_humaneval/humanevalpack.jsonl")

    def get_all_test_prompts(self, num_samples=100, seed=42, max_length=512, prompt_only=False, response_only=False):
        random.seed(42)
        test_data = list(read_problems(self.data_file).values())
        print("Number of examples:", len(test_data))

        # these stop sequences are those mentioned in the codex paper.
        stop_sequences = ["\nclass", "\ndef", "\n#", "\nif", "\nprint"]

        prompts = []
        labels = []
        # If available use more realistic instructions from HumanEvalPack (https://hf.co/datasets/bigcode/humanevalpack)
        if os.path.exists(self.data_file_hep):
            with open(self.data_file_hep, "r") as f:
                instructions = [json.loads(line) for line in f]
                instructions_dict = {
                    x["task_id"].replace("Python", "HumanEval"): x["instruction"] for x in instructions
                }
            answer = "Here is the function:\n\n```python\n"
            stop_sequences.append("\n```")
        else:
            print(
                f"Could not find HumanEvalPack file at {self.data_file_hep}, which will result in significantly worse performance. You can download it at https://hf.co/datasets/bigcode/humanevalpack/blob/main/data/python/data/humanevalpack.jsonl"
            )
            instructions_dict = None
            answer = "Here is the completed function:\n\n\n"

        def apply_chat_format(tokenizer, inst, suffix):
            messages = [{"role": "user", "content": inst}]
            prompt = create_prompt_with_tulu_chat_format(messages, tokenizer, add_bos=False)
            prefix = "" if prompt[-1] in ["\n", " "] else " "
            return prompt + prefix + suffix

        instruction = "Complete the following python function.\n\n\n"
        for example in test_data:
            if instructions_dict is not None:
                instruction = instructions_dict[example["task_id"]]
                prompts.append(apply_chat_format(self.tokenizer, instruction, answer + example["prompt"]))
                labels.append(example["canonical_solution"])
            else:
                prompts.append(apply_chat_format(self.tokenizer, instruction + example["prompt"], answer))
                labels.append(example["canonical_solution"])
        # labels are taken from the "canonical solution" in the data provided.
        test_dataset = Dataset.from_dict({"prompts": prompts, "labels": labels})

        def construct_test_sample_tok(x):
            return construct_test_sample(self.tokenizer, x, max_length=max_length, prompt_only=prompt_only, response_only=response_only)

        test_dataset = test_dataset.map(construct_test_sample_tok, load_from_cache_file=False)
        test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
        test_dataset = test_dataset.shuffle(seed=seed).select(range(min(num_samples, len(test_dataset))))
        return test_dataset


# codex test - just point to diff file.


class CodexEvalTest(CodexEval):
    # test file
    data_file = os.path.join(get_appropriate_data_dir(), "eval/codex_humaneval/HumanEval.jsonl.gz")


# ALPACAEVAL


class AlpacaEval(TestDataset):
    data_file = os.path.join(get_appropriate_data_dir(), "eval/alpacaeval/alpaca_eval_dev.json")

    def get_all_test_prompts(self, num_samples=50, seed=42, max_length=512, prompt_only=False, response_only=False):
        data = json.load(open(self.data_file, "r"))
        # shuffle and select first num_samples
        random.seed(42)
        random.shuffle(data)
        data = data[:num_samples]
        prompts, labels = [], []
        for example in data:
            prompt = example["instruction"].strip()
            messages = [{"role": "user", "content": prompt}]
            prompt = create_prompt_with_tulu_chat_format(messages, self.tokenizer, add_bos=False)
            prompts.append(prompt)
            labels.append(example["output"].strip())
        test_dataset = Dataset.from_dict({"prompts": prompts, "labels": labels})

        def construct_test_sample_tok(x):
            return construct_test_sample(self.tokenizer, x, max_length=max_length, prompt_only=prompt_only, response_only=response_only)

        test_dataset = test_dataset.map(construct_test_sample_tok, load_from_cache_file=False)
        test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
        test_dataset = test_dataset.shuffle(seed=seed).select(range(min(num_samples, len(test_dataset))))
        return test_dataset


# no few-shot examples for alpacaeval

# version with test prompts


class AlpacaEvalTest(AlpacaEval):
    data_file = os.path.join(get_appropriate_data_dir(), "eval/alpacaeval/alpaca_eval_test.json")


# Toxigen
# toxigen is a little weird, so we actually do something different: we sample a label we don't want to generate, and we can consider 'harmful' influences for this.
# this might be annoying, so we will typically ignore toxigen in our setups.
# but maybe we'll tackle this in the future.


class ToxigenEval(TestDataset):
    data_dir = os.path.join(get_appropriate_data_dir(), "eval/toxigen/dev")
    max_prompts_per_group = 10
    # 140 default: 10 per group.

    def get_all_test_prompts(self, num_samples=140, seed=42, max_length=512, prompt_only=False, response_only=False):
        random.seed(42)
        examples = []
        prompt_files = glob.glob(os.path.join(self.data_dir, "*.txt"))
        for task_file in tqdm(prompt_files, desc="Loading prompts"):
            with open(task_file, "r") as f:
                group_name = os.path.basename(task_file).split(".")[0]
                label = group_name.split("_")[0]
                minority_group = "_".join(group_name.split("_")[1:])
                group_prompts = [line.strip() for line in f]
                random.shuffle(group_prompts)
                group_prompts = group_prompts[: self.max_prompts_per_group]
                label_prompt = random.choice(group_prompts[self.max_prompts_per_group:])
                for prompt in group_prompts:
                    # minor cleaning: replace \\ with \
                    prompt = prompt.replace("\\\\", "\\")
                    prompt = prompt.replace("\\n", "\n")
                    examples.append(
                        {
                            "text": prompt,
                            "label": label,
                            "target_groups": [minority_group],
                            "label_prompt": label_prompt,
                        }
                    )
        prompts, labels = [], []
        for example in examples:
            messages = [{"role": "user", "content": "Complete the following: " + example["text"]}]
            prompt = create_prompt_with_tulu_chat_format(messages, self.tokenizer, add_bos=False)
            prompts.append(prompt)
            labels.append(example["label_prompt"])
        test_dataset = Dataset.from_dict({"prompts": prompts, "labels": labels})

        def construct_test_sample_tok(x):
            return construct_test_sample(self.tokenizer, x, max_length=max_length, prompt_only=prompt_only, response_only=response_only)

        test_dataset = test_dataset.map(construct_test_sample_tok, load_from_cache_file=False)
        test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
        test_dataset = test_dataset.shuffle(seed=seed).select(range(min(num_samples, len(test_dataset))))
        return test_dataset


# no few-shot examples for toxigen.

# TRUTHFULQA

# I am skipping this one, since it is even weirder than the others. Hard to say what a good label might be.


# SQUAD


# not part of the original tulu eval set, but we can use this as a nice simpler task.
class SquadEval(TestDataset):
    add_context = True

    def get_all_test_prompts(self, num_samples=500, seed=42, max_length=512, prompt_only=False, response_only=False):
        test_dataset = load_dataset("squad", split="train")
        tokenizer = self.tokenizer

        def convert_squad_sample(sample):
            prompt = sample["question"]  # no context lmao.
            if self.add_context:
                prompt = sample["context"] + "\n\n" + prompt
            label = sample["answers"]["text"][0]
            messages = [{"role": "user", "content": prompt}]
            return {
                "prompts": create_prompt_with_tulu_chat_format(messages, tokenizer, add_bos=False),
                "labels": label,
            }

        test_dataset = test_dataset.map(
            convert_squad_sample, remove_columns=["id", "title", "context", "question", "answers"]
        )
        test_dataset = test_dataset.map(lambda x: construct_test_sample(self.tokenizer, x, max_length=max_length, prompt_only=prompt_only, response_only=response_only))
        test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
        test_dataset = test_dataset.shuffle(seed=seed).select(range(min(num_samples, len(test_dataset))))
        return test_dataset


class SquadEvalShots(TestDataset):
    def get_all_test_prompts(self, num_samples=500, seed=42, max_length=512, prompt_only=False, response_only=False):
        test_dataset = load_dataset("squad", split="train")
        tokenizer = self.tokenizer

        def convert_squad_sample(sample):
            prompt = "\n".join(SQUAD_SHOTS) + '\n' + sample["context"] + "\n\n" + sample["question"]
            label = sample["answers"]["text"][0]
            messages = [{"role": "user", "content": prompt}]
            return {
                "prompts": create_prompt_with_tulu_chat_format(messages, tokenizer, add_bos=False),
                "labels": label,
            }

        test_dataset = test_dataset.map(
            convert_squad_sample, remove_columns=["id", "title", "context", "question", "answers"]
        )
        test_dataset = test_dataset.map(lambda x: construct_test_sample(self.tokenizer, x, max_length=max_length, prompt_only=prompt_only, response_only=response_only))
        test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
        test_dataset = test_dataset.shuffle(seed=seed).select(range(min(num_samples, len(test_dataset))))
        return test_dataset

# test version for squad - just the same but we use the validation set.


class SquadEvalTest(SquadEval):
    def get_all_test_prompts(self, num_samples=500, seed=42, max_length=512, prompt_only=False, response_only=False):
        test_dataset = load_dataset("squad", split="validation")
        tokenizer = self.tokenizer

        def convert_squad_sample(sample):
            prompt = sample["question"]  # no context lmao.
            if self.add_context:
                prompt = sample["context"] + "\n\n" + prompt
            label = sample["answers"]["text"][0]
            messages = [{"role": "user", "content": prompt}]
            return {
                "prompts": create_prompt_with_tulu_chat_format(messages, tokenizer, add_bos=False),
                "labels": label,
            }

        test_dataset = test_dataset.map(
            convert_squad_sample, remove_columns=["id", "title", "context", "question", "answers"]
        )
        test_dataset = test_dataset.map(lambda x: construct_test_sample(self.tokenizer, x, max_length=max_length, prompt_only=prompt_only, response_only=response_only))
        test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
        test_dataset = test_dataset.shuffle(seed=seed).select(range(min(num_samples, len(test_dataset))))
        return test_dataset


class MBPPPlus(TestDataset):
    dataset = load_dataset("evalplus/mbppplus")["test"]
    # Don't change this, relying on this to ensure no overlap between selection set and test set
    dataset.shuffle(seed=42)
    # Always held-out first 100 samples

    def get_all_test_prompts(self, num_samples=100, seed=42, max_length=512, prompt_only=False, response_only=False):
        test_data = self.dataset.select(range(min(num_samples, len(self.dataset))))
        print("Number of examples:", len(test_data))
        prompts = []
        labels = []

        answer = "Here is the completed function:\n\n\n"

        def apply_chat_format(tokenizer, inst, suffix):
            messages = [{"role": "user", "content": inst}]
            prompt = create_prompt_with_tulu_chat_format(messages, tokenizer, add_bos=False)
            prefix = "" if prompt[-1] in ["\n", " "] else " "
            return prompt + prefix + suffix

        instruction = "Complete the following python function.\n\n\n"
        for example in test_data:
            prompts.append(apply_chat_format(self.tokenizer, instruction + example["prompt"], answer))
            labels.append(example["code"])

        # labels are taken from the "canonical solution" in the data provided.
        test_dataset = Dataset.from_dict({"prompts": prompts, "labels": labels})

        def construct_test_sample_tok(x):
            return construct_test_sample(self.tokenizer, x, max_length=max_length, prompt_only=prompt_only, response_only=response_only)

        test_dataset = test_dataset.map(construct_test_sample_tok, load_from_cache_file=False)
        test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
        test_dataset = test_dataset.shuffle(seed=seed).select(range(min(num_samples, len(test_dataset))))
        breakpoint()
        return test_dataset


class ArenaHardSelection(TestDataset):
    data_file = os.path.join(get_appropriate_data_dir(), "eval/arena_hard.jsonl")

    def get_all_test_prompts(self, num_samples=500, seed=42, max_length=2048, prompt_only=False, response_only=False):
        data = [json.loads(line) for line in open(self.data_file, "r")]
        # shuffle and select first num_samples
        random.seed(42)
        random.shuffle(data)
        data = data[:num_samples]
        prompts, labels = [], []
        for example in data:
            prompt = example["prompt"].strip()
            messages = [{"role": "user", "content": prompt}]
            prompt = create_prompt_with_tulu_chat_format(messages, self.tokenizer, add_bos=False)
            prompts.append(prompt)
            labels.append(example["label"].strip())
        test_dataset = Dataset.from_dict({"prompts": prompts, "labels": labels})

        def construct_test_sample_tok(x):
            return construct_test_sample(self.tokenizer, x, max_length=max_length, prompt_only=prompt_only, response_only=response_only)

        test_dataset = test_dataset.map(construct_test_sample_tok, load_from_cache_file=False)
        test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
        test_dataset = test_dataset.shuffle(seed=seed).select(range(min(num_samples, len(test_dataset))))
        return test_dataset


class WildChatSelection(TestDataset):
    data_file = os.path.join(get_appropriate_data_dir(), "eval/wildchat_100_prompts.jsonl")

    def get_all_test_prompts(self, num_samples=100, seed=42, max_length=2048, prompt_only=False, response_only=False):
        data = [json.loads(line) for line in open(self.data_file, "r")]
        # shuffle and select first num_samples
        random.seed(42)
        random.shuffle(data)
        data = data[:num_samples]
        prompts, labels = [], []
        for example in data:
            prompt = example["prompt"].strip()
            messages = [{"role": "user", "content": prompt}]
            prompt = create_prompt_with_tulu_chat_format(messages, self.tokenizer, add_bos=False)
            prompts.append(prompt)
            labels.append(example["response"].strip())
        test_dataset = Dataset.from_dict({"prompts": prompts, "labels": labels})

        def construct_test_sample_tok(x):
            return construct_test_sample(self.tokenizer, x, max_length=max_length, prompt_only=prompt_only, response_only=response_only)

        test_dataset = test_dataset.map(construct_test_sample_tok, load_from_cache_file=False)
        test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
        test_dataset = test_dataset.shuffle(seed=seed).select(range(min(num_samples, len(test_dataset))))
        return test_dataset


# debug dataset: take 100 samples that I know exist in the index,
# and make sure we select them correctly.
class SubsetSelection(TestDataset):
    data_file = os.path.join(get_appropriate_data_dir(), "eval/100_200k_debug.jsonl")

    def get_all_test_prompts(self, num_samples=100, seed=42, max_length=2048, prompt_only=False, response_only=False):
        assert not prompt_only and not response_only, "Prompt only and response only not supported for debug dataset."
        test_dataset = load_dataset("json", data_files=self.data_file)["train"]
        test_dataset = test_dataset.map(lambda x: encode_with_messages_format(x, self.tokenizer, max_length), load_from_cache_file=False)
        test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
        test_dataset = test_dataset.shuffle(seed=seed).select(range(min(num_samples, len(test_dataset))))
        return test_dataset


# dataset to just ingest a file and encode it. Assumes follows the training dataset format.
class FileDataset(TestDataset):
    def __init__(self, data_file, tokenizer):
        self.data_file = data_file
        self.tokenizer = tokenizer

    def get_all_test_prompts(self, num_samples=1000000, max_length=2048, prompt_only=False, response_only=False):
        assert not prompt_only and not response_only, "Prompt only and response only not supported for file dataset."
        test_dataset = load_dataset("json", data_files=self.data_file)["train"]
        test_dataset = test_dataset.map(lambda x: encode_with_messages_format(x, self.tokenizer, max_length), load_from_cache_file=False)
        test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
        test_dataset = test_dataset.select(range(min(num_samples, len(test_dataset))))
        return test_dataset


# todo: humaneval, truthfulqa mc.
# these are maybe tricky since they dont have hard gold.
DATASETS = {
    "mmlu": MMLU,  # all test prompts
    "mmlu_shots": MMLUShots,  # dev version
    "mmlu_shots_shots": MMLUShotsShots,  # dev version, with shots format.
    "gsm8k": GSM8kEval,  # all test prompts by default
    "gsm8k_shots": GSM8kShots,  # dev version
    "gsm8k_shots_shots": GSM8kShotsShots,  # dev version, with shots format.
    "bbh": BBHEval,  # all test prompts by default (that we eval on)
    "bbh_shots": BBHShots,  # dev version
    "bbh_shots_shots": BBHShotsShots,  # dev version, with shots format.
    "tydiqa": TydiqaEval,  # all test prompts by default
    "tydiqa_shots": TydiQAShots,  # dev version
    "tydiqa_shots_shots": TydiQAShotsShots,  # dev version, with shots format.
    "codex": CodexEval,
    "codex_test": CodexEvalTest,  # test version
    "alpacafarm": AlpacaEval,
    "alpacafarm_test": AlpacaEvalTest,  # test version
    "squad": SquadEval,
    "squad_test": SquadEvalTest,
    "squad_shots": SquadEvalShots,  # selection dev version, with shots.
    "mbppplus": MBPPPlus,
    "arena_hard": ArenaHardSelection,
    "wildchat": WildChatSelection,
    'debug_subset': SubsetSelection,
}

# How do we handle ensuring a difference between selection time and test time?
# MMLU -> we just have overlap, deal with it.
# MMLU Shots -> these are dev prompts, we just use these.
# GSM8k -> we just have overlap, deal with it.
# GSM8k Shots -> these are dev prompts, we just use these.
# BBH -> we just have overlap, deal with it.
# BBH Shots -> these are dev prompts, we just use these.
# TydiQA -> we just have overlap, deal with it.
# TydiQA Shots -> these are dev prompts, we just use these.
# Codex -> have a dev and test codex file and just split.
# AlpacaFarm -> I have split the data into dev and test, so we can use these.
# Toxigen -> Remove half the targets, and eval on the rest.
# Squad -> we use the train set to select, and eval on validation. Fine.

if __name__ == "__main__":
    from transformers import AutoTokenizer

    tokenizer = AutoTokenizer.from_pretrained("oobabooga/llama-tokenizer")
    # test that we can load all the datasets.
    ds = {
        "mmlu_shots": MMLUShots,
        "bbh_shots": BBHShots,
        "tydiqa_shots": TydiQAShots
    }
    for dataset_name, dataset_class in ds.items():
        print(f"Testing {dataset_name}")
        dataset = dataset_class(tokenizer)
        dataset = dataset.get_all_test_prompts()
        print("----")
        print(len(dataset))
        print("----")
        print(tokenizer.decode(dataset[0]["input_ids"]))
        print("-")
        print(tokenizer.decode([x for x in dataset[0]["labels"] if x != -100]))
        print("----")
        print(tokenizer.decode(dataset[1]["input_ids"]))
        print("-")
        print(tokenizer.decode([x for x in dataset[1]["labels"] if x != -100]))
        print("----")
        print(f"Success for {dataset_name}")
