# use dsfilter_3_save_datasetdict.py at first



import argparse
from datasets import Dataset
from utils import load_single_dataset, save_dataset
from tqdm import tqdm
import os
from typing import List, Dict
import copy


def get_lf_data_row(row) -> Dict :
    chosen_indices = [i for (i, score) in enumerate(row["scores"]) if score > 0.5]
    rejected_indices = [i for (i, score) in enumerate(row["scores"]) if score < 0.5]

    result = copy.deepcopy(row)
    result["responses"] = [row["responses"][chosen_indices[0]], row["responses"][rejected_indices[0]]]
    result["scores"] = [1, 0]
    return result


def main():
    parser = argparse.ArgumentParser(description='Merge datasets under directories (expects subdir "true_and_false").')
    parser.add_argument('--input_files', required=True,
                        help='Comma-separated list of input **directories**, each containing a "true_and_false" dataset')
    parser.add_argument('--output_file', required=True, help='Path to the output merged file')
    args = parser.parse_args()

    input_files = [f for f in args.input_files.split(',') if f.strip()]

    data_all = []
    for input_fp in tqdm(input_files):
        print(input_fp)

        # true_and_false
        ds: Dataset = load_single_dataset(os.path.join(input_fp, "true_and_false"))
        tf_length = len(ds)
        print("true_and_false length: ", tf_length)
        for row in ds:
            data_all.append(get_lf_data_row(row))
                
    import random
    random.shuffle(data_all)

    save_dataset(data_all, args.output_file)


if __name__ == '__main__':
    main()

"""
~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/mot0_pairwise.py \
    --input_files ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-rl-rollouts/validation_0_2048_classify \
    --output_file ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-rl-rollouts/validation_0_2048_pairwise.json
"""
