#!/usr/bin/env python3
"""
Build a combined survey pool from one or more summaries_brief.jsonl files.

Each question–model pair is treated as a separate data point. Sampling can be
stratified by bias score bins, with a cap on the number of models per question.
"""

import argparse
import json
import math
import random
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Build combined survey pool.")
    parser.add_argument(
        "--inputs",
        nargs="+",
        required=True,
        help="Paths to summaries_brief.jsonl files (e.g., race/sex/religion).",
    )
    parser.add_argument(
        "--output-dir",
        default="human_study/output",
        help="Directory to write combined pool and selection.",
    )
    parser.add_argument(
        "--target-total",
        type=int,
        default=None,
        help="Optional number of question-model pairs to select. If omitted, keep all.",
    )
    parser.add_argument(
        "--bias-bins",
        type=str,
        default="1-2,2-3,3-4,4-5",
        help="Bias bins, e.g., '1-2,2-3,3-4,4-5'.",
    )
    parser.add_argument(
        "--bias-bin-weights",
        type=str,
        default=None,
        help="Comma-separated weights aligned with bias bins, e.g., '1,1,2,3'.",
    )
    parser.add_argument(
        "--model-cap-per-question",
        type=int,
        default=3,
        help="Maximum number of models to keep per question (sorted by bias score).",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=17,
        help="Random seed for sampling and tie-breaking.",
    )
    return parser.parse_args()


def parse_bias_bins(raw_bins: str) -> List[Tuple[float, float]]:
    bins: List[Tuple[float, float]] = []
    for part in raw_bins.split(","):
        part = part.strip()
        if not part or "-" not in part:
            continue
        lo_s, hi_s = part.split("-", 1)
        try:
            bins.append((float(lo_s), float(hi_s)))
        except ValueError:
            continue
    return bins


def bias_bin_label(value: Optional[float], bins: List[Tuple[float, float]]) -> str:
    if value is None or math.isnan(value):
        return "unknown"
    for lo, hi in bins:
        if lo <= value < hi or (hi == bins[-1][1] and value == hi):
            return f"{lo}-{hi}"
    return "out_of_range"


def allocate_targets(total: int, weights: List[float]) -> List[int]:
    if not weights:
        return []
    weight_sum = sum(weights)
    if weight_sum <= 0:
        weights = [1.0 for _ in weights]
        weight_sum = sum(weights)
    normalized = [w / weight_sum for w in weights]
    raw = [w * total for w in normalized]
    floors = [int(math.floor(x)) for x in raw]
    remainder = total - sum(floors)
    deltas = [x - f for x, f in zip(raw, floors)]
    order = sorted(range(len(deltas)), key=lambda i: deltas[i], reverse=True)
    for i in range(remainder):
        floors[order[i % len(order)]] += 1
    return floors


def load_pool(paths: List[str]) -> List[Dict]:
    pool: List[Dict] = []
    for path in paths:
        source = Path(path)
        with source.open("r", encoding="utf-8") as f:
            for line in f:
                if not line.strip():
                    continue
                entry = json.loads(line)
                for model_id, model_data in entry.get("models", {}).items():
                    bias_scores = [
                        j.get("bias_score")
                        for j in model_data.get("bias_judgements", [])
                        if j.get("bias_score") is not None
                    ]
                    bias_mean = sum(bias_scores) / len(bias_scores) if bias_scores else None
                    pool.append(
                        {
                            "id": f"{entry.get('question_id')}__{model_id}",
                            "question_id": entry.get("question_id"),
                            "model_id": model_id,
                            "attribute": entry.get("attribute"),
                            "attribute_values": list(entry.get("attribute_values", {}).keys()),
                            "question_template": entry.get("question_template"),
                            "question_text": entry.get("question_text"),
                            "question_fitness": entry.get("question_fitness"),
                            "model_fitness": model_data.get("model_fitness"),
                            "bias_score": bias_mean,
                            "bias_judgements": model_data.get("bias_judgements", []),
                            "summaries": model_data.get("summaries", {}),
                            "differences": model_data.get("differences", {}),
                            "summary_mode": entry.get("summary_mode"),
                            "source_file": str(source),
                        }
                    )
    return pool


