from typing import Optional, Union

from datasets import Dataset, DatasetDict, load_dataset

from eliciting_contexts.benchmark.external.utils.logger import logger


class ApplicationsKeys:
    LORA_ID = "lora_id"
    BASE_MODEL = "base_model"
    TEMPLATE = "template"
    UNDESIRED_TEXT = "undesired_text"
    DESIRED_TEXT = "desired_text"
    DATASET_TYPE = "dataset_type"
    DATASET_INFO = "dataset_info"


def download_backdoors_dataset(
    split: Optional[str] = None,
    cache_dir: Optional[str] = None,
    trust_remote_code: bool = False,
    hf_token: Optional[str] = None,
    dataset_name: str = "contextmodification/backdoors-benchmark-dataset",
) -> Union[DatasetDict, Dataset]:
    """
    Downloads a dataset from the Hugging Face Hub.

    Args:
        split: Optional split to download (e.g., "train", "validation"). Downloads all splits if None.
        cache_dir: Optional directory to cache the downloaded data. Defaults to HF default.
        trust_remote_code: Set to True if the dataset requires executing custom code.
        hf_token: Optional Hugging Face token for accessing private datasets.

    Returns:
        The downloaded dataset, either as a DatasetDict (if multiple splits) or a Dataset (if single split).

    Raises:
        ValueError: If the dataset or configuration is not found.
        Exception: For other download errors (e.g., network issues).
    """
    try:
        logger.info(
            f"Downloading dataset '{dataset_name}'"
            + (f", split: {split}" if split else ", all splits")
            + (f", cache_dir: {cache_dir}" if cache_dir else "")
            + (", trusting remote code" if trust_remote_code else "")
            + (", using HF token" if hf_token else "")
        )
        dataset = load_dataset(
            dataset_name,
            split=split,
            cache_dir=cache_dir,
            trust_remote_code=trust_remote_code,
            token=hf_token,
        )
        logger.info(f"Dataset '{dataset_name}' download complete.")
        return dataset
    except ValueError as e:
        logger.error(
            f"Error: Dataset or configuration not found for '{dataset_name}'"
            + f". Details: {e}"
        )
        raise
    except Exception as e:
        logger.error(f"An error occurred during download for '{dataset_name}': {e}")
        raise


if __name__ == "__main__":
    dataset = download_backdoors_dataset()
    print(dataset)
    for datum in dataset["test"]:
        print(datum)
