from datasets import Dataset, DatasetDict
from src.utils.logging_utils import get_logger

logger = get_logger(name=__name__)


def push_ds_to_hub(
        *,
        dataset: DatasetDict | Dataset,
        dataset_name: str | None,
        dataset_name_template: str | None = None,
        hf_token: str,
        limit: int | None = None,
        **kwargs,
):
    if dataset_name_template:
        assert dataset_name is None, "Only one of dataset_name or dataset_name_template should be provided"
        dataset_name = dataset_name_template.format(**kwargs)
        assert "{" not in dataset_name, "All template keys should be filled"

    logger.info(f"Pushing dataset {dataset_name} to hub: \n{dataset}")

    if limit is not None:
        logger.warning(f"Limiting dataset to {limit} samples. No push to hub will be done")
        return dataset

    if isinstance(dataset, Dataset): 
        _ds_dict = DatasetDict(train=dataset)
    else:
        _ds_dict = dataset

    _ds_dict.push_to_hub(dataset_name, private=True, token=hf_token)
    logger.info(f"Dataset {dataset_name} pushed to hub")

    return dataset