import os

from datasets import (
    Dataset,
    DatasetDict,
    concatenate_datasets,
    load_dataset,
    load_from_disk,
)

from egu.dataset.base import BaseDataset


class WPU(BaseDataset):
    dataset_type = "qa"
    path = "Shiyu-Lab/Wikipedia_Person_Unlearn"
    name = "wpu"
    subsets = [
        "forget_100",
        "forget_100_hard_retain",
        "forget_20_1",
        "forget_20_1_hard_retain",
        "forget_20_2",
        "forget_20_2_hard_retain",
        "forget_20_3",
        "forget_20_3_hard_retain",
        "forget_2_1",
        "forget_2_1_hard_retain",
        "forget_2_2",
        "forget_2_2_hard_retain",
        "forget_2_3",
        "forget_2_3_hard_retain",
        "forget_2_4",
        "forget_2_4_hard_retain",
        "forget_2_5",
        "forget_2_5_hard_retain",
        "general_retain",
    ]
    match_retain = {
        "forget01": "retain99",
        "forget05": "retain95",
        "forget10": "retain90",
    }
    keys = ["prompt", "answer", "prompt_formatted"]
    eval_prompt_key = "prompt_formatted"
    eval_answer_key = "answer"
    gen_prompt_key = "prompt_formatted"
    gen_answer_key = "answer"
    eval_dataset_keys = ["retain", "forget", "test"]
    raw_path = "egu/dataset/raw/" + path

    def __init__(self, formatting_tokens=None, eos_token=None, *args, **kwargs):
        super().__init__()
        self.formatting_tokens = formatting_tokens
        self.eos_token = eos_token if eos_token is not None else ""
        for k in [
            "prompt_prefix",
            "prompt_suffix",
            "answer_prefix",
            "answer_suffix",
        ]:
            (
                setattr(self, k, formatting_tokens[k])
                if formatting_tokens is not None
                else setattr(self, k, "")
            )


if __name__ == "__main__":
    dataset = WPU()
    dataset.download()
