import json
import argparse
from pathlib import Path
from datasets import Dataset
from utils import load_single_dataset, save_dataset
from tqdm import tqdm


def main():
    parser = argparse.ArgumentParser(description='Merge JSONL or JSON files.')
    parser.add_argument('--input_files', type=str, help='Comma-separated list of JSONL or JSON files to merge')
    parser.add_argument('--output_file', type=str,  help='Path to the output merged file')
    args = parser.parse_args()

    input_files = [f for f in args.input_files.split(',') if f.strip()]

    if not input_files:
        print("No input files provided.")
        return

    data_all = []
    for input_fp in tqdm(input_files):
        print(input_fp)
        ds: Dataset = load_single_dataset(input_fp)
        data_all.extend(ds.to_list())
    save_dataset(data_all, args.output_file)


if __name__ == '__main__':
    main()


"""

~/verl_cs/.conda/bin/python ~/verl_cs/scripts/merge_file.py \
    --input_files ~/datasets/PRIME-RL-Eurus-2-RL-Data/validation_0_256.json,~/datasets/PRIME-RL-Eurus-2-RL-Data/validation_256_512.json,~/datasets/PRIME-RL-Eurus-2-RL-Data/validation_512_768.json,~/datasets/PRIME-RL-Eurus-2-RL-Data/validation_768_1024.json,~/datasets/PRIME-RL-Eurus-2-RL-Data/validation_1024_1280.json,~/datasets/PRIME-RL-Eurus-2-RL-Data/validation_1280_1536.json,~/datasets/PRIME-RL-Eurus-2-RL-Data/validation_1536_1792.json,~/datasets/PRIME-RL-Eurus-2-RL-Data/validation_1792_2048.json \
    --output_file ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-rl-rollouts/validation_0_2048.json


"""