import os

import torch
from datasets import (
    Dataset,
    DatasetDict,
    concatenate_datasets,
    load_dataset,
    load_from_disk,
)
from torch.utils.data import Dataset as BaseDataLoader

from egu.dataset.base import BaseDataset


class MatchedForgetRetainDataset(BaseDataLoader):
    def __init__(self, forget_dataset, retain_dataset):
        self.forget_dataset = forget_dataset
        self.retain_dataset = retain_dataset
        self.retain_len = len(retain_dataset)

    def __len__(self):
        return len(self.forget_dataset)

    def __getitem__(self, idx):
        forget_item = self.forget_dataset[idx]
        # Match strategy: offset or modulo if different size
        retain_idx = idx % self.retain_len
        retain_item = self.retain_dataset[retain_idx]

        return {
            "forget": forget_item,
            "retain": retain_item,
        }


class TDEC(BaseDataset):
    dataset_type = "qa"
    # path = "white-rabbitz/tdec-unlearning-splits-neo2p7b"
    path = "white-rabbitz/tdec-unlearning-splits-neo2p7b"
    name = "tdec"
    subsets = [
        "forget1",
        "forget2",
        "forget3",
        "forget4",
        "forget5",
        "retain",
        "validation",
    ]
    raw_path = "egu/dataset/raw/" + path

    def download(self):

        data_subsets = {}

        for s in self.subsets:
            output_path = self.raw_path + "/" + s
            if not os.path.exists(output_path):
                print(self.path)
                data_subsets[s] = load_dataset(
                    self.path, s, keep_in_memory=True, trust_remote_code=True
                )["train"]
                print(data_subsets)

                data_subsets[s].save_to_disk(output_path)
            else:
                print("path exists" + s)
                data_subsets[s] = load_from_disk(output_path)

        self.dataset = DatasetDict(data_subsets)

    def load_dataset_for_training(self, split_name, prompt_prefix=""):
        if self.dataset is None:
            self.download()

        if split_name == "retain":
            return self.dataset["retain"]

        if split_name not in self.dataset:
            raise ValueError(
                f"Unknown split '{split_name}'. Expected one of {self.subsets}."
            )

        # sanity: each forget split should have exactly 32 examples
        if split_name.startswith("forget"):
            n_forget = len(self.dataset[split_name])
            if n_forget != 32:
                raise ValueError(
                    f"{split_name} must contain 32 items, found {n_forget}"
                )

        forget_ds = self.dataset[split_name]
        retain_ds = self.dataset["retain"]
        return MatchedForgetRetainDataset(forget_ds, retain_ds)

    def load_dataset_for_validation(self, split_name, prompt_prefix=""):
        print(f"loading validation dataset: {split_name}")
        return self.load_dataset_for_training("validation")


if __name__ == "__main__":
    dataset = TDEC()
    test = dataset.load_dataset_for_training("forget1")

    print(test)
    validation = dataset.load_dataset_for_validation("forget1")
    print(validation)
