import os

import torch
from datasets import (
    Dataset,
    DatasetDict,
    concatenate_datasets,
    load_dataset,
    load_from_disk,
)
from torch.utils.data import Dataset as BaseDataLoader

from egu.dataset.base import BaseDataset


class MatchedForgetRetainDataset(BaseDataLoader):
    def __init__(self, forget_dataset, retain_dataset):
        self.forget_dataset = forget_dataset
        self.retain_dataset = retain_dataset
        self.retain_len = len(retain_dataset)

    def __len__(self):
        return len(self.forget_dataset)

    def __getitem__(self, idx):
        forget_item = self.forget_dataset[idx]
        retain_idx = idx % self.retain_len
        retain_item = self.retain_dataset[retain_idx]

        return {
            "forget": forget_item,
            "retain": retain_item,
        }


class RWKU(BaseDataset):
    dataset_type = "qa"
    path = "jinzhuoran/RWKU"
    name = "rwku"
    train_subsets = [
        "forget_target",
        "train_refusal_llama3",
    ]
    eval_subsets = [
        "forget_level1",
        "forget_level2",
        "forget_level3",
        "mia_forget",
        "mia_retain",
        "neighbor_level1",
        "neighbor_level2",
        "train_negative_llama3",
        "train_negative_phi3",
        "train_original_passage",
        "train_pair_llama3",
        "train_pair_phi3",
        "train_positive_llama3",
        "train_positive_phi3",
        "train_refusal_phi3",
        "utility_factuality",
        "utility_fluency",
        "utility_general",
        "utility_reason",
        "utility_truthfulness",
    ]
    # for the retain set right now default to refusal_llama3 for now double check this
    match_retain = {
        "forget_target": "train_refusal_llama3",
    }
    keys = ["prompt", "answer", "prompt_formatted"]
    eval_prompt_key = "prompt_formatted"
    eval_answer_key = "answer"
    gen_prompt_key = "prompt_formatted"
    gen_answer_key = "answer"
    eval_dataset_keys = ["retain", "forget", "test"]
    raw_path = "egu/dataset/raw/" + path

    def __init__(self, formatting_tokens=None, eos_token=None, *args, **kwargs):
        super().__init__()
        self.eos_token = eos_token if eos_token is not None else ""
        for k in [
            "prompt_prefix",
            "prompt_suffix",
            "answer_prefix",
            "answer_suffix",
        ]:
            (
                setattr(self, k, formatting_tokens[k])
                if formatting_tokens is not None
                else setattr(self, k, "")
            )

    def download(self):

        data_subsets = {}

        for s in self.train_subsets:
            output_path = self.raw_path + "/" + s
            if not os.path.exists(output_path):
                data_subsets[s] = load_dataset(
                    self.path, s, keep_in_memory=True, trust_remote_code=True
                )["train"]

                data_subsets[s].save_to_disk(output_path)
            else:
                print("path exists" + s)
                data_subsets[s] = load_from_disk(output_path)

        self.dataset = DatasetDict(data_subsets)

    def get_matched_retain_split(self, forget_split_name):
        return self.match_retain.get(forget_split_name, None)

    def format_example(self, example, retain_field="answer", prompt_prefix=""):
        # formatted_prompt = f"{prompt_prefix}{self.prompt_prefix}{example['prompt']}{self.prompt_suffix}"
        # formatted_answer = self.answer_prefix + example["answer"] + self.eos_token
        # return {
        #     "prompt_formatted": formatted_prompt,
        #     "answer": formatted_answer,
        # }

        x = {"prompt_formatted": example["prompt"], "answer": example[retain_field]}

        print(x)
        return x

    def load_dataset_for_training(self, split_name, prompt_prefix=""):
        if self.dataset is None:
            self.download()
        dataset = self.dataset[split_name]

        if "intro" in dataset.column_names:

            dataset = dataset.rename_column("intro", "prompt")

        # If this is a forget split, pair it with its retain
        print(self.match_retain)
        print(split_name)
        if split_name in self.match_retain:
            retain_split_name = self.get_matched_retain_split(split_name)
            retain_dataset = self.dataset[retain_split_name]
            if "instruction" in retain_dataset.column_names:
                retain_dataset = retain_dataset.rename_column("instruction", "prompt")
            retain_dataset = retain_dataset.map(
                lambda x: self.format_example(x, "output", prompt_prefix)
            )

            dataset = dataset.map(
                lambda x: self.format_example(x, "target", prompt_prefix)
            )

            return MatchedForgetRetainDataset(dataset, retain_dataset)
        else:
            return dataset

    def load_dataset_for_eval(self, split_name, prompt_prefix=""):
        return self.load_dataset_for_training(split_name, prompt_prefix)


if __name__ == "__main__":
    dataset = RWKU()
    test = dataset.load_dataset_for_training("forget_target")

    # dataset.download()
    print(test)
    print(type(test[0]))
