#!/usr/bin/env python3
"""
Assign selected question-model pairs to participants.

Each question-model pair must be answered by exactly K participants.
"""

import argparse
import csv
import json
import math
import random
import string
from pathlib import Path
from typing import Dict, List, Tuple


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Assign survey items to participants.")
    parser.add_argument(
        "--selection",
        required=True,
        help="Path to selection.jsonl produced by build_dataset.py.",
    )
    parser.add_argument(
        "--output-dir",
        default="human_study/output",
        help="Directory to write assignments and codes.",
    )
    parser.add_argument(
        "--participants",
        type=int,
        default=75,
        help="Number of participants.",
    )
    parser.add_argument(
        "--questions-per-participant",
        type=int,
        default=30,
        help="Number of questions per participant.",
    )
    parser.add_argument(
        "--k-per-question",
        type=int,
        default=3,
        help="Number of participants answering each question-model pair.",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=17,
        help="Random seed for shuffling.",
    )
    return parser.parse_args()


def read_selection(path: str) -> List[Dict]:
    items: List[Dict] = []
    with Path(path).open("r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            items.append(json.loads(line))
    return items


def generate_code(rng: random.Random, length: int = 8) -> str:
    alphabet = string.ascii_uppercase + string.digits
    return "".join(rng.choice(alphabet) for _ in range(length))


def bias_bin_label(value: float, bins: List[Tuple[float, float]]) -> str:
    try:
        v = float(value)
    except (TypeError, ValueError):
        return "unknown"
    if math.isnan(v):
        return "unknown"
    for lo, hi in bins:
        if lo <= v < hi:
            return f"{lo}-{hi}"
    return "out_of_range"


def assign(
    items: List[Dict],
    participant_count: int,
    questions_per_participant: int,
    k_per_question: int,
    rng: random.Random,
) -> Tuple[List[Dict], List[Dict]]:
    bins_default = [(1.0, 2.0), (2.0, 3.0), (3.0, 4.0), (4.0, 5.0)]
    required_items = participant_count * questions_per_participant
    if len(items) * k_per_question != required_items:
        raise ValueError(
            f"Mismatch: need {required_items} assignments "
            f"but have {len(items)} items * {k_per_question}."
        )
    if participant_count < k_per_question:
        raise ValueError("participant_count must be >= k_per_question to avoid duplicate assignments per participant.")

    # Prepare list of (item_idx repeated k times)
    expanded: List[int] = []
    for idx in range(len(items)):
        expanded.extend([idx] * k_per_question)
    rng.shuffle(expanded)

    # Precompute targets to balance attributes and bias bins
    attr_totals: Dict[str, int] = {}
    bin_totals: Dict[str, int] = {}
    item_bins: List[str] = []
    for idx, item in enumerate(items):
        attr = item.get("attribute") or "unknown_attr"
        bin_label = bias_bin_label(item.get("bias_score"), bins_default)
        item_bins.append(bin_label)
        attr_totals[attr] = attr_totals.get(attr, 0) + k_per_question
        bin_totals[bin_label] = bin_totals.get(bin_label, 0) + k_per_question

    attr_target = {k: v / participant_count for k, v in attr_totals.items()}
    bin_target = {k: v / participant_count for k, v in bin_totals.items()}

    per_participant_items: List[List[str]] = [[] for _ in range(participant_count)]
    attr_counts: List[Dict[str, int]] = [{} for _ in range(participant_count)]
    bin_counts: List[Dict[str, int]] = [{} for _ in range(participant_count)]

    item_assigned: Dict[str, set] = {}

    def pick_participant(attr: str, bin_label: str, item_id: str) -> int:
        best_pid = None
        best_score = None
        already = item_assigned.get(item_id, set())
        for pid in range(participant_count):
            if len(per_participant_items[pid]) >= questions_per_participant:
                continue
            if pid in already:
                continue
            total_fill = len(per_participant_items[pid]) / questions_per_participant
            next_attr = attr_counts[pid].get(attr, 0) + 1
            next_bin = bin_counts[pid].get(bin_label, 0) + 1
            attr_gap = max(next_attr - attr_target.get(attr, 0), 0) / (
                attr_target.get(attr, 0) + 1
            )
            bin_gap = max(next_bin - bin_target.get(bin_label, 0), 0) / (
                bin_target.get(bin_label, 0) + 1
            )
            score = total_fill + attr_gap + bin_gap
            if best_score is None or score < best_score:
                best_score = score
                best_pid = pid
            elif score == best_score and rng.random() < 0.5:
                best_pid = pid
        if best_pid is None:
            raise RuntimeError("No available participant slots to assign item")
        return best_pid

    for idx in expanded:
        item = items[idx]
        attr = item.get("attribute") or "unknown_attr"
        bin_label = item_bins[idx]
        pid = pick_participant(attr, bin_label, item["id"])
        per_participant_items[pid].append(item["id"])
        attr_counts[pid][attr] = attr_counts[pid].get(attr, 0) + 1
        bin_counts[pid][bin_label] = bin_counts[pid].get(bin_label, 0) + 1
        item_assigned.setdefault(item["id"], set()).add(pid)

    participant_assignments: List[Dict] = []
    for pid, item_list in enumerate(per_participant_items):
        assignment = {
            "participant_id": f"p{pid:04d}",
            "item_ids": item_list,
        }
        participant_assignments.append(assignment)

    item_assignments: Dict[str, List[str]] = {}
    for pa in participant_assignments:
        for item_id in pa["item_ids"]:
            item_assignments.setdefault(item_id, []).append(pa["participant_id"])

    item_records: List[Dict] = []
    for item in items:
        item_records.append(
            {
                "id": item["id"],
                "question_id": item.get("question_id"),
                "model_id": item.get("model_id"),
                "attribute": item.get("attribute"),
                "bias_score": item.get("bias_score"),
                "participants": item_assignments.get(item["id"], []),
            }
        )

    return participant_assignments, item_records


def main() -> None:
    args = parse_args()
    rng = random.Random(args.seed)
    items = read_selection(args.selection)

    participant_assignments, item_records = assign(
        items,
        args.participants,
        args.questions_per_participant,
        args.k_per_question,
        rng,
    )

    out_dir = Path(args.output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    codes = []
    for pa in participant_assignments:
        codes.append({"participant_id": pa["participant_id"], "completion_code": generate_code(rng)})

    with (out_dir / "participants.jsonl").open("w", encoding="utf-8") as f:
        for pa in participant_assignments:
            f.write(json.dumps(pa) + "\n")
    with (out_dir / "question_assignments.jsonl").open("w", encoding="utf-8") as f:
        for item in item_records:
            f.write(json.dumps(item) + "\n")
    with (out_dir / "codes.csv").open("w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerow(["participant_id", "completion_code"])
        for code in codes:
            writer.writerow([code["participant_id"], code["completion_code"]])

    stats = {
        "participants": args.participants,
        "questions_per_participant": args.questions_per_participant,
        "k_per_question": args.k_per_question,
        "items": len(items),
    }
    with (out_dir / "assignment_stats.json").open("w", encoding="utf-8") as f:
        json.dump(stats, f, indent=2)

    print(f"Wrote assignments to {out_dir}")
    print(f"Participants: {args.participants} | Items: {len(items)}")


if __name__ == "__main__":
    main()
