import os

from datasets import (
    Dataset,
    DatasetDict,
    concatenate_datasets,
    load_dataset,
    load_from_disk,
)

from egu.dataset.base import BaseDataset


class WDMP(BaseDataset):
    dataset_type = "qa"
    path = "cais/wmdp-corpora"
    name = "wmdp"
    subsets = ["cyber-retain-corpus", "cyber-forget-corpus"]
    match_retain = {
        "cyber-forget-corpus": "cyber-retain-corpus",
    }
    eval_prompt_key = "prompt_formatted"
    eval_answer_key = "answer"
    gen_prompt_key = "prompt_formatted"
    gen_answer_key = "answer"
    raw_path = "egu/dataset/raw/" + path

    def __init__(self, formatting_tokens=None, eos_token=None, *args, **kwargs):
        super().__init__()
        self.formatting_tokens = formatting_tokens
        self.eos_token = eos_token if eos_token is not None else ""
        for k in [
            "prompt_prefix",
            "prompt_suffix",
            "answer_prefix",
            "answer_suffix",
        ]:
            (
                setattr(self, k, formatting_tokens[k])
                if formatting_tokens is not None
                else setattr(self, k, "")
            )


if __name__ == "__main__":
    dataset = WDMP()
    dataset.download()
