import logging
import sys
from typing import Optional, Union

from datasets import Dataset, DatasetDict, load_dataset

logger = logging.getLogger("benchmark")

logger.setLevel(logging.INFO)

if not logger.handlers:
    handler = logging.StreamHandler(sys.stdout)
    formatter = logging.Formatter(
        "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
    )
    handler.setFormatter(formatter)
    logger.addHandler(handler)

logger.propagate = False


class StoryKeys:
    """
    Constants for the keys used in the simple_stories_dataset.
    """

    TEMPLATE = "template"
    VARIABLE_TEXT = "variable_text"
    DESIRED_TEXT = "desired_text"
    UNDESIRED_TEXT = "undesired_text"
    PREDICTED_WORD = "predicted_word"
    DESIRED_DETAILS = "desired_details"
    UNDESIRED_DETAILS = "undesired_details"
    STORY_TYPE = "story_type"
    HUMAN_ANSWER = "human_answer"

    # Sub-keys for the details dictionaries
    LOGIT = "logit"
    RANK = "rank"
    WORD = "word"


class SAEKeys:
    """
    Constants for the keys used in the sae_benchmark_single dataset.
    """

    DENSITY = "density"
    VOCAB_DIVERSITY = "vocab_diversity"
    LOCAL_VS_GLOBAL = "local_vs_global"
    TAGS = "tags"
    NECESSARY_CONTEXT = "necessary_context"
    NECESSARY_CONDITION = "necessary_condition"
    SUCCESS_CRITERION = "success_criterion"
    HUMAN_EXPLANATION = "human_explanation"
    FEATURE_GRADE = "feature_grade"
    NEURONPEDIA_ID = "neuronpedia_id"
    INDEX = "index"


class BackdoorKeys:
    LORA_ID = "lora_id"
    BASE_MODEL = "base_model"
    TEMPLATE = "template"
    UNDESIRED_TEXT = "undesired_text"
    DESIRED_TEXT = "desired_text"
    DATASET_TYPE = "dataset_type"
    DATASET_INFO = "dataset_info"


def download_inpainting_stories_dataset(
    dataset_name: str = "contextmodification/simple_stories_new",
    split: Optional[str] = None,
    cache_dir: Optional[str] = None,
    trust_remote_code: bool = False,
    hf_token: Optional[str] = None,
) -> Union[DatasetDict, Dataset]:
    """
    Downloads a dataset from the Hugging Face Hub.

    Args:
        split: Optional split to download (e.g., "train", "validation"). Downloads all splits if None.
        cache_dir: Optional directory to cache the downloaded data. Defaults to HF default.
        trust_remote_code: Set to True if the dataset requires executing custom code.
        hf_token: Optional Hugging Face token for accessing private datasets.

    Returns:
        The downloaded dataset, either as a DatasetDict (if multiple splits) or a Dataset (if single split).

    Raises:
        ValueError: If the dataset or configuration is not found.
        Exception: For other download errors (e.g., network issues).
    """
    try:
        logger.info(
            f"Downloading dataset '{dataset_name}'"
            + (f", split: {split}" if split else ", all splits")
            + (f", cache_dir: {cache_dir}" if cache_dir else "")
            + (", trusting remote code" if trust_remote_code else "")
            + (", using HF token" if hf_token else "")
        )
        dataset = load_dataset(
            dataset_name,
            split=split,
            cache_dir=cache_dir,
            trust_remote_code=trust_remote_code,
            token=hf_token,
        )
        logger.info(f"Dataset '{dataset_name}' download complete.")
        return dataset
    except ValueError as e:
        logger.error(
            f"Error: Dataset or configuration not found for '{dataset_name}'"
            + f". Details: {e}"
        )
        raise
    except Exception as e:
        logger.error(f"An error occurred during download for '{dataset_name}': {e}")
        raise


def download_sae_dataset(
    dataset_name: str = "contextmodification/SAE_single_benchmark",
    split: Optional[str] = None,
    cache_dir: Optional[str] = None,
    trust_remote_code: bool = False,
    hf_token: Optional[str] = None,
) -> Union[DatasetDict, Dataset]:
    """
    Downloads the SAE benchmark dataset from the Hugging Face Hub.

    Args:
        split: Optional split to download (e.g., "train", "validation"). Downloads all splits if None.
        cache_dir: Optional directory to cache the downloaded data. Defaults to HF default.
        trust_remote_code: Set to True if the dataset requires executing custom code.
        hf_token: Optional Hugging Face token for accessing private datasets.

    Returns:
        The downloaded dataset, either as a DatasetDict (if multiple splits) or a Dataset (if single split).

    Raises:
        ValueError: If the dataset or configuration is not found.
        Exception: For other download errors (e.g., network issues).
    """

    try:
        logger.info(
            f"Downloading dataset '{dataset_name}'"
            + (f", split: {split}" if split else ", all splits")
            + (f", cache_dir: {cache_dir}" if cache_dir else "")
            + (", trusting remote code" if trust_remote_code else "")
            + (", using HF token" if hf_token else "")
        )
        dataset = load_dataset(
            dataset_name,
            split=split,
            cache_dir=cache_dir,
            trust_remote_code=trust_remote_code,
            token=hf_token,
        )
        logger.info(f"Dataset '{dataset_name}' download complete.")
        return dataset
    except ValueError as e:
        logger.error(
            f"Error: Dataset or configuration not found for '{dataset_name}'"
            + f". Details: {e}"
        )
        raise
    except Exception as e:
        logger.error(f"An error occurred during download for '{dataset_name}': {e}")
        raise


def download_backdoors_dataset(
    split: Optional[str] = None,
    cache_dir: Optional[str] = None,
    trust_remote_code: bool = False,
    hf_token: Optional[str] = None,
    dataset_name: str = "contextmodification/backdoors-benchmark-dataset",
) -> Union[DatasetDict, Dataset]:
    """
    Downloads a dataset from the Hugging Face Hub.

    Args:
        split: Optional split to download (e.g., "train", "validation"). Downloads all splits if None.
        cache_dir: Optional directory to cache the downloaded data. Defaults to HF default.
        trust_remote_code: Set to True if the dataset requires executing custom code.
        hf_token: Optional Hugging Face token for accessing private datasets.

    Returns:
        The downloaded dataset, either as a DatasetDict (if multiple splits) or a Dataset (if single split).

    Raises:
        ValueError: If the dataset or configuration is not found.
        Exception: For other download errors (e.g., network issues).
    """
    try:
        logger.info(
            f"Downloading dataset '{dataset_name}'"
            + (f", split: {split}" if split else ", all splits")
            + (f", cache_dir: {cache_dir}" if cache_dir else "")
            + (", trusting remote code" if trust_remote_code else "")
            + (", using HF token" if hf_token else "")
        )
        dataset = load_dataset(
            dataset_name,
            split=split,
            cache_dir=cache_dir,
            trust_remote_code=trust_remote_code,
            token=hf_token,
        )
        logger.info(f"Dataset '{dataset_name}' download complete.")
        return dataset
    except ValueError as e:
        logger.error(
            f"Error: Dataset or configuration not found for '{dataset_name}'"
            + f". Details: {e}"
        )
        raise
    except Exception as e:
        logger.error(f"An error occurred during download for '{dataset_name}': {e}")
        raise
