from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from datasets import load_dataset
from huggingface_hub import Repository, snapshot_download


if __name__ == "__main__":

    dataset_tot = load_dataset("ZHLiu627/ultrafeedback_binarized_with_response_full", split="train_prefs",
                               download_mode="force_redownload", ignore_verifications=True)

    dataset_train_1 = dataset_tot.select(range(0, 15000))
    dataset_train_2 = dataset_tot.select(range(15000, 30000))
    dataset_train_3 = dataset_tot.select(range(30000, 45000))

    dataset_gen_1 = dataset_tot.select(range(45000, 50000))
    dataset_gen_2 = dataset_tot.select(range(50000, 55000))
    dataset_gen_3 = dataset_tot.select(range(55000, 60000))

    dataset_train_1.push_to_hub(
        "ultrafeedback_binarized_new_train_part_1", split="train_prefs", private=False)
    dataset_train_2.push_to_hub(
        "ultrafeedback_binarized_new_train_part_2", split="train_prefs", private=False)
    dataset_train_3.push_to_hub(
        "ultrafeedback_binarized_new_train_part_3", split="train_prefs", private=False)

    dataset_gen_1.push_to_hub(
        "ultrafeedback_binarized_new_gen_part_1", split="train_prefs", private=False)
    dataset_gen_2.push_to_hub(
        "ultrafeedback_binarized_new_gen_part_2", split="train_prefs", private=False)
    dataset_gen_3.push_to_hub(
        "ultrafeedback_binarized_new_gen_part_3", split="train_prefs", private=False)

    # interval = len(dataset_opt)//4
    # start = interval*n_part
    # end = interval*(n_part+1) if n_part != 3 else len(dataset_opt)
    # dataset_opt = dataset_opt.select(range(start, end))

    # new_dataset = rank_responses(dataset_opt)

    # new_dataset = new_dataset.remove_columns(
    #     ["resp0", "resp1", "resp2", "resp3", "minpi", "random"])
    # if dataset_left != "None":
    #     dataset_rest = load_dataset(dataset_left, split="train_prefs",
    #                                 download_mode="force_redownload", ignore_verifications=True)
    #     new_dataset = concatenate_datasets([new_dataset, dataset_rest])

    # new_dataset.push_to_hub(
    #     output_dir+f"_mini_{n_part}", split="train_prefs", private=False)
