#!/usr/bin/env python
import json, random, argparse
from pathlib import Path
from json import JSONDecodeError
from typing import Any, Dict, List, Tuple

def load_json_any(path: Path) -> List[Dict[str, Any]]:
    text = path.read_text(encoding="utf-8").strip()
    try:
        obj = json.loads(text)
        return obj if isinstance(obj, list) else [obj]
    except JSONDecodeError:
        items = []
        with path.open("r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if line:
                    items.append(json.loads(line))
        return items

def extract_completions(example: Dict[str, Any]) -> List[Dict[str, Any]]:
    out = []
    for idx, comp in enumerate(example.get("completions", [])):
        anns = comp.get("annotations", {})
        try:
            help_score = int(anns["helpfulness"]["Rating"])
            truth_score = int(anns["truthfulness"]["Rating"])
        except (KeyError, ValueError):
            continue
        out.append({"idx": idx, "response": comp.get("response", ""), "truth": truth_score, "help": help_score})
    return out

def candidate_pairs_for_objective(completions: List[Dict[str, Any]], objective: str, margin: int) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]:
    key_attr = "truth" if objective == "truth" else "help"
    comps_sorted = sorted(completions, key=lambda c: c[key_attr], reverse=True)
    pairs = []
    n = len(comps_sorted)
    for i in range(n):
        hi = comps_sorted[i]
        for j in range(i + 1, n):
            lo = comps_sorted[j]
            if hi[key_attr] - lo[key_attr] >= margin:
                pairs.append((hi, lo))
    return pairs

def build_pairs(examples: List[Dict[str, Any]], objective: str, margin: int = 2, max_pairs_per_prompt: int = 3, seed: int = 42) -> List[Dict[str, Any]]:
    random.seed(seed)
    pairs_out = []
    obj_name = "truthfulness" if objective == "truth" else "helpfulness"
    for ex_idx, ex in enumerate(examples):
        completions = extract_completions(ex)
        if len(completions) < 2:
            continue
        cand_pairs = candidate_pairs_for_objective(completions, objective, margin)
        if not cand_pairs:
            continue
        k = min(max_pairs_per_prompt, len(cand_pairs))
        chosen_pairs = random.sample(cand_pairs, k)
        instruction = ex.get("instruction", ex.get("prompt", ""))
        source = ex.get("source", None)
        for high, low in chosen_pairs:
            delta_truth = high["truth"] - low["truth"]
            delta_help = high["help"] - low["help"]
            other_delta = delta_help
            conflict_strength = -other_delta
            if other_delta < 0:
                conflict = "true"
            elif other_delta == 0:
                conflict = "equal"
            else:
                conflict = "false"
            pairs_out.append(
                {
                    "example_id": ex_idx,
                    "source": source,
                    "instruction": instruction,
                    "objective": obj_name,
                    "chosen_idx": high["idx"],
                    "rejected_idx": low["idx"],
                    "chosen_response": high["response"],
                    "rejected_response": low["response"],
                    "scores": {
                        "truth_chosen": high["truth"],
                        "truth_rejected": low["truth"],
                        "help_chosen": high["help"],
                        "help_rejected": low["help"],
                    },
                    "deltas": {
                        "truth": delta_truth,
                        "help": delta_help,
                    },
                    "conflict": conflict,
                    "conflict_strength": conflict_strength,
                }
            )
    return pairs_out

def main():
    parser = argparse.ArgumentParser(description="Build UltraFeedback DPO pairs for truthfulness.")
    parser.add_argument("--input", type=str, required=True, help="Path to UltraFeedback JSON / JSONL file.")
    parser.add_argument("--margin", type=int, default=2, help="Minimum score difference (high - low).")
    parser.add_argument("--max_pairs_per_prompt", type=int, default=3, help="Max pairs per prompt per objective.")
    parser.add_argument("--seed", type=int, default=42, help="Random seed.")
    parser.add_argument("--truth_output", type=str, default="ultra_truth_pairs.json", help="Output JSON for truthfulness pairs.")
    args = parser.parse_args()

    examples = load_json_any(Path(args.input))
    truth_pairs = build_pairs(examples, objective="truth", margin=args.margin, max_pairs_per_prompt=args.max_pairs_per_prompt, seed=args.seed)

    truth_path = Path(args.truth_output)
    truth_path.parent.mkdir(parents=True, exist_ok=True)
    truth_path.write_text(json.dumps(truth_pairs, ensure_ascii=False, indent=2), encoding="utf-8")

    print(f"Saved {len(truth_pairs)} truthfulness pairs to {args.truth_output}")

if __name__ == "__main__":
    main()
