from datasets import Dataset, DatasetDict, load_dataset

if __name__ == "__main__":
    dataset = load_dataset(
        "contextmodification/simple_stories_new"
    )

    new_dataset_name_jailbreaking = "contextmodification/jailbreaking"
    new_dataset_name_discovery = "contextmodification/discover"
    other_dataset_name = "contextmodification/simple_stories_new"
    new_dataset_jailbreaking = {}
    new_dataset_discovery = {}
    new_dataset_other = {}
    for split in dataset:
        new_dataset_jailbreaking[split] = []
        new_dataset_discovery[split] = []
        new_dataset_other[split] = []
        for datum in dataset[split]:
            if datum["story_type"] == "jailbreak":
                new_dataset_jailbreaking[split].append(datum)
            elif datum["story_type"] == "discover":
                new_dataset_discovery[split].append(datum)
            else:
                new_dataset_other[split].append(datum)

    # dataset_dict_jailbreaking = DatasetDict({
    #     split: Dataset.from_list(examples)
    #     for split, examples in new_dataset_jailbreaking.items()
    # })

    # dataset_dict_discovery = DatasetDict({
    #     split: Dataset.from_list(examples)
    #     for split, examples in new_dataset_discovery.items()
    # })
    dataset_dict_other = DatasetDict(
        {
            split: Dataset.from_list(examples)
            for split, examples in new_dataset_other.items()
        }
    )

    dataset_dict_other.push_to_hub(other_dataset_name)

    # dataset_dict_jailbreaking.push_to_hub(new_dataset_name_jailbreaking)
    # dataset_dict_discovery.push_to_hub(new_dataset_name_discovery)
