from datasets import load_dataset
from utils.prompter import (
    ZeroPrompter,
    GSM8KPrompter,
    OpenBookQAPrompter,
    PIQAPrompter,
    MMLUPrompter,
    BoolQPrompter,
    SIQAPrompter,
    HellaswagPrompter,
    MedQAPrompter,
    AlpacaPrompter,
    SamSumPrompter
)


class InstructionDataset():
    def __init__(self,
                 path,
                 tokenizer,
                 max_length,
                 prompter=None,
                 name=None,
                 train_split="train",
                 val_split=None,
                 remove_columns=None
                 ):
        self.path = path
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.prompter = prompter if prompter is not None else ZeroPrompter()
        self.load_train_and_val(name, train_split, val_split, remove_columns)

    def load_train_and_val(self, name, train_split, val_split, remove_columns):
        data = load_dataset(self.path, name)

        if val_split is None:
            train_val = data[train_split].train_test_split(
                test_size=2000, shuffle=True, seed=42
            )
            self.train_data = train_val["train"].shuffle().map(self.generate_and_tokenize_prompt,
                                                               remove_columns=remove_columns)
            val_data = train_val["test"].map(self.generate_and_tokenize_prompt, remove_columns=remove_columns)
            self.val_data = {
                self.path: val_data
            }

        else:
            self.train_data = (data[train_split].shuffle()
                               .map(self.generate_and_tokenize_prompt, remove_columns=remove_columns))
            val_data = (data[val_split].map(self.generate_and_tokenize_prompt, remove_columns=remove_columns))
            self.val_data = {
                self.path: val_data
            }

    def get_train_and_val(self):
        return self.train_data, self.val_data

    def generate_and_tokenize_prompt(self, data_point):
        full_prompt = self.prompter.generate_prompt(data_point)
        tokenized_full_prompt = self.tokenize_function(full_prompt)
        return tokenized_full_prompt

    def tokenize_function(self, item, padding=False):
        tokenizer = self.tokenizer
        full_prompt = item["prompt"] + item["answer"]

        tokenized_full_prompt = tokenizer(
            full_prompt,
            truncation=True,
            padding=padding,
            max_length=self.max_length,
            return_tensors=None
        )

        # add eos to the full prompt
        if (
                tokenized_full_prompt["input_ids"][-1] != tokenizer.eos_token_id
                and len(tokenized_full_prompt["input_ids"]) < self.max_length
        ):
            tokenized_full_prompt["input_ids"].append(tokenizer.eos_token_id)
            tokenized_full_prompt["attention_mask"].append(1)

        tokenized_user_prompt = tokenizer(
            item["prompt"],
            truncation=True,
            padding=padding,
            max_length=self.max_length,
            return_tensors=None
        )

        user_prompt_len = len(tokenized_user_prompt["input_ids"])

        labels = [-100] * user_prompt_len + tokenized_full_prompt["input_ids"][
                                            user_prompt_len:]  # could be sped up, probably

        return {
            "input_ids": tokenized_full_prompt["input_ids"],
            "attention_mask": tokenized_full_prompt["attention_mask"],
            "labels": labels,
        }


dataset_configs = {
    "openbookqa": {
        # "path": "openbookqa",
        "name": "main",
        "train_split": "train",
        "val_split": "validation"
    },
    "gsm8k": {
        # "path": "gsm8k",
        "name": "main",
        "train_split": "train",
        "val_split": "test"
    },
    "piqa": {
        # "path": "piqa",
        "name": "plain_text",
        "train_split": "train",
        "val_split": "validation",
        "remove_columns": ["label"]
    },
    "mmlu": {
        # "path": "cais/mmlu",
        "name": "all",
        "train_split": "auxiliary_train",
        "val_split": "validation"
    },
    "boolq": {
        # "path": "super_glue",
        "name": "boolq",
        "train_split": "train",
        "val_split": "validation",
        "remove_columns": ["label"]
    },
    "siqa": {
        # "path": "social_i_qa",
        "train_split": "train",
        "val_split": "validation",
        "remove_columns": ["label"]
    },
    "hellaswag": {
        # "path": "hellaswag",
        "train_split": "train",
        "val_split": "validation",
        "remove_columns": ["label"]
    },
    "medqa": {
        # "path": "GBaker/MedQA-USMLE-4-options-hf",
        "train_split": "train",
        "val_split": "validation",
        "remove_columns": ["label"]
    },
    "alpaca": {
        # "path": "yahma/alpaca-cleaned",
        "train_split": "train",
        "val_split": None,
    },
    "samsum": {
        # "path":"knkarthick/samsum"
        "train_split": "train",
        "val_split": "validation"
    }
}


def load_instruction_dataset(name, path, tokenizer, max_length=256):
    if "openbookqa" in name.lower():  # 10-shot
        data = InstructionDataset(path, tokenizer, max_length,
                                  prompter=OpenBookQAPrompter(), **dataset_configs["openbookqa"])
    elif "gsm8k" in name.lower():  # 0-shot
        data = InstructionDataset(path, tokenizer, max_length,
                                  prompter=GSM8KPrompter(), **dataset_configs["gsm8k"])
    elif "piqa" in name.lower():  # 5-shot
        data = InstructionDataset(path, tokenizer, max_length,
                                  prompter=PIQAPrompter(), **dataset_configs["piqa"])
    elif "mmlu" in name.lower():  # 5-shot
        data = InstructionDataset(path, tokenizer, max_length,
                                  prompter=MMLUPrompter(), **dataset_configs["mmlu"])
    elif "boolq" in name.lower():  # 0-shot
        data = InstructionDataset(path, tokenizer, max_length,
                                  prompter=BoolQPrompter(), **dataset_configs["boolq"])
    elif "social_iqa" in name.lower():  # 5-shot
        data = InstructionDataset(path, tokenizer, max_length,
                                  prompter=SIQAPrompter(), **dataset_configs["siqa"])
    elif "hellaswag" in name.lower():  # 5-shot
        data = InstructionDataset(path, tokenizer, max_length,
                                  prompter=HellaswagPrompter(), **dataset_configs["hellaswag"])
    elif "medqa_4options" in name.lower():  # 2-shot
        data = InstructionDataset(path, tokenizer, max_length,
                                  prompter=MedQAPrompter(), **dataset_configs["medqa"])
    elif "alpaca" in name.lower():
        data = InstructionDataset(path, tokenizer, max_length,
                                  prompter=AlpacaPrompter(), **dataset_configs["alpaca"])
    elif "samsum" in name.lower():
        data = InstructionDataset(path, tokenizer, max_length,
                                  prompter=SamSumPrompter(), **dataset_configs["samsum"])
    else:
        data = InstructionDataset(path, tokenizer, max_length)

    return data.get_train_and_val()
