from tqdm import tqdm
import torch
from datasets import load_dataset, concatenate_datasets
import numpy as np
from transformers import AutoTokenizer
import datasets
import argparse
import os
import json
from huggingface_hub import delete_repo

parser = argparse.ArgumentParser()

# parser.add_argument("--total_part", type=int)
parser.add_argument("--dataset", type=str)
# parser.add_argument("--output_left", type=str)


args = parser.parse_args()

# col_name = args.name
dataset_dir = args.dataset

total_part = 8


# def get_json(dataset):
#     result = []
#     original = []
#     for row in dataset:
#         result.append({"instruction": row["prompt"],
#                        "output": row[col_name]})
#         original.append({"instruction": row["original"],
#                          "output": row[col_name]})
#     print("over!")
#     output_dir = "/home/zli5570/yibowang/nash_rlhf/RDPO-main/alpaca/"+col_name
#     os.makedirs(output_dir, exist_ok=True)
#     file_path = os.path.join(output_dir, "data.json")
#     with open(file_path, 'w') as json_file:
#         json.dump(result, json_file, indent=2)
#     file_path = os.path.join(output_dir, "ref_data.json")
#     with open(file_path, 'w') as json_file:
#         json.dump(original, json_file, indent=2)


if __name__ == "__main__":
    # train_dataset = load_dataset("YYYYYYibo/eval-dataset-"+col_name, split="train_prefs",
    #                              download_mode="force_redownload", ignore_verifications=True)
    # get_json(train_dataset)
    dataset_list = []
    for i in range(total_part):
        mini_dataset = load_dataset(dataset_dir+f"_mini_{i}", split="train_prefs",
                                    download_mode="force_redownload", ignore_verifications=True)
        dataset_list.append(mini_dataset)
    train_dataset = concatenate_datasets(dataset_list)
    train_dataset.push_to_hub(
        dataset_dir, split="train_prefs", private=False)

    # for i in range(4):
    #     delete_repo(dataset_dir+f"_mini_{i}")
