import argparse
from datasets import Dataset, concatenate_datasets
from utils import load_single_dataset, save_dataset
from tqdm import tqdm
import os


def main():
    parser = argparse.ArgumentParser(description='Merge JSONL or JSON files.')
    parser.add_argument('--input_files', help='Comma-separated list of JSONL or JSON files to merge')
    parser.add_argument('--output_file', help='Path to the output merged file')
    args = parser.parse_args()

    input_files = [f for f in args.input_files.split(',') if f.strip()]

    data_all = []
    for input_fp in tqdm(input_files):
        print(input_fp)

        # true_and_false
        ds: Dataset = load_single_dataset(os.path.join(input_fp, "true_and_false"))
        data_all.append(ds)
        tf_length = len(ds)

        # all_false
        ds: Dataset = load_single_dataset(os.path.join(input_fp, "all_false"))
        af_length = min(int(tf_length*(1/3)), len(ds))
        ds = ds.select(range(af_length))
        data_all.append(ds)

        # all_true
        ds: Dataset = load_single_dataset(os.path.join(input_fp, "all_true"))
        af_length = min(int(tf_length*(2/3)), len(ds))
        ds = ds.select(range(af_length))
        data_all.append(ds)

    data_all = concatenate_datasets(data_all)
    save_dataset(data_all, args.output_file)


if __name__ == '__main__':
    main()

"""


~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/dsfilter_4_prepare_for_verl_dataset.py \
--input_files ~/LLaMA-Factory-250514/saves/qwen3-1.7B/prime-rl-dataset-part1,~/LLaMA-Factory-250514/saves/qwen3-1.7B/prime-rl-dataset-part2,~/LLaMA-Factory-250514/saves/qwen3-1.7B/prime-rl-dataset-part3,~/LLaMA-Factory-250514/saves/qwen3-1.7B/prime-rl-dataset-part4,~/LLaMA-Factory-250514/saves/qwen3-1.7B/prime-rl-dataset-part5,~/LLaMA-Factory-250514/saves/qwen3-1.7B/prime-rl-dataset-part6,~/LLaMA-Factory-250514/saves/qwen3-1.7B/prime-rl-dataset-part7,~/LLaMA-Factory-250514/saves/qwen3-1.7B/prime-rl-dataset-part8 \
--output_file ~/LLaMA-Factory-250514/saves/qwen3-1.7B/prime-verl-data.parquet


"""
