"""Stage 0 — preflight: confirm re-segmentation reproduces parquet labels.

Sampling/judging are blocked unless the re-run produces identical
`primitive_sequence` lists for >=95% of audited rows. Drift means the
(episode_idx -> primitive) mapping has broken since the parquet was built
(tokenizer cache, code edit, etc.) and any sampling would attach the
wrong heuristic label to spans.
"""
from __future__ import annotations

import argparse
import json
import sys
from collections import defaultdict
from pathlib import Path

import pandas as pd

from analysis.exploration.segmentation import segment_response
from analysis.exploration.primitive_classification import (
    classify_trace_spans,
    merge_episodes,
)


CHECKPOINT_DIR_MAP = {
    "dsr_sft_v1": "dsr_sft",
    "gspo_exploration_fix_step20": "gspo_step20",
}


def find_samples_jsonl(raw_root: Path, checkpoint_id: str, task_name: str) -> Path:
    ckpt_dir = CHECKPOINT_DIR_MAP[checkpoint_id]
    candidates = sorted((raw_root / ckpt_dir / task_name).rglob("samples_*.jsonl"))
    if not candidates:
        raise FileNotFoundError(
            f"No samples_*.jsonl under {raw_root}/{ckpt_dir}/{task_name}"
        )
    return candidates[0]


def resolve_responses(rows: pd.DataFrame, raw_root: Path) -> dict[tuple, str]:
    """Return {(checkpoint_id, task_name, doc_id, trace_id): response_text}."""
    out: dict[tuple, str] = {}
    grouped = rows.groupby(["checkpoint_id", "task_name"])
    for (ckpt, task), grp in grouped:
        jsonl_path = find_samples_jsonl(raw_root, ckpt, task)
        wanted_doc_ids = set(int(d) for d in grp["doc_id"].unique())
        wanted_pairs = set(
            (int(r.doc_id), int(r.trace_id)) for r in grp.itertuples()
        )
        with open(jsonl_path) as f:
            for line in f:
                d = json.loads(line)
                doc_id = int(d.get("doc_id", d.get("idx", 0)))
                if doc_id not in wanted_doc_ids:
                    continue
                resps = d.get("resps", [[]])
                responses = resps[0] if resps and isinstance(resps[0], list) else resps
                for tid, resp in enumerate(responses):
                    if (doc_id, tid) in wanted_pairs:
                        out[(ckpt, task, doc_id, tid)] = resp
    return out


def reproduce_sequence(response_text: str) -> list[str]:
    """Re-run segmentation + classification + merge_episodes; return label list."""
    spans = segment_response(response_text)
    labeled = classify_trace_spans(spans)
    episodes = merge_episodes(labeled)
    return [ep.label for ep in episodes]


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--parquet", required=True, type=Path)
    ap.add_argument("--raw-results-root", required=True, type=Path)
    ap.add_argument("--n", type=int, default=50)
    ap.add_argument("--seed", type=int, default=17)
    ap.add_argument(
        "--gate", type=float, default=0.95,
        help="Required match rate to declare drift acceptable",
    )
    args = ap.parse_args()

    df = pd.read_parquet(args.parquet)
    sample = df.sample(n=args.n, random_state=args.seed).reset_index(drop=True)
    print(f"Auditing {len(sample)} traces from {args.parquet}")

    responses = resolve_responses(sample, args.raw_results_root)

    n_match = 0
    n_missing = 0
    mismatches: list[dict] = []
    length_diffs: dict[int, int] = defaultdict(int)

    for row in sample.itertuples():
        key = (row.checkpoint_id, row.task_name, int(row.doc_id), int(row.trace_id))
        resp = responses.get(key)
        if resp is None:
            n_missing += 1
            continue
        stored = json.loads(row.primitive_sequence)
        try:
            reproduced = reproduce_sequence(resp)
        except Exception as e:
            mismatches.append({
                "key": key,
                "stored_len": len(stored),
                "reproduced_len": None,
                "error": str(e),
                "stored": stored,
                "reproduced": None,
            })
            continue
        if reproduced == stored:
            n_match += 1
        else:
            length_diffs[len(reproduced) - len(stored)] += 1
            if len(mismatches) < 3:
                mismatches.append({
                    "key": key,
                    "stored_len": len(stored),
                    "reproduced_len": len(reproduced),
                    "stored": stored,
                    "reproduced": reproduced,
                })

    n_audited = len(sample) - n_missing
    rate = n_match / n_audited if n_audited else 0.0

    print(f"\nResults:")
    print(f"  matched          : {n_match} / {n_audited}  ({rate*100:.1f}%)")
    print(f"  missing responses: {n_missing}")
    if length_diffs:
        print(f"  length-delta histogram (reproduced - stored):")
        for delta in sorted(length_diffs):
            print(f"    {delta:+d}: {length_diffs[delta]}")

    if mismatches:
        print(f"\nFirst {len(mismatches)} mismatches:")
        for i, m in enumerate(mismatches, 1):
            print(f"  [{i}] {m['key']}: stored_len={m['stored_len']} "
                  f"reproduced_len={m['reproduced_len']}")
            if m.get("error"):
                print(f"       error: {m['error']}")
                continue
            print(f"       stored    : {m['stored']}")
            print(f"       reproduced: {m['reproduced']}")

    if rate < args.gate:
        print(f"\nFAIL: match rate {rate*100:.1f}% < gate {args.gate*100:.1f}%")
        print("Sampling and judging are blocked until drift is resolved.")
        sys.exit(1)
    print(f"\nOK: match rate {rate*100:.1f}% >= gate {args.gate*100:.1f}%")


if __name__ == "__main__":
    main()
