import argparse
from typing import Any, Dict

from datasets import DatasetDict, load_dataset

# Default values that will be used if not provided via command line
DEFAULT_SOURCE_DATASET_ID = "contextmodification/sandbagging-sciq"
DEFAULT_TARGET_DATASET_ID = "contextmodification/sandbagging-sciq-template"
PREFIX_COLUMN = "prefix"
PROMPT_COLUMN = "prompt"
NEW_TEMPLATE_COLUMN = "template"
SANDBAGGING_ENVIRONMENT_COLUMN = "sandbagging_environment"
SYSTEM_PROMPT_TEXT = "<start_of_turn>system log\n{0}<end_of_turn>\n"
USER_PROMPT_TEXT = "<start_of_turn>user\n"

# Global variable to store our dictionary


def parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(description="Add template to sciq dataset")
    parser.add_argument(
        "--source-dataset",
        type=str,
        default=DEFAULT_SOURCE_DATASET_ID,
        help=f"Source dataset ID (default: {DEFAULT_SOURCE_DATASET_ID})",
    )
    parser.add_argument(
        "--target-dataset",
        type=str,
        default=DEFAULT_TARGET_DATASET_ID,
        help=f"Target dataset ID where results will be pushed (default: {DEFAULT_TARGET_DATASET_ID})",
    )
    return parser.parse_args()


def add_template(example: Dict[str, Any]) -> Dict[str, Any]:
    """Modifies the prefix column with randomly sampled words."""

    original_prompt = example[PROMPT_COLUMN]

    prefix = example[PREFIX_COLUMN]

    template_prompt = original_prompt.replace(prefix, "{0}")

    print(template_prompt)
    example[NEW_TEMPLATE_COLUMN] = template_prompt

    return example


if __name__ == "__main__":
    # Parse command line arguments
    args = parse_args()
    source_dataset_id = args.source_dataset
    target_dataset_id = args.target_dataset

    print(f"Loading dataset: {source_dataset_id}")
    dataset_dict = load_dataset(source_dataset_id)

    print("Preprocessing dataset...")
    modified_datasets = {}
    for split, dataset in dataset_dict.items():
        print(f"  Processing split: {split}")
        modified_datasets[split] = dataset.map(add_template)

    final_dataset_dict = DatasetDict(modified_datasets)
    print(final_dataset_dict)

    print(f"Pushing modified dataset to Hugging Face Hub: {target_dataset_id}")
    print(f"https://huggingface.co/datasets/{target_dataset_id}")
    try:
        final_dataset_dict.push_to_hub(target_dataset_id)
        print("Dataset successfully pushed to the Hub!")
    except Exception as e:
        print(f"Error pushing dataset to the Hub: {e}")
        print(
            "Please ensure you are logged in (`huggingface-cli login`) and have the correct permissions."
        )
