import os
from typing import Any

from datasets import Dataset, DatasetDict

from eliciting_contexts.benchmark.external.utils.logger import logger
from eliciting_contexts.benchmark.internal.tiny_stories.raw_data import (
    simple_stories_dataset,
)


def upload_hf_dataset(
    dataset: Dataset | DatasetDict,
    repo_id: str,
    hf_token: str,
    private: bool = False,
    **kwargs: Any,
) -> None:
    """
    Uploads a dataset to the Hugging Face Hub.

    Args:
        dataset: The dataset object (Dataset or DatasetDict) to upload.
        repo_id: The target repository ID on the Hugging Face Hub (e.g., "username/my-dataset").
        hf_token: Hugging Face API token. If None, it will try to use the token
                  from the environment or Hugging Face CLI login.
        private: If True, creates a private repository. Defaults to False.
        **kwargs: Additional keyword arguments passed to `dataset.push_to_hub()`.
    """
    if not repo_id or len(repo_id.split("/")) != 2:
        raise ValueError(
            "Invalid repo_id format. Expected 'username/repository_name' or 'org_name/repository_name'."
        )

    try:
        logger.info(
            f"Uploading dataset to Hugging Face Hub repository: '{repo_id}' (Private: {private})"
        )
        dataset.push_to_hub(
            repo_id=repo_id,
            token=hf_token,
            private=private,
            **kwargs,
        )
        logger.info(f"Successfully uploaded dataset to '{repo_id}'.")
    except (
        ValueError
    ) as e:  # Handles auth errors, invalid repo_id etc. from push_to_hub
        logger.error(
            f"Failed to upload dataset to '{repo_id}'. Check repository ID and authentication. Details: {e}"
        )
        raise ValueError from e
    except Exception as e:
        logger.error(f"An unexpected error occurred during upload to '{repo_id}': {e}")
        raise Exception from e


def upload_tiny_stories_dataset():
    column_names = ["template", "variable_text", "desired_text", "undesired_text"]
    data_dict = {
        name: [item[i] for item in simple_stories_dataset]
        for i, name in enumerate(column_names)
    }
    dataset = Dataset.from_dict(data_dict)
    dataset_dict = DatasetDict({"test": dataset})
    print("Created DatasetDict.")

    api_token = os.environ.get("HF_TOKEN")

    upload_hf_dataset(
        dataset_dict,
        "contextmodification/simple_stories_dataset",
        hf_token=api_token,
        private=True,
    )


if __name__ == "__main__":
    upload_tiny_stories_dataset()
