"""Stage 2: load silver labels, drop bad rows, group-split by trace.

Produces train/val/test partitions where every span from a given trace
lives in exactly one partition (prevents leakage from trace-level features
like position or trace-style phrasing).
"""
from __future__ import annotations

import json
from collections import Counter
from pathlib import Path

import numpy as np
from sklearn.model_selection import GroupShuffleSplit


def load_silver(path: Path) -> list[dict]:
    """Load silver_train.jsonl, keep rows with a valid llm_label."""
    rows = []
    with open(path) as f:
        for line in f:
            r = json.loads(line)
            if r.get("llm_label") and r["llm_label"] not in {
                "PARSE_ERROR", "API_FAILURE",
            }:
                rows.append(r)
    return rows


def trace_group_key(row: dict) -> tuple:
    """Identifier that's stable per trace; same value for all spans of a trace."""
    return (row["checkpoint_id"], row["task_name"], row["doc_id"], row["trace_id"])


def split_by_trace(
    rows: list[dict],
    train_frac: float = 0.70,
    val_frac: float = 0.15,
    seed: int = 42,
) -> tuple[list[dict], list[dict], list[dict]]:
    """Group split: 70/15/15 by trace.

    Two-stage: first carve off (val+test) from train, then split val/test.
    Uses sklearn.model_selection.GroupShuffleSplit for reproducibility.
    """
    test_frac = 1.0 - train_frac - val_frac
    assert 0 < test_frac < 1.0
    groups = np.array([hash(trace_group_key(r)) for r in rows])

    splitter1 = GroupShuffleSplit(n_splits=1, train_size=train_frac, random_state=seed)
    train_idx, holdout_idx = next(splitter1.split(rows, groups=groups))

    holdout_groups = groups[holdout_idx]
    val_size = val_frac / (val_frac + test_frac)
    splitter2 = GroupShuffleSplit(
        n_splits=1, train_size=val_size, random_state=seed + 1,
    )
    rel_val_idx, rel_test_idx = next(
        splitter2.split([rows[i] for i in holdout_idx], groups=holdout_groups)
    )
    val_idx = holdout_idx[rel_val_idx]
    test_idx = holdout_idx[rel_test_idx]

    train = [rows[i] for i in train_idx]
    val = [rows[i] for i in val_idx]
    test = [rows[i] for i in test_idx]
    return train, val, test


def report_split(name: str, rows: list[dict]) -> None:
    n_traces = len({trace_group_key(r) for r in rows})
    counts = Counter(r["llm_label"] for r in rows)
    print(f"  {name:>5}: {len(rows):>5} spans, {n_traces:>4} traces")
    for label, c in counts.most_common():
        print(f"           {label:13} {c:>5} ({c/len(rows)*100:>5.1f}%)")


def main():
    """Standalone diagnostic — load, split, print summary."""
    import argparse
    ap = argparse.ArgumentParser()
    ap.add_argument("--silver", required=True, type=Path)
    ap.add_argument("--seed", type=int, default=42)
    args = ap.parse_args()

    rows = load_silver(args.silver)
    print(f"Loaded {len(rows)} valid silver rows from {args.silver}")
    train, val, test = split_by_trace(rows, seed=args.seed)
    print()
    report_split("train", train)
    report_split("val",   val)
    report_split("test",  test)


if __name__ == "__main__":
    main()
