from typing import Optional, Union

from datasets import Dataset, DatasetDict, load_dataset

from eliciting_contexts.benchmark.external.utils.logger import logger


class SAEKeys:
    """
    Constants for the keys used in the sae_benchmark_single dataset.
    """

    DENSITY = "density"
    VOCAB_DIVERSITY = "vocab_diversity"
    LOCAL_VS_GLOBAL = "local_vs_global"
    TAGS = "tags"
    NECESSARY_CONTEXT = "necessary_context"
    NECESSARY_CONDITION = "necessary_condition"
    SUCCESS_CRITERION = "success_criterion"
    HUMAN_EXPLANATION = "human_explanation"
    FEATURE_GRADE = "feature_grade"
    NEURONPEDIA_ID = "neuronpedia_id"
    INDEX = "index"


def download_sae_dataset(
    dataset_name: str = "Eliciting-Contexts/SAE_single_benchmark",
    split: Optional[str] = None,
    cache_dir: Optional[str] = None,
    trust_remote_code: bool = False,
    hf_token: Optional[str] = None,
) -> Union[DatasetDict, Dataset]:
    """
    Downloads the SAE benchmark dataset from the Hugging Face Hub.

    Args:
        split: Optional split to download (e.g., "train", "validation"). Downloads all splits if None.
        cache_dir: Optional directory to cache the downloaded data. Defaults to HF default.
        trust_remote_code: Set to True if the dataset requires executing custom code.
        hf_token: Optional Hugging Face token for accessing private datasets.

    Returns:
        The downloaded dataset, either as a DatasetDict (if multiple splits) or a Dataset (if single split).

    Raises:
        ValueError: If the dataset or configuration is not found.
        Exception: For other download errors (e.g., network issues).
    """

    try:
        logger.info(
            f"Downloading dataset '{dataset_name}'"
            + (f", split: {split}" if split else ", all splits")
            + (f", cache_dir: {cache_dir}" if cache_dir else "")
            + (", trusting remote code" if trust_remote_code else "")
            + (", using HF token" if hf_token else "")
        )
        dataset = load_dataset(
            dataset_name,
            split=split,
            cache_dir=cache_dir,
            trust_remote_code=trust_remote_code,
            token=hf_token,
        )
        logger.info(f"Dataset '{dataset_name}' download complete.")
        return dataset
    except ValueError as e:
        logger.error(
            f"Error: Dataset or configuration not found for '{dataset_name}'"
            + f". Details: {e}"
        )
        raise
    except Exception as e:
        logger.error(f"An error occurred during download for '{dataset_name}': {e}")
        raise


if __name__ == "__main__":
    dataset = download_sae_dataset()
    for data in dataset["test"]:
        print(data)
        break