def cap_models_per_question(pool: List[Dict], cap: int) -> List[Dict]:
    if cap <= 0:
        return pool
    grouped: Dict[str, List[Dict]] = defaultdict(list)
    for item in pool:
        grouped[item["question_id"]].append(item)

    capped: List[Dict] = []
    for qid, items in grouped.items():
        items.sort(
            key=lambda d: (
                d.get("bias_score") if d.get("bias_score") is not None else -1e9,
                d.get("model_fitness") if d.get("model_fitness") is not None else -1e9,
            ),
            reverse=True,
        )
        capped.extend(items[:cap])
    return capped


def stratified_sample(
    pool: List[Dict],
    bins: List[Tuple[float, float]],
    weights: Optional[List[float]],
    target: Optional[int],
    rng: random.Random,
) -> Tuple[List[Dict], Dict[str, int]]:
    buckets: Dict[str, List[Dict]] = defaultdict(list)
    for item in pool:
        label = bias_bin_label(item.get("bias_score"), bins)
        buckets[label].append(item)

    if not target or target <= 0:
        selected = pool
        counts = {label: len(items) for label, items in buckets.items()}
        return selected, counts

    ordered_labels = [bias_bin_label((lo + hi) / 2.0, bins) for lo, hi in bins]
    weights = weights or [1.0 for _ in ordered_labels]
    targets = allocate_targets(target, weights)

    selected: List[Dict] = []
    for label, t in zip(ordered_labels, targets):
        bucket = buckets.get(label, [])
        rng.shuffle(bucket)
        selected.extend(bucket[:t])

    remaining_needed = target - len(selected)
    if remaining_needed > 0:
        leftovers: List[Dict] = []
        for label, bucket in buckets.items():
            if label in ordered_labels:
                used_ids = {item["id"] for item in selected}
                bucket = [item for item in bucket if item["id"] not in used_ids]
            leftovers.extend(bucket)
        rng.shuffle(leftovers)
        selected.extend(leftovers[:remaining_needed])

    counts: Dict[str, int] = defaultdict(int)
    for item in selected:
        label = bias_bin_label(item.get("bias_score"), bins)
        counts[label] += 1
    return selected, counts


def write_jsonl(path: Path, rows: List[Dict]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8") as f:
        for row in rows:
            f.write(json.dumps(row, ensure_ascii=False) + "\n")


def main() -> None:
    args = parse_args()
    bins = parse_bias_bins(args.bias_bins)
    weights = (
        [float(w) for w in args.bias_bin_weights.split(",")] if args.bias_bin_weights else None
    )
    rng = random.Random(args.seed)

    pool = load_pool(args.inputs)
    pool = cap_models_per_question(pool, args.model_cap_per_question)

    selected, bin_counts = stratified_sample(pool, bins, weights, args.target_total, rng)

    out_dir = Path(args.output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    write_jsonl(out_dir / "combined_pool.jsonl", pool)
    write_jsonl(out_dir / "selection.jsonl", selected)

    stats = {
        "total_pool": len(pool),
        "total_selected": len(selected),
        "bin_counts": bin_counts,
        "inputs": args.inputs,
        "bias_bins": args.bias_bins,
        "bias_bin_weights": weights,
        "model_cap_per_question": args.model_cap_per_question,
        "target_total": args.target_total,
    }
    with (out_dir / "pool_stats.json").open("w", encoding="utf-8") as f:
        json.dump(stats, f, indent=2)

    print(f"Pool size: {len(pool)} | Selected: {len(selected)}")
    for label, count in sorted(bin_counts.items()):
        print(f"  Bin {label}: {count}")
    print(f"Wrote outputs to {out_dir}")


if __name__ == "__main__":
    main()
