import argparse
from datasets import Dataset
from utils import load_single_dataset, save_dataset
from tqdm import tqdm
import os
from typing import List, Dict


def get_lf_data_row(row) -> List[Dict] :
    chosen_indices = [i for (i, score) in enumerate(row["scores"]) if score > 0.5]
    rejected_indices = [i for (i, score) in enumerate(row["scores"]) if score < 0.5]

    results = []
    for chosen_index in chosen_indices:
        rejected_index = rejected_indices[chosen_index % len(rejected_indices)]
        results.append({
            "system": row["prompt"][0]['content'],
            "conversations": [{"from": "human", "value": row["prompt"][1]['content']}],
            "chosen":   {"from": "gpt", "value": row["responses"][chosen_index]},
            "rejected": {"from": "gpt", "value": row["responses"][rejected_index]},
        })
    return results


def main():
    parser = argparse.ArgumentParser(description='Merge datasets under directories (expects subdir "true_and_false").')
    parser.add_argument('--input_files', required=True,
                        help='Comma-separated list of input **directories**, each containing a "true_and_false" dataset')
    parser.add_argument('--output_file', required=True, help='Path to the output merged file')
    args = parser.parse_args()

    input_files = [f for f in args.input_files.split(',') if f.strip()]

    data_all = []
    for input_fp in tqdm(input_files):
        print(input_fp)

        # true_and_false
        ds: Dataset = load_single_dataset(os.path.join(input_fp, "true_and_false"))
        tf_length = len(ds)
        print("true_and_false length: ", tf_length)
        for row in ds:
            data_all.extend(get_lf_data_row(row))
                
    import random
    random.shuffle(data_all)

    save_dataset(data_all, args.output_file)


if __name__ == '__main__':
    main()



"""

~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/dsfilter_4_prepare_for_pairdpo_trainset.py \
--input_files ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-rl-rollouts/ds_0_56712,~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-rl-rollouts/ds_56712_113424,~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-rl-rollouts/ds_113424_170136,~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-rl-rollouts/ds_170136_226848,~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-rl-rollouts/ds_226848_283560,~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-rl-rollouts/ds_283560_340272,~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-rl-rollouts/ds_340272_396984,~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-rl-rollouts/ds_396984_end \
--output_file ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-sft-llama32-3b-pairwise-rl-data.jsonl




~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/dsfilter_5_prepare_for_kto_trainset.py \
--input_files ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-sft-llama32-3b-pairwise-rl-data.jsonl \
--output_file ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-sft-llama32-3b-kto-rl-data.jsonl



"""
