
from .utils import ArgumentParserPlus, FlatArguments, get_datasets




def main():
    parser = ArgumentParserPlus((FlatArguments))
    args = parser.parse()

    # assert that data_mixer is not none in config
    assert args.dataset_mixer is not None, "data_mixer is required in config"

    raw_datasets = get_datasets(
        args.dataset_mixer,
        configs=args.dataset_config_name,
        splits=["train"],
        save_data_dir=args.dataset_mix_dir,  # location where dataset is saved as json
        columns_to_keep=["messages"],
    )

    # print first 5 samples of dataset
    for i in range(5):
        print(raw_datasets["train"][i])


if __name__ == "__main__":
    main()
