import os

import torch
from datasets import (
    Dataset,
    DatasetDict,
    concatenate_datasets,
    load_dataset,
    load_from_disk,
)

from egu.dataset.base import BaseDataset

from .utils import MatchedForgetRetainDataset, MatchedForgetRetainRandomDataset


class TOFU(BaseDataset):
    dataset_type = "qa"
    path = "locuslab/TOFU"
    name = "tofu"
    subsets = [
        "retain90",
        "retain95",
        "retain99",
        "forget01",
        "forget05",
        "forget10",
        "real_authors",
        "world_facts",
    ]
    match_retain = {
        "forget01": "retain99",
        "forget05": "retain95",
        "forget10": "retain90",
    }
    keys = ["prompt", "answer", "prompt_formatted"]
    eval_prompt_key = "prompt_formatted"
    eval_answer_key = "answer"
    gen_prompt_key = "prompt_formatted"
    gen_answer_key = "answer"
    eval_dataset_keys = ["retain", "forget", "test"]
    raw_path = "egu/dataset/raw/" + path

    def __init__(self, formatting_tokens=None, eos_token=None, *args, **kwargs):
        super().__init__()
        self.eos_token = eos_token if eos_token is not None else ""
        for k in [
            "prompt_prefix",
            "prompt_suffix",
            "answer_prefix",
            "answer_suffix",
        ]:
            (
                setattr(self, k, formatting_tokens[k])
                if formatting_tokens is not None
                else setattr(self, k, "")
            )

    def download(self):

        data_subsets = {}

        for s in self.subsets:
            output_path = self.raw_path + "/" + s
            if not os.path.exists(output_path):
                data_subsets[s] = load_dataset(
                    self.path, s, keep_in_memory=True, trust_remote_code=True
                )["train"]

                data_subsets[s].save_to_disk(output_path)
            else:
                print("path exists" + s)
                data_subsets[s] = load_from_disk(output_path)

        self.dataset = DatasetDict(data_subsets)

    def get_matched_retain_split(self, forget_split_name):
        return self.match_retain.get(forget_split_name, None)

    def format_example(self, example, prompt_prefix=""):
        # formatted_prompt = f"{prompt_prefix}{self.prompt_prefix}{example['prompt']}{self.prompt_suffix}"
        # formatted_answer = self.answer_prefix + example["answer"] + self.eos_token
        # return {
        #     "prompt_formatted": formatted_prompt,
        #     "answer": formatted_answer,
        # }

        x = {"prompt_formatted": example["prompt"], "answer": example["answer"]}

        print(x)
        return x

    def load_dataset_for_training(self, split_name, prompt_prefix=""):
        if self.dataset is None:
            self.download()
        dataset = self.dataset[split_name]

        if "question" in dataset.column_names:

            dataset = dataset.rename_column("question", "prompt")

        # If this is a forget split, pair it with its retain
        print(self.match_retain)
        print(split_name)
        if split_name in self.match_retain:
            retain_split_name = self.get_matched_retain_split(split_name)
            retain_dataset = self.dataset[retain_split_name]
            if "question" in retain_dataset.column_names:
                retain_dataset = retain_dataset.rename_column("question", "prompt")
            retain_dataset = retain_dataset.map(
                lambda x: self.format_example(x, prompt_prefix),
            )

            dataset = dataset.map(lambda x: self.format_example(x, prompt_prefix))

            return MatchedForgetRetainRandomDataset(dataset, retain_dataset)
        else:
            return dataset

    def load_dataset_for_eval(self, split_name, prompt_prefix=""):
        return self.load_dataset_for_training(split_name, prompt_prefix)


if __name__ == "__main__":
    dataset = TOFU()
    test = dataset.load_dataset_for_training("forget05")

    # dataset.download()
    print(test)
    # print(type(test[0]))
