import datasets
import argparse
import os
from nesim.utils.json_stuff import dict_to_json
from tqdm import tqdm

parser = argparse.ArgumentParser(
    description="Save dataset in the form of lots of json files in a folder"
)
parser.add_argument("--start-idx", type=int, help="start", required=True)
parser.add_argument("--end-idx", type=int, help="end", required=False, default=None)
parser.add_argument("--output-folder", type=str, help="folder", required=True)
args = parser.parse_args()

print("Loading tokenized dataset...")
tokenized_dataset = datasets.load_from_disk(
    "/home/XXXX-4/repos/nesim/training/gpt_neo_125m/datasets/openwebtext_tokenized",
    keep_in_memory=False,
)
if "test" not in tokenized_dataset:
    # XXXX
    tokenized_dataset = tokenized_dataset["train"].train_test_split(
        test_size=0.1,
        seed=0
    )
    print(f"Made a train test split")
else:
    print(f"There is already a pre defined test set in the dataset. But we dont touch it here!")

print("Loading tokenized dataset complete! Will generate jsons for only the train samples. Wont touch test samples")

total_n_samples = len(tokenized_dataset["train"])
all_indices = [i for i in range(total_n_samples)]
times = []

print(f"Total number of samples in original dataset: {total_n_samples}")

assert args.start_idx in all_indices
assert os.path.exists(args.output_folder)
if args.end_idx is None:
    args.end_idx = total_n_samples

for dataset_idx in tqdm(
    range(args.start_idx, args.end_idx, 1),
    desc=f"[start_idx = {args.start_idx} end_idx: {args.end_idx}] Saving samples",
):
    
    filename = os.path.join(args.output_folder, f"{dataset_idx}.json")
    if not os.path.exists(filename):
        dict_to_json(
            tokenized_dataset["train"][dataset_idx],
            filename=filename,
        )
    else:
        pass
print("Done")

