"""Compare v1 (discourse-marker) and v2 (paragraph) segmenters on a sample.

For each trace in the sample:
    - Segment with v1 (analysis.exploration.segmentation.segment_response)
    - Segment with v2 (segmentation_v2.segment_response_v2)
    - Run the heuristic classifier on each
    - Report span counts, episode counts, primitive distributions

This is read-only on the parquet; produces stats only.
"""
from __future__ import annotations

import argparse
import json
import sys
from collections import Counter
from pathlib import Path

import pandas as pd

from analysis.exploration.segmentation import segment_response
from analysis.exploration.primitive_classification import (
    classify_trace_spans,
    merge_episodes,
)
from analysis.exploration.llm_validation.segmentation_v2 import (
    segment_response_v2,
)
from analysis.exploration.llm_validation.audit_drift import (
    CHECKPOINT_DIR_MAP,
    find_samples_jsonl,
)


def load_responses(rows: pd.DataFrame, raw_root: Path) -> dict[tuple, str]:
    out: dict[tuple, str] = {}
    grouped = rows.groupby(["checkpoint_id", "task_name"])
    for (ckpt, task), grp in grouped:
        jsonl = find_samples_jsonl(raw_root, ckpt, task)
        wanted_doc_ids = set(int(d) for d in grp["doc_id"].unique())
        wanted_pairs = set((int(r.doc_id), int(r.trace_id)) for r in grp.itertuples())
        with open(jsonl) as f:
            for line in f:
                d = json.loads(line)
                doc_id = int(d.get("doc_id", d.get("idx", 0)))
                if doc_id not in wanted_doc_ids:
                    continue
                resps = d.get("resps", [[]])
                responses = resps[0] if resps and isinstance(resps[0], list) else resps
                for tid, resp in enumerate(responses):
                    if (doc_id, tid) in wanted_pairs:
                        out[(ckpt, task, doc_id, tid)] = resp
    return out


def trace_stats(response: str) -> dict:
    """Return v1 vs v2 segmentation + classification stats for one response."""
    spans_v1 = segment_response(response)
    spans_v2 = segment_response_v2(response)

    labeled_v1 = classify_trace_spans(spans_v1)
    labeled_v2 = classify_trace_spans(spans_v2)
    eps_v1 = merge_episodes(labeled_v1)
    eps_v2 = merge_episodes(labeled_v2)

    def _counts(eps):
        return dict(Counter(e.label for e in eps))

    return {
        "n_spans_v1": len(spans_v1),
        "n_spans_v2": len(spans_v2),
        "n_episodes_v1": len(eps_v1),
        "n_episodes_v2": len(eps_v2),
        "primitive_seq_v1": [e.label for e in eps_v1],
        "primitive_seq_v2": [e.label for e in eps_v2],
        "primitive_counts_v1": _counts(eps_v1),
        "primitive_counts_v2": _counts(eps_v2),
        "tokens_per_span_v1_median": (
            sorted(s.n_tokens for s in spans_v1)[len(spans_v1)//2]
            if spans_v1 else 0
        ),
        "tokens_per_span_v2_median": (
            sorted(s.n_tokens for s in spans_v2)[len(spans_v2)//2]
            if spans_v2 else 0
        ),
    }


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--parquet", required=True, type=Path)
    ap.add_argument("--raw-results-root", required=True, type=Path)
    ap.add_argument("--n", type=int, default=50)
    ap.add_argument("--seed", type=int, default=17)
    ap.add_argument("--out", required=True, type=Path,
                    help="Per-trace JSONL stats")
    args = ap.parse_args()

    df = pd.read_parquet(args.parquet)
    sample = df.sample(n=args.n, random_state=args.seed).reset_index(drop=True)
    print(f"Comparing v1 vs v2 segmenters on {len(sample)} traces.")
    responses = load_responses(sample, args.raw_results_root)

    args.out.parent.mkdir(parents=True, exist_ok=True)
    rows: list[dict] = []
    agg_v1: Counter = Counter()
    agg_v2: Counter = Counter()
    n_spans_v1 = 0
    n_spans_v2 = 0
    n_episodes_v1 = 0
    n_episodes_v2 = 0

    for r in sample.itertuples():
        key = (r.checkpoint_id, r.task_name, int(r.doc_id), int(r.trace_id))
        resp = responses.get(key)
        if resp is None:
            continue
        try:
            st = trace_stats(resp)
        except Exception as e:
            print(f"  skip {key}: {e}", file=sys.stderr)
            continue
        st["key"] = list(key)
        st["task_name"] = r.task_name
        st["correct"] = bool(r.correct)
        rows.append(st)
        agg_v1.update(st["primitive_counts_v1"])
        agg_v2.update(st["primitive_counts_v2"])
        n_spans_v1 += st["n_spans_v1"]
        n_spans_v2 += st["n_spans_v2"]
        n_episodes_v1 += st["n_episodes_v1"]
        n_episodes_v2 += st["n_episodes_v2"]

    with open(args.out, "w") as f:
        for row in rows:
            f.write(json.dumps(row) + "\n")
    print(f"Wrote {len(rows)} rows -> {args.out}")

    print()
    print(f"=== Aggregate over {len(rows)} traces ===")
    print(f"  Total spans    : v1={n_spans_v1:6d}  v2={n_spans_v2:6d}  ratio={n_spans_v2/max(1,n_spans_v1):.2f}")
    print(f"  Total episodes : v1={n_episodes_v1:6d}  v2={n_episodes_v2:6d}  ratio={n_episodes_v2/max(1,n_episodes_v1):.2f}")
    print(f"  Avg spans/trace: v1={n_spans_v1/len(rows):.1f}  v2={n_spans_v2/len(rows):.1f}")
    print(f"  Avg eps/trace  : v1={n_episodes_v1/len(rows):.1f}  v2={n_episodes_v2/len(rows):.1f}")

    print()
    print("=== Heuristic primitive distribution (across all episodes) ===")
    PRIMS = ["PLAN", "DECOMPOSE", "ENUMERATE", "HYPOTHESIZE", "COMPUTE",
             "VERIFY", "ERROR_DETECT", "BACKTRACK", "SUMMARIZE", "OTHER"]
    print(f"  {'class':12} {'v1':>8} {'v1 %':>7} {'v2':>8} {'v2 %':>7} {'Δ %':>6}")
    tot_v1 = sum(agg_v1.values()) or 1
    tot_v2 = sum(agg_v2.values()) or 1
    for p in PRIMS:
        c1 = agg_v1.get(p, 0)
        c2 = agg_v2.get(p, 0)
        p1 = c1 / tot_v1 * 100
        p2 = c2 / tot_v2 * 100
        print(f"  {p:12} {c1:>8} {p1:>6.1f}% {c2:>8} {p2:>6.1f}% {p2-p1:+6.1f}")


if __name__ == "__main__":
    main()
