import argparse
from datasets import Dataset
from utils import load_single_dataset, save_dataset
from tqdm import tqdm
import os
from typing import List, Dict


def get_lf_data_row(row: Dict) -> List[Dict]:
    # 找到分数大于 0.5 和小于 0.5 的索引
    chosen_indices = [i for i, score in enumerate(row.get("scores", [])) if score > 0.5]
    rejected_indices = [i for i, score in enumerate(row.get("scores", [])) if score < 0.5]

    # 如果没有符合条件的，直接返回空
    if not chosen_indices or not rejected_indices:
        return []

    # 只保留第一个
    chosen_index = chosen_indices[0]
    rejected_index = rejected_indices[0]

    return [{
        "system": row["prompt"][0]["content"],
        "conversations": [
            {"from": "human", "value": row["prompt"][1]["content"]}
        ],
        "chosen":   {"from": "gpt", "value": row["responses"][chosen_index]},
        "rejected": {"from": "gpt", "value": row["responses"][rejected_index]},
    }]



def main():
    parser = argparse.ArgumentParser(description='Merge datasets under directories (expects subdir "true_and_false").')
    parser.add_argument('--input_files', required=True,
                        help='Comma-separated list of input **directories**, each containing a "true_and_false" dataset')
    parser.add_argument('--output_file', required=True, help='Path to the output merged file')
    args = parser.parse_args()

    input_files = [f for f in args.input_files.split(',') if f.strip()]

    data_all = []
    for input_fp in tqdm(input_files):
        print(input_fp)

        # true_and_false
        ds: Dataset = load_single_dataset(os.path.join(input_fp, "true_and_false"))
        tf_length = len(ds)
        print("true_and_false length: ", tf_length)
        for row in ds:
            data_all.extend(get_lf_data_row(row))
                
    import random
    random.shuffle(data_all)

    save_dataset(data_all, args.output_file)


if __name__ == '__main__':
    main()



"""

~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/dsfilter_41_single_prepare_for_pairdpo_trainset.py \
--input_files ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-rl-rollouts/ds_0_56712,~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-rl-rollouts/ds_56712_113424,~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-rl-rollouts/ds_113424_170136,~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-rl-rollouts/ds_170136_226848,~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-rl-rollouts/ds_226848_283560,~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-rl-rollouts/ds_283560_340272,~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-rl-rollouts/ds_340272_396984,~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-rl-rollouts/ds_396984_end  \
--output_file ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-sft-llama32-3b-pairwise-rl-data-single.jsonl



~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/dsfilter_5_prepare_for_kto_trainset.py \
--input_files ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-sft-llama32-3b-pairwise-rl-data-single.jsonl \
--output_file ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-sft-llama32-3b-kto-rl-data-single.jsonl



"""
