#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import random
from typing import List, Dict, Any

from utils import load_single_dataset, save_dataset
from tqdm import tqdm


def _extract_text(x: Any) -> str:
    """从多种可能的结构中提取文本。"""
    if isinstance(x, dict):
        # 兼容常见字段名
        return x.get("value") or x.get("text") or x.get("content") or ""
    if x is None:
        return ""
    return str(x)


def _find_user_message(conversations: Any) -> str:
    """从 conversations 中找到第一条 human/user 的内容。"""
    if isinstance(conversations, list):
        for m in conversations:
            if isinstance(m, dict) and m.get("from") in ("human", "user"):
                return _extract_text(m)
    return ""


def split_pair_row_to_kto_rows(row: Dict[str, Any]) -> List[Dict[str, Any]]:
    """
    将一条 pairwise 行（含 chosen / rejected）拆成 KTO 两条（kto_tag=True/False）。
    保留 system 字段，并确保 conversations = [human, gpt]。
    """
    out: List[Dict[str, Any]] = []

    system = row.get("system", "")

    # 找到 human 文本（优先从 conversations；回退到 prompt[1]）
    user_text = _find_user_message(row.get("conversations"))
    if not user_text and isinstance(row.get("prompt"), list):
        p = row["prompt"]
        if len(p) > 1 and isinstance(p[1], dict):
            user_text = p[1].get("content", "") or ""

    chosen = row.get("chosen")
    rejected = row.get("rejected")
    if chosen is None and rejected is None:
        return out

    chosen_text = _extract_text(chosen)
    rejected_text = _extract_text(rejected)

    if chosen_text:
        out.append({
            "system": system,  # 保留 system 字段
            "conversations": [
                {"from": "human", "value": user_text},
                {"from": "gpt", "value": chosen_text},
            ],
            "kto_tag": True,   # 人类偏好为真
        })

    if rejected_text:
        out.append({
            "system": system,  # 保留 system 字段
            "conversations": [
                {"from": "human", "value": user_text},
                {"from": "gpt", "value": rejected_text},
            ],
            "kto_tag": False,  # 人类偏好为假
        })

    return out


def main():
    parser = argparse.ArgumentParser(
        description="Split pairwise dataset (chosen/rejected) into KTO dataset with kto_tag."
    )
    parser.add_argument(
        "--input_files",
        required=True,
        help="逗号分隔的输入路径列表（文件或目录），每个会用 utils.load_single_dataset 读取",
    )
    parser.add_argument(
        "--output_file",
        required=True,
        help="输出文件路径（使用 utils.save_dataset 保存）",
    )
    args = parser.parse_args()

    inputs = [p.strip() for p in args.input_files.split(",") if p.strip()]

    kto_rows: List[Dict[str, Any]] = []
    full_pairs = 0

    for inp in tqdm(inputs, desc="Loading"):
        ds = load_single_dataset(inp)
        for row in ds:
            items = split_pair_row_to_kto_rows(row)
            if items:
                kto_rows.extend(items)
                if len(items) == 2:
                    full_pairs += 1


    save_dataset(kto_rows, args.output_file)
    print(f"[done] wrote {len(kto_rows)} KTO rows from ~{full_pairs} full pairs -> {args.output_file}")


if __name__ == "__main__":
    main()

"""
~/verl_cs/.conda/bin/python ~/verl_cs/scripts/dsfilter_5_prepare_for_kto_trainset.py \
--input_files ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-1B-base/prime-sft-full/prime-sft-llama3.2-1B-pairwise-rl-data.jsonl \
--output_file ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-1B-base/prime-sft-full/prime-sft-llama3.2-1B-kto-rl-data.jsonl




~/verl_250713/.conda/bin/python \
~/verl_250713/scripts/dsfilter_5_prepare_for_kto_trainset.py \
--input_files ~/LLaMA-Factory-250514/saves_shuyan/qwen3-8B-base/prime-sft/prime-rl-rollouts/ds_part1,~/LLaMA-Factory-250514/saves_shuyan/qwen3-8B-base/prime-sft/prime-rl-rollouts/ds_part2 \
--output_file ~/LLaMA-Factory-250514/saves_shuyan/qwen3-8B-base/prime-sft/prime-sft-qwen3-8b-kto-rl-data.jsonl


"""
