from typing import Optional, Union

from datasets import Dataset, DatasetDict, load_dataset

from eliciting_contexts.benchmark.external.utils.logger import logger


class StoryKeys:
    """
    Constants for the keys used in the simple_stories_dataset.
    """

    TEMPLATE = "template"
    VARIABLE_TEXT = "variable_text"
    DESIRED_TEXT = "desired_text"
    UNDESIRED_TEXT = "undesired_text"
    PREDICTED_WORD = "predicted_word"
    DESIRED_DETAILS = "desired_details"
    UNDESIRED_DETAILS = "undesired_details"
    STORY_TYPE = "story_type"
    HUMAN_ANSWER = "human_answer"

    # Sub-keys for the details dictionaries
    LOGIT = "logit"
    RANK = "rank"
    WORD = "word"


def download_tiny_stories_dataset(
    split: Optional[str] = None,
    cache_dir: Optional[str] = None,
    trust_remote_code: bool = False,
    hf_token: Optional[str] = None,
    dataset_name: str = "contextmodification/simple_stories_dataset_with_human_answers",
) -> Union[DatasetDict, Dataset]:
    """
    Downloads a dataset from the Hugging Face Hub.

    Args:
        split: Optional split to download (e.g., "train", "validation"). Downloads all splits if None.
        cache_dir: Optional directory to cache the downloaded data. Defaults to HF default.
        trust_remote_code: Set to True if the dataset requires executing custom code.
        hf_token: Optional Hugging Face token for accessing private datasets.

    Returns:
        The downloaded dataset, either as a DatasetDict (if multiple splits) or a Dataset (if single split).

    Raises:
        ValueError: If the dataset or configuration is not found.
        Exception: For other download errors (e.g., network issues).
    """
    try:
        logger.info(
            f"Downloading dataset '{dataset_name}'"
            + (f", split: {split}" if split else ", all splits")
            + (f", cache_dir: {cache_dir}" if cache_dir else "")
            + (", trusting remote code" if trust_remote_code else "")
            + (", using HF token" if hf_token else "")
        )
        dataset = load_dataset(
            dataset_name,
            split=split,
            cache_dir=cache_dir,
            trust_remote_code=trust_remote_code,
            token=hf_token,
        )
        logger.info(f"Dataset '{dataset_name}' download complete.")
        return dataset
    except ValueError as e:
        logger.error(
            f"Error: Dataset or configuration not found for '{dataset_name}'"
            + f". Details: {e}"
        )
        raise
    except Exception as e:
        logger.error(f"An error occurred during download for '{dataset_name}': {e}")
        raise


if __name__ == "__main__":
    dataset = download_tiny_stories_dataset()
    print(dataset)
    for datum in dataset["test"]:
        print(datum)
