"""Load the StrongREJECT dataset.
"""

import pathlib

from datasets import Dataset, load_dataset

DATA_DIR = pathlib.Path(__file__).parent.resolve() / "datasets"


def load_strongreject_small() -> Dataset:
    """Load the StrongREJECT small dataset.

    Returns:
        Dataset: StrongREJECT small.
    """
    return load_dataset(
        "csv",
        data_files="https://raw.githubusercontent.com/alexandrasouly/strongreject/main/strongreject_dataset/strongreject_small_dataset.csv",
    )["train"]


def load_strongreject() -> Dataset:
    """Load the full StrongREJECT dataset.

    Returns:
        Dataset: Full StrongREJECT dataset.
    """
    return load_dataset(
        "csv",
        data_files="https://raw.githubusercontent.com/alexandrasouly/strongreject/main/strongreject_dataset/strongreject_dataset.csv",
    )["train"]


def load_wmdp_open_ended() -> Dataset:
    """Load an open-ended version of the WMDP chem and bio datasets.

    Returns:
        Dataset: Open-ended version of WMDP chem and bio.
    """
    return load_dataset("dsbowen/wmdp_open_ended")["train"].rename_column(
        "question", "forbidden_prompt"
    )
