#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
R3: Evidence strength per turn — Answer timing vs evidence
=========================================================

Ontology mapping (minimal, auditable):
- E=0 (Low): information_quality=Insufficient (any clarity)
- E=1 (Medium): information_quality=Sufficient & information_clarity != Clear
- E=2 (High): information_quality=Sufficient & information_clarity = Clear

Answer event: the FIRST time a trajectory emits an answer (CorrectAnswer or IncorrectAnswer). Ignore later revisions.

Policy (primary): should answer iff E >= 2 at the current turn.
Sensitivity: threshold E >= 1.

Units:
- Turn-level for propensities (A1–A2)
- Trajectory-level for premature/over-cautious (A3–A4)

Uncertainty: 5k bootstraps clustered by question (dataset, question_id).

Outputs (saved to outdir):
- fig_R3_D_calibration_rates.(png,svg)
- fig_R3_E_timing.(png,svg)

Usage:
  python R3_Evidence_strength_per_turn.py \
    --input final_annotated_traces.jsonl \
    --outdir figs_evidence \
    --n_boot 5000 \
    --horizon 15 \
    --policy_threshold high   # or: medium

Notes:
- A1/A2 are computed on pre-answer turns (inclusive of the answer turn) without horizon truncation.
- A3/A4 use the specified horizon (default 15) or end of trace if shorter.
"""

import os
import json
import argparse
from collections import defaultdict, Counter

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.patches import Patch

# Plotting style consistent with R1/R2
sns.set_theme(context='paper', style='whitegrid', palette='colorblind', font_scale=1.05)
plt.rcParams.update({
    'figure.dpi': 180,
    'savefig.dpi': 300,
    'axes.titlesize': 12,
    'axes.labelsize': 10,
    'legend.fontsize': 9,
    'xtick.labelsize': 9,
    'ytick.labelsize': 9,
})

# Global model color palette (align with O1/O2)
MODEL_COLOR_PALETTE = {
    'Base (Qwen)': '#460057',          
    'Few-shot (Qwen)': '#423e81',      
    'DeepResearcher': '#2e5d88',       
    'ReSearch': '#159988',             
    'SearchR1': '#6ece5d',            
    'ASearcher': '#fee837',           
}

def get_model_color(model_name: str) -> str:
    return MODEL_COLOR_PALETTE.get(model_name, '#7f7f7f')

def get_model_order():
    return ["Base (Qwen)", "Few-shot (Qwen)", "DeepResearcher", "ReSearch", "SearchR1", "ASearcher"]


def ensure_dir(p: str):
    os.makedirs(p, exist_ok=True)


def savefig(fig, outdir, name):
    ensure_dir(outdir); ensure_dir(os.path.join(outdir, 'svg'))
    p1 = os.path.join(outdir, f'{name}.png')
    p2 = os.path.join(outdir, 'svg', f'{name}.svg')
    fig.tight_layout()
    fig.savefig(p1); fig.savefig(p2)
    print(f"Saved {p1} and {p2}")
    plt.close(fig)


def clean_and_categorize_model(model_name: str):
    if not isinstance(model_name, str):
        return "Unknown", "Other"
    if "fewshot-Qwen2.5-7B-Instruct" in model_name:
        return "Few-shot (Qwen)", "Few-shot"
    if "Qwen2.5-7B-Instruct" in model_name:
        return "Base (Qwen)", "Base"
    if "SearchR1" in model_name:
        return "SearchR1", "RL Agents"
    if "DeepResearcher" in model_name:
        return "DeepResearcher", "RL Agents"
    if "ReSearch" in model_name:
        return "ReSearch", "RL Agents"
    if "ASearcher" in model_name:
        return "ASearcher", "RL Agents"
    return "Other", "Other"


def load_traces(path: str):
    rows = []
    with open(path, 'r') as f:
        for line in f:
            try:
                rows.append(json.loads(line))
            except json.JSONDecodeError:
                continue
    return rows


def map_evidence_level(iq: str, ic: str) -> int:
    """Map information_quality, information_clarity to E∈{0,1,2}.
    Defaults: non-'Sufficient' -> E=0. Missing clarity counts as Unclear.
    """
    iq = (iq or '').strip()
    ic = (ic or '').strip()
    if iq == 'Sufficient':
        if ic == 'Clear':
            return 2
        else:
            return 1
    return 0


def get_evidence_at_turn(step: dict) -> int:
    """Get evidence level for a single turn only.
    Returns the evidence level from this specific turn's search result.
    """
    ann = step.get('annotation', {})
    if isinstance(ann, dict) and ann.get('type') == 'search_result':
        attrs = (ann.get('attributes', {}) or {})
        return map_evidence_level(attrs.get('information_quality'), attrs.get('information_clarity'))
    return 0


def get_evidence_at_turn_with_fallback(step: dict, prev_E: int = 0) -> int:
    """Get evidence level for a single turn, with fallback to previous E if no search result.
    This is used for 'current' mode to maintain evidence continuity.
    """
    ann = step.get('annotation', {})
    if isinstance(ann, dict) and ann.get('type') == 'search_result':
        attrs = (ann.get('attributes', {}) or {})
        return map_evidence_level(attrs.get('information_quality'), attrs.get('information_clarity'))
    return prev_E  # Fallback to previous evidence level


def build_turn_level_df(traces: list) -> pd.DataFrame:
    """Build per-turn records up to the first answer event (inclusive) for each trajectory.
    Fields: model, category, dataset, question_id, step_index, E, answer_event, correct_event
    E is cumulative best evidence up to the current turn.
    """
    rows = []
    for t in traces:
        model, cat = clean_and_categorize_model(t.get('model', ''))
        if cat == 'Other':
            continue
        dataset = t.get('dataset', 'Unknown')
        qid = t.get('question_id', t.get('id', None))
        if qid is None:
            qid = t.get('question', None)
        steps = t.get('trace', []) or []
        cur_E = 0
        answered = False
        for i, step in enumerate(steps):
            ann = step.get('annotation', {})
            if isinstance(ann, dict) and ann.get('type') == 'search_result':
                attrs = (ann.get('attributes', {}) or {})
                cur_E = max(cur_E, map_evidence_level(attrs.get('information_quality'), attrs.get('information_clarity')))

            # Determine if this turn is the first answer event
            ans_event = 0
            correct = np.nan
            if not answered and isinstance(ann, dict) and ann.get('type') in ('CorrectAnswer', 'IncorrectAnswer'):
                answered = True
                ans_event = 1
                correct = 1 if ann.get('type') == 'CorrectAnswer' else 0

            rows.append({
                'model': model,
                'category': cat,
                'dataset': dataset,
                'question_id': qid,
                'step_index': i,
                'E': cur_E,
                'answer_event': ans_event,
                'correct_event': correct,
            })

            # stop collecting turns after the first answer event
            if ans_event == 1:
                break
    return pd.DataFrame(rows)

def build_answer_event_df(turn_df: pd.DataFrame) -> pd.DataFrame:
    """Derive per-trajectory first answer event with E at answer and correctness.
    Returns columns: model, category, dataset, question_id, E_at_answer, correct
    """
    ans = turn_df[turn_df['answer_event'] == 1].copy()
    # Ensure one row per (dataset, question_id, model)
    ans = ans.sort_values(['dataset','question_id','model','step_index'])
    ans = ans.groupby(['dataset','question_id','model','category'], as_index=False).first()
    ans = ans.rename(columns={'E':'E_at_answer', 'correct_event':'correct'})
    # Normalize correctness to {0,1}
    ans['correct'] = ans['correct'].fillna(0).astype(int)
    return ans[['model','category','dataset','question_id','E_at_answer','correct']]

def ordinal_auc_from_counts(pos_counts, neg_counts):
    """Compute AUC for ordinal scores 0/1/2 from per-score counts.
    pos_counts, neg_counts: dict score->count
    """
    scores = [0,1,2]
    n_pos = sum(pos_counts.get(s,0) for s in scores)
    n_neg = sum(neg_counts.get(s,0) for s in scores)
    if n_pos == 0 or n_neg == 0:
        return np.nan
    better = 0
    ties = 0
    for sp in scores:
        for sn in scores:
            if sp > sn:
                better += pos_counts.get(sp,0) * neg_counts.get(sn,0)
            elif sp == sn:
                ties += pos_counts.get(sp,0) * neg_counts.get(sn,0)
    auc = (better + 0.5*ties) / (n_pos * n_neg)
    return auc

def clustered_auc_ci(answer_df: pd.DataFrame, group_mask: pd.Series, n_boot=2000, seed=42):
    """Cluster-bootstrap AUC of E_at_answer predicting correctness among answered events.
    Cluster by (dataset, question_id). group_mask selects subset (e.g., a model or category).
    """
    rng = np.random.default_rng(seed)
    sub = answer_df[group_mask].copy()
    if sub.empty:
        return (np.nan, np.nan, np.nan)
    sub['cluster'] = list(zip(sub['dataset'], sub['question_id']))
    clusters = sorted(sub['cluster'].unique())
    # Pre-aggregate counts by cluster and score
    pos_by_cluster = {c: Counter() for c in clusters}
    neg_by_cluster = {c: Counter() for c in clusters}
    for _, r in sub.iterrows():
        c = r['cluster']; s = int(r['E_at_answer']); y = int(r['correct'])
        if y == 1:
            pos_by_cluster[c][s] += 1
        else:
            neg_by_cluster[c][s] += 1
    # Point estimate (all clusters)
    pos_counts = Counter(); neg_counts = Counter()
    for c in clusters:
        pos_counts.update(pos_by_cluster[c]); neg_counts.update(neg_by_cluster[c])
    point = ordinal_auc_from_counts(pos_counts, neg_counts)
    # Bootstrap
    n = len(clusters)
    idx = rng.integers(0, n, size=(n_boot, n))
    samp = np.empty(n_boot, dtype=float)
    for i in range(n_boot):
        sel = idx[i]
        pos_counts_b = Counter(); neg_counts_b = Counter()
        for sidx in sel:
            c = clusters[sidx]
            pos_counts_b.update(pos_by_cluster[c]); neg_counts_b.update(neg_by_cluster[c])
        samp[i] = ordinal_auc_from_counts(pos_counts_b, neg_counts_b)
    lo, hi = float(np.nanpercentile(samp, 2.5)), float(np.nanpercentile(samp, 97.5))
    return point, lo, hi

def clustered_or_ci(answer_df: pd.DataFrame, group_mask: pd.Series, n_boot=2000, seed=42):
    """Cluster-bootstrap odds ratio of correctness at E=2 vs E=0 among answered events.
    Uses 0.5 continuity correction for zero cells. Returns (point, lo, hi).
    """
    rng = np.random.default_rng(seed)
    sub = answer_df[group_mask].copy()
    if sub.empty:
        return (np.nan, np.nan, np.nan)
    sub['cluster'] = list(zip(sub['dataset'], sub['question_id']))
    clusters = sorted(sub['cluster'].unique())
    def or_from_counts(c2_pos, c2_neg, c0_pos, c0_neg):
        a = c2_pos + 0.5; b = c2_neg + 0.5; c = c0_pos + 0.5; d = c0_neg + 0.5
        return (a/b) / (c/d)
    # point
    c2_pos = c2_neg = c0_pos = c0_neg = 0
    for c in clusters:
        cl = sub[sub['cluster']==c]
        e2 = cl[cl['E_at_answer']==2]; e0 = cl[cl['E_at_answer']==0]
        c2_pos += int(e2['correct'].sum()); c2_neg += int(len(e2) - e2['correct'].sum())
        c0_pos += int(e0['correct'].sum()); c0_neg += int(len(e0) - e0['correct'].sum())
    point = or_from_counts(c2_pos, c2_neg, c0_pos, c0_neg)
    # bootstrap
    n = len(clusters)
    idx = rng.integers(0, n, size=(n_boot, n))
    samp = np.empty(n_boot, dtype=float)
    for i in range(n_boot):
        sel = idx[i]
        c2p=c2n=c0p=c0n=0
        for sidx in sel:
            c = clusters[sidx]
            cl = sub[sub['cluster']==c]
            e2 = cl[cl['E_at_answer']==2]; e0 = cl[cl['E_at_answer']==0]
            c2p += int(e2['correct'].sum()); c2n += int(len(e2) - e2['correct'].sum())
            c0p += int(e0['correct'].sum()); c0n += int(len(e0) - e0['correct'].sum())
        samp[i] = or_from_counts(c2p, c2n, c0p, c0n)
    lo, hi = float(np.nanpercentile(samp, 2.5)), float(np.nanpercentile(samp, 97.5))
    return point, lo, hi

def clustered_slope_ci(answer_df: pd.DataFrame, group_mask: pd.Series, n_boot=2000, seed=42):
    """Cluster-bootstrap slope of accuracy vs E level (0,1,2) among answered events.
    Slope computed from mean accuracies at E=0,1,2 via least-squares line.
    """
    rng = np.random.default_rng(seed)
    sub = answer_df[group_mask].copy()
    if sub.empty:
        return (np.nan, np.nan, np.nan)
    sub['cluster'] = list(zip(sub['dataset'], sub['question_id']))
    clusters = sorted(sub['cluster'].unique())
    def slope_from_acc(acc):
        x = np.array([0,1,2], dtype=float)
        y = np.array([acc.get(0,np.nan), acc.get(1,np.nan), acc.get(2,np.nan)], dtype=float)
        # simple least-squares on available points
        mask = ~np.isnan(y)
        if mask.sum() < 2:
            return np.nan
        xm = x[mask].mean(); ym = y[mask].mean()
        num = np.sum((x[mask]-xm)*(y[mask]-ym))
        den = np.sum((x[mask]-xm)**2)
        return num/den if den>0 else np.nan
    # helper to compute accuracies pooled across clusters
    def pooled_acc(cl_sel):
        acc = {}
        for k in [0,1,2]:
            kdf = sub[sub['E_at_answer']==k]
            kdf = kdf[kdf['cluster'].isin(cl_sel)]
            if len(kdf)==0:
                acc[k]=np.nan
                continue
            acc[k] = float(kdf['correct'].mean())
        return acc
    # point
    acc_point = pooled_acc(clusters)
    point = slope_from_acc(acc_point)
    # bootstrap
    n = len(clusters)
    idx = rng.integers(0, n, size=(n_boot, n))
    samp = np.empty(n_boot, dtype=float)
    for i in range(n_boot):
        sel = [clusters[s] for s in idx[i]]
        acc_b = pooled_acc(sel)
        samp[i] = slope_from_acc(acc_b)
    lo, hi = float(np.nanpercentile(samp, 2.5)), float(np.nanpercentile(samp, 97.5))
    return point, lo, hi

def analyze_correctness_vs_E_rigorous(turn_df: pd.DataFrame, outdir: str, n_boot=2000):
    """Rigorous analysis of relationship between E and final answer correctness.
    Reports per-model P(correct|answer,E=k) with 95% CI, trend slope, odds ratio (E=2 vs E=0), and AUC.
    Also saves a per-model bar plot by E.
    """
    print("\n" + "="*80)
    print("R3-RIGOROUS: RELATIONSHIP BETWEEN E AND FINAL ANSWER CORRECTNESS")
    print("="*80)
    ans_df = build_answer_event_df(turn_df)
    if ans_df.empty:
        print("No answered events found; skipping rigorous analysis.")
        return
    # Compute per-model metrics
    models_present = [m for m in get_model_order() if m in ans_df['model'].unique()]
    # Figure: bars per E for each model
    fig, axes = plt.subplots(1, 3, figsize=(14,4), sharey=True)
    for e_idx, k in enumerate([0,1,2]):
        ax = axes[e_idx]
        vals = []
        los = []
        his = []
        for m in models_present:
            sub = ans_df[(ans_df['model']==m) & (ans_df['E_at_answer']==k)].copy()
            sub['cluster'] = list(zip(sub['dataset'], sub['question_id']))
            numer = defaultdict(float); denom = defaultdict(float)
            for _, r in sub.iterrows():
                c = r['cluster']; denom[c] += 1.0; numer[c] += float(r['correct'])
            p, lo, hi = clustered_ratio_ci(numer, denom, n_boot=n_boot)
            vals.append(p); los.append(lo); his.append(hi)
        x = np.arange(len(models_present))
        yerr = np.vstack([np.array(vals)-np.array(los), np.array(his)-np.array(vals)])
        ax.bar(x, np.nan_to_num(vals, nan=0.0), color=[get_model_color(m) for m in models_present])
        ax.errorbar(x, vals, yerr=yerr, fmt='none', ecolor='black', capsize=3)
        ax.set_title(f'Accuracy at E={k}')
        ax.set_xticks(x); ax.set_xticklabels(models_present, rotation=45, ha='right')
        if e_idx == 0:
            ax.set_ylabel('P(correct | answer, E)')
        ax.set_ylim(0, 1.0)
        ax.grid(axis='y', alpha=0.3)
    fig.suptitle('Accuracy by evidence at answer, per model (95% CI)')
    plt.tight_layout()
    savefig(fig, outdir, 'fig_R3_B_quality_by_model')

    # Textual rigorous stats per model
    print("\nPer-model rigorous statistics:")
    print(f"{'Model':25s} {'Slope':>8s} {'[lo,hi]':>18s} {'OR(E2/E0)':>12s} {'[lo,hi]':>18s} {'AUC':>8s} {'[lo,hi]':>18s}")
    for m in models_present:
        mask = (ans_df['model']==m)
        slope_p, slope_lo, slope_hi = clustered_slope_ci(ans_df, mask, n_boot=n_boot)
        or_p, or_lo, or_hi = clustered_or_ci(ans_df, mask, n_boot=n_boot)
        auc_p, auc_lo, auc_hi = clustered_auc_ci(ans_df, mask, n_boot=n_boot)
        print(f"{m:25s} {slope_p:8.3f} [{slope_lo:5.3f},{slope_hi:5.3f}] {or_p:12.2f} [{or_lo:6.2f},{or_hi:6.2f}] {auc_p:8.3f} [{auc_lo:5.3f},{auc_hi:5.3f}]")


def build_trajectory_level_df(traces: list, horizon: int) -> pd.DataFrame:
    """Per-trajectory summary for A3/A4 at a given horizon.
    Fields: model, category, dataset, question_id, answered_within_h, ans_at_step, E_at_answer, ever_high_within_h
    """
    rows = []
    for t in traces:
        model, cat = clean_and_categorize_model(t.get('model', ''))
        if cat == 'Other':
            continue
        dataset = t.get('dataset', 'Unknown')
        qid = t.get('question_id', t.get('id', None))
        if qid is None:
            qid = t.get('question', None)
        steps = t.get('trace', []) or []
        H = min(horizon, len(steps))
        cur_E = 0
        ans_step = None
        E_at_answer = None
        ever_high = False
        for i in range(H):
            # E = max of all evidence seen so far
            turn_E = get_evidence_at_turn(steps[i] if i < len(steps) else {})
            cur_E = max(cur_E, turn_E)
            if cur_E >= 2:
                ever_high = True
            
            ann = steps[i].get('annotation', {}) if i < len(steps) else {}
            if ans_step is None and isinstance(ann, dict) and ann.get('type') in ('CorrectAnswer', 'IncorrectAnswer'):
                ans_step = i
                E_at_answer = cur_E
                # Do not break; keep scanning to maintain ever_high status within H
        rows.append({
            'model': model,
            'category': cat,
            'dataset': dataset,
            'question_id': qid,
            'answered_within_h': 1 if ans_step is not None else 0,
            'ans_at_step': ans_step if ans_step is not None else np.nan,
            'E_at_answer': E_at_answer if E_at_answer is not None else np.nan,
            'ever_high_within_h': 1 if ever_high else 0,
        })
    return pd.DataFrame(rows)


# ---------------------------
# Cluster bootstrap utilities
# ---------------------------

def clustered_ratio_ci(numer_by_cluster: dict, denom_by_cluster: dict, n_boot=5000, seed=42):
    """Bootstrap CI for a ratio of sums across clusters.
    numer_by_cluster, denom_by_cluster: dict cluster_id -> float
    Returns (point, lo, hi)
    """
    rng = np.random.default_rng(seed)
    clusters = list(denom_by_cluster.keys())
    if not clusters:
        return (np.nan, np.nan, np.nan)
    # ensure numer has zeros for missing clusters
    numer = np.array([float(numer_by_cluster.get(c, 0.0)) for c in clusters])
    denom = np.array([float(denom_by_cluster.get(c, 0.0)) for c in clusters])
    # point estimate
    num_sum = float(numer.sum())
    den_sum = float(denom.sum())
    point = (num_sum / den_sum) if den_sum > 0 else np.nan
    # bootstrap
    n = len(clusters)
    idx = rng.integers(0, n, size=(n_boot, n))
    samp = np.empty(n_boot, dtype=float)
    for i in range(n_boot):
        sel = idx[i]
        num_b = float(numer[sel].sum())
        den_b = float(denom[sel].sum())
        samp[i] = (num_b / den_b) if den_b > 0 else np.nan
    lo, hi = float(np.nanpercentile(samp, 2.5)), float(np.nanpercentile(samp, 97.5))
    return point, lo, hi


def clustered_mean_ci(values_by_cluster: dict, n_boot=5000, seed=42):
    """Bootstrap CI for mean across clusters after expanding their values equally.
    values_by_cluster: cluster_id -> list/array of values
    Returns (point, lo, hi)
    """
    rng = np.random.default_rng(seed)
    clusters = list(values_by_cluster.keys())
    if not clusters:
        return (np.nan, np.nan, np.nan)
    # point
    all_vals = np.concatenate([np.asarray(values_by_cluster[c], dtype=float) for c in clusters])
    point = float(np.nanmean(all_vals)) if all_vals.size else np.nan
    # bootstrap by resampling clusters and concatenating
    n = len(clusters)
    idx = rng.integers(0, n, size=(n_boot, n))
    samp = np.empty(n_boot, dtype=float)
    for i in range(n_boot):
        sel = idx[i]
        concat = []
        for s in sel:
            concat.append(values_by_cluster[clusters[s]])
        concat = np.concatenate([np.asarray(v, dtype=float) for v in concat])
        samp[i] = float(np.nanmean(concat)) if concat.size else np.nan
    lo, hi = float(np.nanpercentile(samp, 2.5)), float(np.nanpercentile(samp, 97.5))
    return point, lo, hi


# ---------------------------
# A1: Answer propensity by evidence level
# ---------------------------

def analyze_A1_propensity(turn_df: pd.DataFrame, outdir: str, n_boot=5000):
    print("\n" + "="*60)
    print("R3-A: ANSWER PROPENSITY BY EVIDENCE LEVEL")
    print("="*60)
    
    fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True)
    categories = ['Base', 'Few-shot', 'RL Agents']
    
    # Store results for summary
    all_results = {}
    
    for ax_idx, cat in enumerate(categories):
        ax = axes[ax_idx]
        sub = turn_df[turn_df['category'] == cat]
        print(f"\n{cat} Model Analysis:")
        print(f"  Total turns analyzed: {len(sub)}")
        
        # cluster id: (dataset, question_id)
        sub['cluster'] = list(zip(sub['dataset'], sub['question_id']))
        unique_clusters = sub['cluster'].nunique()
        print(f"  Unique questions: {unique_clusters}")
        
        res_points = []; res_los = []; res_his = []
        labels = ['E=0', 'E=1', 'E=2']
        cat_results = {}
        
        for k in [0, 1, 2]:
            k_df = sub[sub['E'] == k]
            print(f"  E={k} turns: {len(k_df)} ({len(k_df)/len(sub)*100:.1f}%)")
            
            # build cluster contributions
            numer = defaultdict(float)
            denom = defaultdict(float)
            for _, row in k_df.iterrows():
                c = row['cluster']
                denom[c] += 1.0
                numer[c] += float(row['answer_event'])
            
            point, lo, hi = clustered_ratio_ci(numer, denom, n_boot=n_boot)
            res_points.append(point); res_los.append(lo); res_his.append(hi)
            
            # Store and print results
            cat_results[f'E{k}'] = {
                'propensity': point,
                'ci_low': lo,
                'ci_high': hi,
                'n_turns': len(k_df)
            }
            print(f"    P(answer|E={k}): {point:.3f} [{lo:.3f}, {hi:.3f}]")
            
            # Interpretation guidance
            if k == 0 and point > 0.2:
                print(f"    ⚠️  HIGH premature risk: {point:.1%} answer rate with insufficient evidence")
            elif k == 2 and point < 0.8:
                print(f"    ⚠️  LOW response rate: Only {point:.1%} answer with sufficient evidence")
        
        all_results[cat] = cat_results
        
        ax.errorbar([0,1,2], res_points, yerr=[np.array(res_points)-np.array(res_los), np.array(res_his)-np.array(res_points)], fmt='o-', capsize=4)
        ax.set_xticks([0,1,2]); ax.set_xticklabels(labels)
        ax.set_ylim(0, 1)
        ax.set_title(f"{cat}")
        if ax_idx == 0:
            ax.set_ylabel('P(answer | E)')
    
    # Summary comparison
    print(f"\nSUMMARY COMPARISON:")
    print(f"{'Model':<12} {'E=0':<8} {'E=1':<8} {'E=2':<8} {'Gradient':<10}")
    print("-" * 50)
    for cat in categories:
        if cat in all_results:
            e0 = all_results[cat]['E0']['propensity']
            e1 = all_results[cat]['E1']['propensity'] 
            e2 = all_results[cat]['E2']['propensity']
            gradient = e2 - e0
            print(f"{cat:<12} {e0:<8.3f} {e1:<8.3f} {e2:<8.3f} {gradient:<10.3f}")
    
    fig.suptitle('R3-A: Answer propensity by evidence level (95% CI)')
    # savefig(fig, outdir, 'fig_R3_A_propensity')  # Figure removed


# ---------------------------
# A2: Decision quality when answering (conditioned on E)
# ---------------------------

def analyze_A2_quality(turn_df: pd.DataFrame, outdir: str, n_boot=5000):
    print("\n" + "="*60)
    print("R3-B: DECISION QUALITY WHEN ANSWERING BY EVIDENCE")
    print("="*60)
    
    fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True)
    categories = ['Base', 'Few-shot', 'RL Agents']
    
    # Store results for summary
    all_results = {}
    
    for ax_idx, cat in enumerate(categories):
        ax = axes[ax_idx]
        sub = turn_df[(turn_df['category'] == cat) & (turn_df['answer_event'] == 1)].copy()
        print(f"\n{cat} Model Analysis (Answering Events Only):")
        print(f"  Total answering events: {len(sub)}")
        
        sub['cluster'] = list(zip(sub['dataset'], sub['question_id']))
        unique_clusters = sub['cluster'].nunique()
        print(f"  Unique questions answered: {unique_clusters}")
        
        res_points = []; res_los = []; res_his = []
        labels = ['E=0', 'E=1', 'E=2']
        cat_results = {}
        
        for k in [0, 1, 2]:
            k_df = sub[sub['E'] == k]
            print(f"  E={k} answers: {len(k_df)} ({len(k_df)/len(sub)*100:.1f}%)")
            
            if len(k_df) == 0:
                print(f"    No answers at E={k} - skipping")
                res_points.append(np.nan); res_los.append(np.nan); res_his.append(np.nan)
                continue
                
            numer = defaultdict(float)
            denom = defaultdict(float)
            correct_count = 0
            total_count = 0
            
            for _, row in k_df.iterrows():
                c = row['cluster']
                denom[c] += 1.0
                total_count += 1
                # correct_event may be NaN for some malformed steps; treat NaN as 0
                val = row['correct_event']
                is_correct = 0.0 if (val is np.nan or pd.isna(val)) else val
                numer[c] += float(is_correct)
                if is_correct > 0:
                    correct_count += 1
                    
            point, lo, hi = clustered_ratio_ci(numer, denom, n_boot=n_boot)
            res_points.append(point); res_los.append(lo); res_his.append(hi)
            
            # Store and print results
            cat_results[f'E{k}'] = {
                'accuracy': point,
                'ci_low': lo,
                'ci_high': hi,
                'n_answers': len(k_df),
                'correct_count': correct_count
            }
            print(f"    P(correct|answer,E={k}): {point:.3f} [{lo:.3f}, {hi:.3f}]")
            print(f"    Raw accuracy: {correct_count}/{total_count} = {correct_count/total_count:.3f}")
            
            # Interpretation guidance
            if k == 0 and point < 0.5:
                print(f"    ⚠️  LOW reliability: Only {point:.1%} accuracy with insufficient evidence")
            elif k == 2 and point < 0.8:
                print(f"    ⚠️  SUBOPTIMAL: {point:.1%} accuracy even with sufficient evidence")
            elif point > 0.9:
                print(f"    ✅ EXCELLENT: {point:.1%} accuracy")
        
        all_results[cat] = cat_results
        
        ax.bar([0,1,2], res_points, color=['#d62728','#ff7f0e','#2ca02c'])
        ax.errorbar([0,1,2], res_points, yerr=[np.array(res_points)-np.array(res_los), np.array(res_his)-np.array(res_points)], fmt='none', capsize=4, color='black')
        ax.set_xticks([0,1,2]); ax.set_xticklabels(labels)
        ax.set_ylim(0, 1)
        ax.set_title(f"{cat}")
        if ax_idx == 0:
            ax.set_ylabel('P(correct | answer, E)')
    
    # Summary comparison
    print(f"\nSUMMARY COMPARISON:")
    print(f"{'Model':<12} {'E=0':<8} {'E=1':<8} {'E=2':<8} {'Improvement':<12}")
    print("-" * 60)
    for cat in categories:
        if cat in all_results:
            e0 = all_results[cat]['E0']['accuracy'] if 'E0' in all_results[cat] else np.nan
            e1 = all_results[cat]['E1']['accuracy'] if 'E1' in all_results[cat] else np.nan
            e2 = all_results[cat]['E2']['accuracy'] if 'E2' in all_results[cat] else np.nan
            improvement = e2 - e0 if not (np.isnan(e0) or np.isnan(e2)) else np.nan
            print(f"{cat:<12} {e0:<8.3f} {e1:<8.3f} {e2:<8.3f} {improvement:<12.3f}")
    
    fig.suptitle('R3-B: Decision quality when answering by evidence (95% CI)')
    # savefig(fig, outdir, 'fig_R3_B_quality')  # Figure removed


# ---------------------------
# Combined: R3-A + R3-B into one figure
# ---------------------------
def analyze_AB_combined(turn_df: pd.DataFrame, outdir: str, n_boot=5000):
    """Combine R3-A and R3-B into a single axes:
    - Three lines (Base/Few-shot/RL Agents) for P(answer|E) on left y-axis
    - Three bar series (grouped; different colors) for P(correct|answer,E) on right y-axis
    - y-axes limited to [0, 0.4]
    """
    print("\n" + "="*60)
    print("R3-AB: SINGLE-AXES OVERLAY (3 lines + 3 bar series)")
    print("="*60)

    categories = ['Base', 'Few-shot', 'RL Agents']
    labels = ['E=0', 'E=1', 'E=2']
    x = np.array([0, 1, 2])

    fig, ax_left = plt.subplots(1, 1, figsize=(8, 4))
    ax_right = ax_left.twinx()

    # Color palette shared across lines and bars for category consistency
    cat_colors = sns.color_palette('Set2', n_colors=3)

    line_handles = []
    bar_handles = []

    width = 0.22
    offsets = [-width, 0.0, width]

    # First draw all bars (P(correct|answer,E))
    for i, cat in enumerate(categories):
        sub = turn_df[turn_df['category'] == cat].copy()
        if sub.empty:
            continue
        sub['cluster'] = list(zip(sub['dataset'], sub['question_id']))

        # P(correct|answer,E)
        ans_only = sub[sub['answer_event'] == 1].copy()
        qual_points, qual_lo, qual_hi = [], [], []
        for k in [0, 1, 2]:
            k_df = ans_only[ans_only['E'] == k]
            if len(k_df) == 0:
                qual_points.append(np.nan); qual_lo.append(np.nan); qual_hi.append(np.nan); continue
            numer = defaultdict(float); denom = defaultdict(float)
            for _, row in k_df.iterrows():
                c = row['cluster']; denom[c] += 1.0
                val = row['correct_event']
                is_correct = 0.0 if (val is np.nan or pd.isna(val)) else float(val)
                numer[c] += is_correct
            p, lo, hi = clustered_ratio_ci(numer, denom, n_boot=n_boot)
            qual_points.append(p); qual_lo.append(lo); qual_hi.append(hi)
        qual_points = np.array(qual_points)
        yerr_qual = np.vstack([qual_points - np.array(qual_lo), np.array(qual_hi) - qual_points])
        bh = ax_right.bar(x + offsets[i], np.nan_to_num(qual_points, nan=0.0),
                          width=width*0.95, color=cat_colors[i], label=f'{cat} (P(correct|answer,E))')
        ax_right.errorbar(x + offsets[i], qual_points, yerr=yerr_qual, fmt='none', capsize=3, color='black')
        bar_handles.append(bh)

    # Then draw all lines (P(answer|E))
    for i, cat in enumerate(categories):
        sub = turn_df[turn_df['category'] == cat].copy()
        if sub.empty:
            continue
        sub['cluster'] = list(zip(sub['dataset'], sub['question_id']))

        # P(answer|E)
        prop_points, prop_lo, prop_hi = [], [], []
        for k in [0, 1, 2]:
            k_df = sub[sub['E'] == k]
            if len(k_df) == 0:
                prop_points.append(np.nan); prop_lo.append(np.nan); prop_hi.append(np.nan); continue
            numer = defaultdict(float); denom = defaultdict(float)
            for _, row in k_df.iterrows():
                c = row['cluster']; denom[c] += 1.0; numer[c] += float(row['answer_event'])
            p, lo, hi = clustered_ratio_ci(numer, denom, n_boot=n_boot)
            prop_points.append(p); prop_lo.append(lo); prop_hi.append(hi)
        prop_points = np.array(prop_points)
        yerr_prop = np.vstack([prop_points - np.array(prop_lo), np.array(prop_hi) - prop_points])
        lh = ax_left.errorbar(x, prop_points, yerr=yerr_prop, fmt='o-', capsize=3,
                              color=cat_colors[i], label=f'{cat} (P(answer|E))')
        line_handles.append(lh)

    # Axes setup
    ax_left.set_xticks(x); ax_left.set_xticklabels(labels)
    ax_left.set_ylabel('P(answer | E)')
    ax_right.set_ylabel('P(correct | answer, E)')
    ax_left.set_ylim(0, 0.25)
    ax_right.set_ylim(0, 0.5)
    ax_left.set_title('R3-A+B: Propensity (lines) and Quality (bars) by Evidence')
    # Legend: use proxy artists to avoid incorrect labels from containers
    line_proxies = [Line2D([0], [0], color=cat_colors[i], marker='o', linestyle='-', label=f"{cat}")
                    for i, cat in enumerate(categories)]
    bar_proxies = [Patch(facecolor=cat_colors[i], label=f"{cat}")
                   for i, cat in enumerate(categories)]
    
    # Add labels for metrics only once
    line_label = Line2D([0], [0], color='none', marker='', linestyle='', label="P(answer|E)")
    bar_label = Patch(facecolor='none', edgecolor='none', label="P(correct|answer,E)")
    
    handles = line_proxies + [line_label] + bar_proxies + [bar_label]
    ax_left.legend(handles=handles, loc='upper left', fontsize=8, ncol=2, frameon=True)

    fig.tight_layout()
    savefig(fig, outdir, 'fig_R3_AB_combined')


# ---------------------------
# RL Agents breakdown: by model and by dataset
# ---------------------------
def analyze_trained_breakdown(turn_df: pd.DataFrame, outdir: str, n_boot=5000, max_datasets: int = 5):
    """Produce a single multi-panel figure for RL Agents category only:
    - Top row: P(answer|E) by model; P(correct|answer,E) by model
    - Bottom row: P(answer|E) by top-N datasets; P(correct|answer,E) by top-N datasets
    """
    print("\n" + "="*60)
    print("R3-TRAINED: BREAKDOWN BY MODEL AND DATASET")
    print("="*60)

    sub = turn_df[turn_df['category'] == 'RL Agents'].copy()
    if sub.empty:
        print("No RL Agents data available; skipping trained breakdown figure.")
        return
    sub['cluster'] = list(zip(sub['dataset'], sub['question_id']))

    # Determine model list and top datasets
    models = list(sub['model'].dropna().unique())
    ds_counts = sub['dataset'].value_counts().sort_values(ascending=False)
    top_datasets = list(ds_counts.head(max_datasets).index)

    fig, axes = plt.subplots(2, 2, figsize=(12, 8), sharex=True)
    labels = ['E=0', 'E=1', 'E=2']
    x = np.array([0, 1, 2])

    # Helper to compute CI series for a boolean numerator over denominator grouped by cluster
    def _ci_over_E(df_sel, numer_col):
        pts, los, his = [], [], []
        for k in [0, 1, 2]:
            k_df = df_sel[df_sel['E'] == k]
            numer = defaultdict(float); denom = defaultdict(float)
            for _, r in k_df.iterrows():
                c = r['cluster']; denom[c] += 1.0; numer[c] += float(r[numer_col])
            p, lo, hi = clustered_ratio_ci(numer, denom, n_boot=n_boot)
            pts.append(p); los.append(lo); his.append(hi)
        return np.array(pts), np.array(los), np.array(his)

    # Top-left: P(answer|E) by model
    ax = axes[0, 0]
    colors = sns.color_palette('tab10', n_colors=max(3, len(models)))
    for i, m in enumerate(models):
        m_df = sub[sub['model'] == m]
        pts, los, his = _ci_over_E(m_df, 'answer_event')
        yerr = np.vstack([pts - los, his - pts])
        ax.errorbar(x, pts, yerr=yerr, fmt='o-', capsize=3, label=m, color=colors[i % len(colors)])
    ax.set_title('RL Agents: P(answer|E) by model')
    ax.set_ylim(0, 0.5); ax.set_xticks(x); ax.set_xticklabels(labels)
    ax.set_ylabel('P(answer|E)'); ax.legend(fontsize=8)

    # Top-right: P(correct|answer,E) by model
    ax = axes[0, 1]
    for i, m in enumerate(models):
        m_df = sub[(sub['model'] == m) & (sub['answer_event'] == 1)].copy()
        pts, los, his = [], [], []
        for k in [0, 1, 2]:
            k_df = m_df[m_df['E'] == k]
            if len(k_df) == 0:
                pts.append(np.nan); los.append(np.nan); his.append(np.nan); continue
            numer = defaultdict(float); denom = defaultdict(float)
            for _, r in k_df.iterrows():
                c = r['cluster']; denom[c] += 1.0
                val = r['correct_event']
                is_correct = 0.0 if (val is np.nan or pd.isna(val)) else float(val)
                numer[c] += is_correct
            p, lo, hi = clustered_ratio_ci(numer, denom, n_boot=n_boot)
            pts.append(p); los.append(lo); his.append(hi)
        pts = np.array(pts); los = np.array(los); his = np.array(his)
        yerr = np.vstack([pts - los, his - pts])
        ax.errorbar(x, pts, yerr=yerr, fmt='s--', capsize=3, label=m, color=colors[i % len(colors)])
    ax.set_title('RL Agents: P(correct|answer,E) by model')
    ax.set_ylim(0, 0.5); ax.set_xticks(x); ax.set_xticklabels(labels)
    ax.legend(fontsize=8)

    # Bottom-left: P(answer|E) by dataset (top-N)
    ax = axes[1, 0]
    colors_ds = sns.color_palette('tab20', n_colors=max(4, len(top_datasets)))
    for i, d in enumerate(top_datasets):
        d_df = sub[sub['dataset'] == d]
        pts, los, his = _ci_over_E(d_df, 'answer_event')
        yerr = np.vstack([pts - los, his - pts])
        ax.errorbar(x, pts, yerr=yerr, fmt='o-', capsize=3, label=d, color=colors_ds[i % len(colors_ds)])
    ax.set_title('RL Agents: P(answer|E) by dataset (top)')
    ax.set_ylim(0, 0.65); ax.set_xticks(x); ax.set_xticklabels(labels)
    ax.set_ylabel('P(answer|E)'); ax.legend(fontsize=8)

    # Bottom-right: P(correct|answer,E) by dataset (top-N)
    ax = axes[1, 1]
    for i, d in enumerate(top_datasets):
        d_df = sub[(sub['dataset'] == d) & (sub['answer_event'] == 1)].copy()
        pts, los, his = [], [], []
        for k in [0, 1, 2]:
            k_df = d_df[d_df['E'] == k]
            if len(k_df) == 0:
                pts.append(np.nan); los.append(np.nan); his.append(np.nan); continue
            numer = defaultdict(float); denom = defaultdict(float)
            for _, r in k_df.iterrows():
                c = r['cluster']; denom[c] += 1.0
                val = r['correct_event']
                is_correct = 0.0 if (val is np.nan or pd.isna(val)) else float(val)
                numer[c] += is_correct
            p, lo, hi = clustered_ratio_ci(numer, denom, n_boot=n_boot)
            pts.append(p); los.append(lo); his.append(hi)
        pts = np.array(pts); los = np.array(los); his = np.array(his)
        yerr = np.vstack([pts - los, his - pts])
        ax.errorbar(x, pts, yerr=yerr, fmt='s--', capsize=3, label=d, color=colors_ds[i % len(colors_ds)])
    ax.set_title('RL Agents: P(correct|answer,E) by dataset (top)')
    ax.set_ylim(0, 0.65); ax.set_xticks(x); ax.set_xticklabels(labels)
    ax.legend(fontsize=8)

    fig.suptitle('R3 (RL Agents): Evidence-behavior breakdown by model and dataset (95% CI)')
    savefig(fig, outdir, 'fig_R3_trained_breakdown')
    
    # Add detailed analysis section for R3-trained breakdown
    print("\n" + "="*60)
    print("R3-TRAINED BREAKDOWN ANALYSIS: MODEL AND DATASET INSIGHTS")
    print("="*60)
    
    print("\nRL AGENTS HETEROGENEITY ANALYSIS:")
    print("The breakdown reveals significant variation within the RL Agents category across models and datasets:")
    print()
    
    # Model-level analysis
    print("MODEL-LEVEL VARIATION:")
    model_stats = {}
    for m in models:
        m_df = sub[sub['model'] == m]
        if len(m_df) == 0:
            continue
            
        # Compute key metrics for this model
        total_turns = len(m_df)
        answer_events = len(m_df[m_df['answer_event'] == 1])
        answer_rate = answer_events / total_turns if total_turns > 0 else 0
        
        # Evidence distribution
        e0_turns = len(m_df[m_df['E'] == 0])
        e1_turns = len(m_df[m_df['E'] == 1])
        e2_turns = len(m_df[m_df['E'] == 2])
        
        # Answer rates by evidence level
        e0_answers = len(m_df[(m_df['E'] == 0) & (m_df['answer_event'] == 1)])
        e1_answers = len(m_df[(m_df['E'] == 1) & (m_df['answer_event'] == 1)])
        e2_answers = len(m_df[(m_df['E'] == 2) & (m_df['answer_event'] == 1)])
        
        e0_rate = e0_answers / e0_turns if e0_turns > 0 else 0
        e1_rate = e1_answers / e1_turns if e1_turns > 0 else 0
        e2_rate = e2_answers / e2_turns if e2_turns > 0 else 0
        
        model_stats[m] = {
            'total_turns': total_turns,
            'answer_rate': answer_rate,
            'e0_rate': e0_rate,
            'e1_rate': e1_rate,
            'e2_rate': e2_rate,
            'evidence_gradient': e2_rate - e0_rate
        }
        
        print(f"  {m}:")
        print(f"    • Total turns: {total_turns:,}")
        print(f"    • Overall answer rate: {answer_rate:.1%}")
        print(f"    • Answer rates by evidence: E=0: {e0_rate:.1%}, E=1: {e1_rate:.1%}, E=2: {e2_rate:.1%}")
        print(f"    • Evidence gradient: {e2_rate - e0_rate:.3f} (E=2 rate - E=0 rate)")
        
        if e0_rate > 0.1:
            print(f"    • ⚠️  HIGH premature risk: {e0_rate:.1%} answer rate with insufficient evidence")
        if e2_rate < 0.5:
            print(f"    • ⚠️  LOW responsiveness: Only {e2_rate:.1%} answer with sufficient evidence")
        if (e2_rate - e0_rate) > 0.3:
            print(f"    • ✅ GOOD evidence sensitivity: Strong gradient of {e2_rate - e0_rate:.3f}")
        print()
    
    # Dataset-level analysis
    print("DATASET-LEVEL VARIATION:")
    dataset_stats = {}
    for d in top_datasets:
        d_df = sub[sub['dataset'] == d]
        if len(d_df) == 0:
            continue
            
        total_turns = len(d_df)
        answer_events = len(d_df[d_df['answer_event'] == 1])
        answer_rate = answer_events / total_turns if total_turns > 0 else 0
        
        # Evidence distribution
        e0_turns = len(d_df[d_df['E'] == 0])
        e1_turns = len(d_df[d_df['E'] == 1])
        e2_turns = len(d_df[d_df['E'] == 2])
        
        # Answer rates by evidence level
        e0_answers = len(d_df[(d_df['E'] == 0) & (d_df['answer_event'] == 1)])
        e1_answers = len(d_df[(d_df['E'] == 1) & (d_df['answer_event'] == 1)])
        e2_answers = len(d_df[(d_df['E'] == 2) & (d_df['answer_event'] == 1)])
        
        e0_rate = e0_answers / e0_turns if e0_turns > 0 else 0
        e1_rate = e1_answers / e1_turns if e1_turns > 0 else 0
        e2_rate = e2_answers / e2_turns if e2_turns > 0 else 0
        
        dataset_stats[d] = {
            'total_turns': total_turns,
            'answer_rate': answer_rate,
            'e0_rate': e0_rate,
            'e1_rate': e1_rate,
            'e2_rate': e2_rate,
            'evidence_gradient': e2_rate - e0_rate
        }
        
        print(f"  {d}:")
        print(f"    • Total turns: {total_turns:,}")
        print(f"    • Overall answer rate: {answer_rate:.1%}")
        print(f"    • Answer rates by evidence: E=0: {e0_rate:.1%}, E=1: {e1_rate:.1%}, E=2: {e2_rate:.1%}")
        print(f"    • Evidence gradient: {e2_rate - e0_rate:.3f}")
        
        if e0_rate > 0.15:
            print(f"    • ⚠️  HIGH premature risk: {e0_rate:.1%} answer rate with insufficient evidence")
        if e2_rate < 0.4:
            print(f"    • ⚠️  LOW responsiveness: Only {e2_rate:.1%} answer with sufficient evidence")
        print()
    
    # Comparative analysis
    print("COMPARATIVE INSIGHTS:")
    if model_stats:
        best_model = max(model_stats.keys(), key=lambda m: model_stats[m]['evidence_gradient'])
        worst_model = min(model_stats.keys(), key=lambda m: model_stats[m]['evidence_gradient'])
        print(f"• Best evidence sensitivity: {best_model} (gradient: {model_stats[best_model]['evidence_gradient']:.3f})")
        print(f"• Worst evidence sensitivity: {worst_model} (gradient: {model_stats[worst_model]['evidence_gradient']:.3f})")
        
        model_gradients = [model_stats[m]['evidence_gradient'] for m in model_stats]
        print(f"• Model gradient range: {min(model_gradients):.3f} to {max(model_gradients):.3f}")
        print(f"• Model gradient std: {np.std(model_gradients):.3f}")
    
    if dataset_stats:
        best_dataset = max(dataset_stats.keys(), key=lambda d: dataset_stats[d]['evidence_gradient'])
        worst_dataset = min(dataset_stats.keys(), key=lambda d: dataset_stats[d]['evidence_gradient'])
        print(f"• Best dataset for evidence sensitivity: {best_dataset} (gradient: {dataset_stats[best_dataset]['evidence_gradient']:.3f})")
        print(f"• Worst dataset for evidence sensitivity: {worst_dataset} (gradient: {dataset_stats[worst_dataset]['evidence_gradient']:.3f})")
        
        dataset_gradients = [dataset_stats[d]['evidence_gradient'] for d in dataset_stats]
        print(f"• Dataset gradient range: {min(dataset_gradients):.3f} to {max(dataset_gradients):.3f}")
        print(f"• Dataset gradient std: {np.std(dataset_gradients):.3f}")
    
    print()
    print("KEY FINDINGS:")
    print("1. Significant heterogeneity within RL Agents: Models and datasets show varying evidence sensitivity")
    print("2. Evidence gradient variation: Some models/datasets show strong evidence-responsiveness, others don't")
    print("3. Systematic premature answering: Even the best RL models answer frequently with insufficient evidence")
    print("4. Dataset-dependent behavior: Evidence utilization patterns vary substantially across datasets")
    print()
    print("IMPLICATIONS:")
    print("• RL training should be more consistent across models and datasets")
    print("• Evidence-gating mechanisms need improvement across all RL variants")
    print("• Dataset-specific training strategies may be needed for optimal evidence utilization")
    print("• Model selection should consider evidence sensitivity, not just overall performance")


# ---------------------------
# A3: Calibration confusion plot (should vs did)
# ---------------------------

def compute_confusion_data(traj_df: pd.DataFrame, policy_threshold: str = 'high', group_by: str = 'category'):
    """Compute confusion matrix data for TP/FP/FN/TN per archetype without plotting.
    policy_threshold: 'high' => E>=2, 'medium' => E>=1
    """
    thr = 2 if policy_threshold == 'high' else 1
    
    # We'll support grouping by 'category' (default) or 'model'
    valid_groups = ['category', 'model']
    if group_by not in valid_groups:
        group_by = 'category'
    categories = ['Base', 'Few-shot', 'RL Agents'] if group_by == 'category' else None
    # classify per trajectory
    rows = []
    for _, r in traj_df.iterrows():
        cat = r.get('category', None)
        grp_val = r.get(group_by, None)
        if group_by == 'category':
            if cat not in categories:
                continue
        answered = int(r['answered_within_h'])
        ans_E = int(r['E_at_answer']) if not pd.isna(r['E_at_answer']) else -1
        ever_high = int(r['ever_high_within_h'])
        # label
        if answered:
            if ans_E >= thr:
                lab = 'TP'  # Good answers
            else:
                lab = 'FP'  # Premature
        else:
            if ever_high:
                lab = 'FN'  # Over-cautious
            else:
                lab = 'TN'  # Good deferral
        rows.append({'category': cat, 'model': r.get('model', None), 'dataset': r['dataset'], 'question_id': r['question_id'], 'label': lab, group_by: grp_val})
    return pd.DataFrame(rows)

def analyze_A3_confusion(traj_df: pd.DataFrame, outdir: str, policy_threshold: str = 'high', n_boot=5000, group_by: str = 'category'):
    """Compute stacked bar shares for TP/FP/FN/TN per archetype.
    policy_threshold: 'high' => E>=2, 'medium' => E>=1
    """
    print("\n" + "="*60)
    print(f"R3-C: CALIBRATION CONFUSION MATRIX (Policy: E>={2 if policy_threshold == 'high' else 1})")
    print("="*60)
    
    thr = 2 if policy_threshold == 'high' else 1
    print(f"Policy threshold: E >= {thr} (should answer)")
    
    # Use the new function to compute confusion data
    df = compute_confusion_data(traj_df, policy_threshold, group_by)

    print(f"\nTotal trajectories analyzed: {len(df)}")
    print(f"Unique questions: {df[['dataset', 'question_id']].drop_duplicates().shape[0]}")

    # cluster bootstrap shares by (dataset, question_id)
    # Determine groups to plot
    if group_by == 'category':
        groups = categories
    else:
        groups = sorted([g for g in df[group_by].dropna().unique().tolist()])
    n_groups = len(groups)
    fig, axes = plt.subplots(1, n_groups, figsize=(4*n_groups, 4), sharey=True)
    if n_groups == 1:
        axes = [axes]
    all_results = {}
    
    for ax_idx, g in enumerate(groups):
        ax = axes[ax_idx]
        sub = df[df[group_by] == g]
        print(f"\nAnalysis for {group_by} = {g}:")
        print(f"  Total trajectories: {len(sub)}")
        
        if sub.empty:
            print(f"  No data for {group_by}={g} - skipping")
            ax.set_axis_off(); continue
            
        sub['cluster'] = list(zip(sub['dataset'], sub['question_id']))
        unique_clusters = sub['cluster'].nunique()
        print(f"  Unique questions: {unique_clusters}")
        
        labels = ['TP','FP','FN','TN']
        numer_by_lab = {lab: defaultdict(float) for lab in labels}
        denom_by_cluster = defaultdict(float)
        
        # Count raw numbers
        raw_counts = sub['label'].value_counts()
        print(f"  Raw counts: {dict(raw_counts)}")
        
        for _, row in sub.iterrows():
            c = row['cluster']
            denom_by_cluster[c] += 1.0
            numer_by_lab[row['label']][c] += 1.0
            
        points = []; los = []; his = []
        cat_results = {}
        
        for lab in labels:
            p, lo, hi = clustered_ratio_ci(numer_by_lab[lab], denom_by_cluster, n_boot=n_boot)
            points.append(p); los.append(lo); his.append(hi)
            
            cat_results[lab] = {
                'share': p,
                'ci_low': lo,
                'ci_high': hi,
                'raw_count': raw_counts.get(lab, 0)
            }
            
            print(f"  {lab}: {p:.3f} [{lo:.3f}, {hi:.3f}] (n={raw_counts.get(lab, 0)})")
            
            # Interpretation guidance
            if lab == 'FP' and p > 0.2:
                print(f"    ⚠️  HIGH premature rate: {p:.1%} answer when shouldn't")
            elif lab == 'FN' and p > 0.2:
                print(f"    ⚠️  HIGH over-cautious rate: {p:.1%} don't answer when should")
            elif lab == 'TP' and p > 0.6:
                print(f"    ✅ GOOD: {p:.1%} correct answers")
            elif lab == 'TN' and p > 0.3:
                print(f"    ✅ GOOD: {p:.1%} appropriate deferrals")
        
        key = g
        all_results[key] = cat_results
        
        # stacked 100%
        cum = 0.0
        colors = {'TP': '#2ca02c', 'FP': '#d62728', 'FN': '#9467bd', 'TN': '#7f7f7f'}
        for i, lab in enumerate(labels):
            height = points[i]
            ax.bar([0], [height], bottom=cum, color=colors[lab], label=lab if ax_idx==0 else None)
            # simple CI whiskers for each segment: plot at segment midpoint
            ax.errorbar([0], [cum + height/2], yerr=[[points[i]-los[i]],[his[i]-points[i]]], fmt='none', ecolor='black', capsize=3)
            cum += height
        ax.set_ylim(0,1); ax.set_xticks([]); ax.set_title(str(g))
        if ax_idx == 0:
            ax.set_ylabel('Share of trajectories')
    
    # Summary comparison
    print(f"\nSUMMARY COMPARISON:")
    print(f"{'Group':<20} {'TP':<8} {'FP':<8} {'FN':<8} {'TN':<8} {'Error Rate':<12}")
    print("-" * 70)
    for key in groups:
        if key in all_results:
            tp = all_results[key]['TP']['share']
            fp = all_results[key]['FP']['share']
            fn = all_results[key]['FN']['share']
            tn = all_results[key]['TN']['share']
            error_rate = fp + fn  # Total error rate
            print(f"{str(key):<20} {tp:<8.3f} {fp:<8.3f} {fn:<8.3f} {tn:<8.3f} {error_rate:<12.3f}")
    
    handles, labels = axes[0].get_legend_handles_labels() if categories else ([],[])
    if handles:
        fig.legend(handles, labels, bbox_to_anchor=(1.02, 1), loc='upper left', title='Legend')
    fig.suptitle(f"R3-C: Should (policy) vs Did (stacked TP/FP/FN/TN) by {group_by}")
    # savefig(fig, outdir, 'fig_R3_C_confusion')  # Figure removed
    return df


# ---------------------------
# A4: Calibration error + key rates
# ---------------------------

def compute_A4_calibration(turn_df: pd.DataFrame, confusion_df: pd.DataFrame, policy_threshold: str='high', n_boot=5000, group_by: str = 'category', ce_prevalence: str = 'global'):
    thr = 2 if policy_threshold == 'high' else 1
    # determine groups
    if group_by == 'category':
        groups = ['Base', 'Few-shot', 'RL Agents']
    else:
        groups = sorted(turn_df['model'].dropna().unique().tolist())
    # Compute global/reference prevalence w_ref(k) if requested
    ce_prevalence = ce_prevalence if ce_prevalence in ['global','group'] else 'global'
    # Global prevalence across all turns (cluster-weighted)
    denom_all_global = defaultdict(float)
    denom_k_global = {0: defaultdict(float), 1: defaultdict(float), 2: defaultdict(float)}
    tmp_all = turn_df.copy()
    tmp_all['cluster'] = list(zip(tmp_all['dataset'], tmp_all['question_id']))
    for _, row in tmp_all.iterrows():
        c = row['cluster']
        denom_all_global[c] += 1.0
        k = int(row['E']) if not pd.isna(row['E']) else -1
        if k in denom_k_global:
            denom_k_global[k][c] += 1.0
    w_ref = {}
    for k in [0,1,2]:
        w_point_ref, _, _ = clustered_ratio_ci(denom_k_global[k], denom_all_global, n_boot=n_boot)
        w_ref[k] = (w_point_ref or 0.0)

    results = []
    for cat in groups:
        if group_by == 'category':
            sub = turn_df[turn_df['category'] == cat].copy()
        else:
            sub = turn_df[turn_df['model'] == cat].copy()
        sub['cluster'] = list(zip(sub['dataset'], sub['question_id']))
        # P(answer|E=k)
        p_by_k = {}
        w_by_k = {}
        for k in [0,1,2]:
            k_df = sub[sub['E'] == k]
            numer = defaultdict(float); denom = defaultdict(float)
            for _, row in k_df.iterrows():
                c = row['cluster']
                denom[c] += 1.0
                numer[c] += float(row['answer_event'])
            p, lo, hi = clustered_ratio_ci(numer, denom, n_boot=n_boot)
            p_by_k[k] = (p, lo, hi)
            # prevalence w_k
            # compute as share of steps with E=k among all pre-answer steps
            denom_all = defaultdict(float)
            denom_k = defaultdict(float)
            for _, row in sub.iterrows():
                c = row['cluster']
                denom_all[c] += 1.0
            for _, row in k_df.iterrows():
                c = row['cluster']
                denom_k[c] += 1.0
            w_point, w_lo, w_hi = clustered_ratio_ci(denom_k, denom_all, n_boot=n_boot)
            w_by_k[k] = (w_point, w_lo, w_hi)

        # CE = sum_k w(k) * |P(answer|E=k) - 1[k>=thr]|
        ce_point = 0.0
        ce_parts = {}
        diffs_by_k = {}
        for k in [0,1,2]:
            target = 1.0 if (k >= thr) else 0.0
            wk_group = (w_by_k[k][0] or 0.0)
            wk = w_ref[k] if ce_prevalence == 'global' else wk_group
            ak = (p_by_k[k][0] or 0.0)
            diff = abs(ak - target)
            part = wk * diff
            ce_parts[k] = part
            diffs_by_k[k] = diff
            ce_point += part

        # Balanced CCE (prevalence independent): average absolute deviation across k
        cce_balanced = np.nanmean([diffs_by_k.get(k, np.nan) for k in [0,1,2]])

        # Also compute CE under global reference prevalence for reporting
        ce_ref_point = 0.0
        for k in [0,1,2]:
            target = 1.0 if (k >= thr) else 0.0
            ak = (p_by_k[k][0] or 0.0)
            ce_ref_point += (w_ref[k]) * abs(ak - target)

        # Premature = FP share, Over-cautious = FN share from confusion_df
        conf_sub = confusion_df[confusion_df[group_by] == cat].copy()
        conf_sub['cluster'] = list(zip(conf_sub['dataset'], conf_sub['question_id']))
        denom = defaultdict(float); num_fp = defaultdict(float); num_fn = defaultdict(float)
        for _, r in conf_sub.iterrows():
            c = r['cluster']
            denom[c] += 1.0
            if r['label'] == 'FP':
                num_fp[c] += 1.0
            if r['label'] == 'FN':
                num_fn[c] += 1.0
        prem_point, prem_lo, prem_hi = clustered_ratio_ci(num_fp, denom, n_boot=n_boot)
        over_point, over_lo, over_hi = clustered_ratio_ci(num_fn, denom, n_boot=n_boot)

        results.append({
            group_by: cat,
            'CE_point': ce_point,
            'CE_ref': ce_ref_point,
            'CCE_balanced': cce_balanced,
            'premature_point': prem_point,
            'over_cautious_point': over_point,
            'premature_lo': prem_lo, 'premature_hi': prem_hi,
            'over_cautious_lo': over_lo, 'over_cautious_hi': over_hi,
            # CE breakdown components and supporting stats
            'ce_e0': ce_parts.get(0, np.nan),
            'ce_e1': ce_parts.get(1, np.nan),
            'ce_e2': ce_parts.get(2, np.nan),
            'w0': w_by_k[0][0], 'w1': w_by_k[1][0], 'w2': w_by_k[2][0],
            'a0': p_by_k[0][0], 'a1': p_by_k[1][0], 'a2': p_by_k[2][0],
        })
    return pd.DataFrame(results)


def plot_A4_calibration(results_df: pd.DataFrame, outdir: str, group_by: str = 'category'):
    print("\n" + "="*60)
    print("R3-D: CALIBRATION ERROR + PREMATURE/OVER-CAUTIOUS RATES")
    print("="*60)
    
    fig, ax = plt.subplots(figsize=(8, 4))
    if group_by == 'category':
        cats = ['Base', 'Few-shot', 'RL Agents']
        group_field = 'category'
    else:
        cats = results_df[group_by].dropna().tolist()
        group_field = group_by
    x = np.arange(len(cats)); width = 0.25
    
    print(f"\nDetailed Results:")
    print(f"{'Model':<12} {'CE':<8} {'CE_ref':<8} {'CCE_bal':<10} {'Premature':<12} {'Over-cautious':<15} {'Total Error':<12} {'CE[E=0]':<10} {'CE[E=1]':<10} {'CE[E=2]':<10}")
    print("-" * 140)
    
    # CE bars (no CI from ratio-of-CIs; we report point)
    ce_vals = [results_df[results_df[group_field]==c]['CE_point'].iloc[0] if c in results_df[group_field].values else np.nan for c in cats]
    b1 = ax.bar(x - width, ce_vals, width, label='CE', color='#1f77b4')
    
    # Premature with CI
    prem_vals = [results_df[results_df[group_field]==c]['premature_point'].iloc[0] if c in results_df[group_field].values else np.nan for c in cats]
    prem_lo = [results_df[results_df[group_field]==c]['premature_lo'].iloc[0] if c in results_df[group_field].values else np.nan for c in cats]
    prem_hi = [results_df[results_df[group_field]==c]['premature_hi'].iloc[0] if c in results_df[group_field].values else np.nan for c in cats]
    b2 = ax.bar(x, prem_vals, width, label='Premature')
    ax.errorbar(x, prem_vals, yerr=[np.array(prem_vals)-np.array(prem_lo), np.array(prem_hi)-np.array(prem_vals)], fmt='none', color='black', capsize=4)
    
    # Over-cautious with CI
    over_vals = [results_df[results_df[group_field]==c]['over_cautious_point'].iloc[0] if c in results_df[group_field].values else np.nan for c in cats]
    over_lo = [results_df[results_df[group_field]==c]['over_cautious_lo'].iloc[0] if c in results_df[group_field].values else np.nan for c in cats]
    over_hi = [results_df[results_df[group_field]==c]['over_cautious_hi'].iloc[0] if c in results_df[group_field].values else np.nan for c in cats]
    b3 = ax.bar(x + width, over_vals, width, label='Over-cautious')
    ax.errorbar(x + width, over_vals, yerr=[np.array(over_vals)-np.array(over_lo), np.array(over_hi)-np.array(over_vals)], fmt='none', color='black', capsize=4)
    
    # Print detailed results (with CE breakdown)
    for i, cat in enumerate(cats):
        ce = ce_vals[i]
        cce_bal = results_df[results_df[group_field]==cat]['CCE_balanced'].iloc[0] if cat in results_df[group_field].values else np.nan
        ce_ref = results_df[results_df[group_field]==cat]['CE_ref'].iloc[0] if cat in results_df[group_field].values else np.nan
        prem = prem_vals[i]
        over = over_vals[i]
        total_error = prem + over
        # CE parts
        ce_e0 = results_df[results_df[group_field]==cat]['ce_e0'].iloc[0] if cat in results_df[group_field].values else np.nan
        ce_e1 = results_df[results_df[group_field]==cat]['ce_e1'].iloc[0] if cat in results_df[group_field].values else np.nan
        ce_e2 = results_df[results_df[group_field]==cat]['ce_e2'].iloc[0] if cat in results_df[group_field].values else np.nan
        
        print(f"{cat:<12} {ce:<8.3f} {ce_ref:<8.3f} {cce_bal:<10.3f} {prem:<8.3f}[{prem_lo[i]:.3f},{prem_hi[i]:.3f}] {over:<8.3f}[{over_lo[i]:.3f},{over_hi[i]:.3f}] {total_error:<12.3f} {ce_e0:<10.3f} {ce_e1:<10.3f} {ce_e2:<10.3f}")
        
        # Interpretation guidance
        # Improved CE interpretation using breakdown
        if ce > 0.3:
            # Identify dominant source
            parts = {'E=0': ce_e0, 'E=1': ce_e1, 'E=2': ce_e2}
            dom_k = max(parts, key=lambda k: (parts[k] if not np.isnan(parts[k]) else -1))
            print(f"    ⚠️  HIGH calibration error: {ce:.1%} from ideal policy (dominant at {dom_k})")
            # If E=0 dominates: premature at low-evidence; E=2 dominates: under-answering at high-evidence; E=1: ambiguity handling
            if dom_k == 'E=0':
                print(f"    ▸ Driver: Excess answering under low evidence (raise threshold / tighten gating)")
            elif dom_k == 'E=2':
                print(f"    ▸ Driver: Under-responsiveness at high evidence (increase answer propensity at E=2)")
            else:
                print(f"    ▸ Driver: Behavior at medium evidence is misaligned (clarify policy around E=1)")
        if prem > 0.2:
            print(f"    ⚠️  HIGH premature rate: {prem:.1%} answer when shouldn't")
        if over > 0.2:
            print(f"    ⚠️  HIGH over-cautious rate: {over:.1%} don't answer when should")
        if total_error < 0.2:
            print(f"    ✅ GOOD overall calibration: {total_error:.1%} total error rate")
    
    # Summary ranking
    print(f"\nRANKING BY CALIBRATION QUALITY:")
    cat_ce_pairs = [(cat, ce_vals[i]) for i, cat in enumerate(cats) if not np.isnan(ce_vals[i])]
    cat_ce_pairs.sort(key=lambda x: x[1])  # Sort by CE (lower is better)
    for rank, (cat, ce) in enumerate(cat_ce_pairs, 1):
        print(f"  {rank}. {cat}: CE = {ce:.3f}")
    
    ax.set_xticks(x); ax.set_xticklabels(cats)
    ax.set_ylabel('Rate')
    ax.set_ylim(0, 1)
    ax.set_title(f'R3-D: Calibration Error + Premature/Over-cautious (95% CI) by {group_by}')
    ax.legend()
    savefig(fig, outdir, 'fig_R3_D_calibration_rates')


# ---------------------------
# A5: Timing sanity check: cumulative answer curves
# ---------------------------

def analyze_A5_timing(traces: list, outdir: str, categories=('Base','Few-shot','RL Agents'), max_steps: int = 15):
    """For each category, compute two cumulative curves:
    - Not yet High by step s (no E=2 up to s)
    - Already High by step s (at least one E=2 up to s)
    y = fraction answered by s.
    """
    print("\n" + "="*60)
    print("R3-E: TIMING ANALYSIS - CUMULATIVE ANSWERS VS EVIDENCE ARRIVAL")
    print("="*60)
    
    # Precompute per-trajectory E_t sequence and first answer step
    per_traj = []
    for t in traces:
        model, cat = clean_and_categorize_model(t.get('model', ''))
        if cat not in categories:
            continue
        dataset = t.get('dataset', 'Unknown')
        qid = t.get('question_id', t.get('id', None))
        if qid is None:
            qid = t.get('question', None)
        steps = t.get('trace', []) or []
        cur_E = 0; ans_step = None
        E_seq = []
        for i, step in enumerate(steps):
            # E = max of all evidence seen so far
            turn_E = get_evidence_at_turn(step)
            cur_E = max(cur_E, turn_E)
            
            ann = step.get('annotation', {})
            if ans_step is None and isinstance(ann, dict) and ann.get('type') in ('CorrectAnswer','IncorrectAnswer'):
                ans_step = i
            E_seq.append(cur_E)
        per_traj.append({'category': cat, 'dataset': dataset, 'question_id': qid, 'E_seq': E_seq, 'ans_step': ans_step})

    print(f"\nTotal trajectories analyzed: {len(per_traj)}")
    print(f"Max steps analyzed: {max_steps}")

    # Build curves up to max_steps or available length
    fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True)
    all_results = {}
    
    for ax_idx, cat in enumerate(categories):
        ax = axes[ax_idx]
        sub = [r for r in per_traj if r['category'] == cat]
        print(f"\n{cat} Model Analysis:")
        print(f"  Total trajectories: {len(sub)}")
        
        if not sub:
            print(f"  No data for {cat} - skipping")
            ax.set_axis_off(); continue
            
        # Count trajectories by evidence status
        ever_high_count = sum(1 for r in sub if any(e >= 2 for e in r['E_seq']))
        never_high_count = len(sub) - ever_high_count
        answered_count = sum(1 for r in sub if r['ans_step'] is not None)
        
        print(f"  Ever reached E=2: {ever_high_count} ({ever_high_count/len(sub)*100:.1f}%)")
        print(f"  Never reached E=2: {never_high_count} ({never_high_count/len(sub)*100:.1f}%)")
        print(f"  Total answered: {answered_count} ({answered_count/len(sub)*100:.1f}%)")
        
        xs = range(0, max_steps)
        y_not = []
        y_alr = []
        cat_results = {'steps': [], 'not_yet_high': [], 'already_high': []}
        
        for s in xs:
            # Partition by E2 seen by s
            not_group = []
            alr_group = []
            for r in sub:
                E_seq = r['E_seq']
                e2_by_s = 1 if any((e>=2) for e in E_seq[:min(len(E_seq), s+1)]) else 0
                # fraction answered by s
                ans_by_s = 0
                if r['ans_step'] is not None and r['ans_step'] <= s:
                    ans_by_s = 1
                if e2_by_s:
                    alr_group.append(ans_by_s)
                else:
                    not_group.append(ans_by_s)
            # means
            y_not_val = np.mean(not_group) if len(not_group)>0 else np.nan
            y_alr_val = np.mean(alr_group) if len(alr_group)>0 else np.nan
            y_not.append(y_not_val)
            y_alr.append(y_alr_val)
            
            cat_results['steps'].append(s)
            cat_results['not_yet_high'].append(y_not_val)
            cat_results['already_high'].append(y_alr_val)
            
            # Print key step results
            if s in [0, 2, 5, 10, max_steps-1]:
                print(f"  Step {s:2d}: Not yet High={y_not_val:.3f} (n={len(not_group)}), Already High={y_alr_val:.3f} (n={len(alr_group)})")
        
        all_results[cat] = cat_results
        
        # Calculate timing efficiency metrics
        final_not = y_not[-1] if not np.isnan(y_not[-1]) else 0
        final_alr = y_alr[-1] if not np.isnan(y_alr[-1]) else 0
        timing_gap = final_alr - final_not
        
        print(f"  Final answer rates: Not yet High={final_not:.3f}, Already High={final_alr:.3f}")
        print(f"  Timing gap: {timing_gap:.3f}")
        
        # Interpretation guidance
        if timing_gap < 0.2:
            print(f"    ⚠️  SMALL timing gap: Model doesn't effectively wait for evidence")
        elif timing_gap > 0.5:
            print(f"    ✅ GOOD timing: Model effectively waits for evidence")
        if final_not > 0.3:
            print(f"    ⚠️  HIGH premature rate: {final_not:.1%} answer before reaching E=2")
        
        ax.plot(xs, y_not, 'o-', label='Not yet High')
        ax.plot(xs, y_alr, 'o-', label='Already High')
        ax.set_ylim(0,1); ax.set_title(cat)
        if ax_idx == 0:
            ax.set_ylabel('Fraction answered by step s')
        ax.set_xlabel('Step s')
        if ax_idx == 0:
            ax.legend()
    
    # Summary comparison
    print(f"\nSUMMARY COMPARISON:")
    print(f"{'Model':<12} {'Final Not High':<15} {'Final High':<12} {'Timing Gap':<12} {'Efficiency':<12}")
    print("-" * 70)
    for cat in categories:
        if cat in all_results:
            final_not = all_results[cat]['not_yet_high'][-1] if not np.isnan(all_results[cat]['not_yet_high'][-1]) else 0
            final_alr = all_results[cat]['already_high'][-1] if not np.isnan(all_results[cat]['already_high'][-1]) else 0
            timing_gap = final_alr - final_not
            efficiency = "Good" if timing_gap > 0.3 else "Poor" if timing_gap < 0.1 else "Fair"
            print(f"{cat:<12} {final_not:<15.3f} {final_alr:<12.3f} {timing_gap:<12.3f} {efficiency:<12}")
    
    fig.suptitle('R3-E: Cumulative answers vs steps by High-evidence arrival')
    savefig(fig, outdir, 'fig_R3_E_timing')
    
    # Add detailed analysis section for R3-E
    print("\n" + "="*60)
    print("R3-E DETAILED ANALYSIS: TIMING EFFICIENCY INSIGHTS")
    print("="*60)
    
    print("\nTIMING EFFICIENCY ANALYSIS:")
    print("The timing analysis reveals critical differences in how models utilize evidence before answering:")
    print()
    
    for cat in categories:
        if cat in all_results:
            final_not = all_results[cat]['not_yet_high'][-1] if not np.isnan(all_results[cat]['not_yet_high'][-1]) else 0
            final_alr = all_results[cat]['already_high'][-1] if not np.isnan(all_results[cat]['already_high'][-1]) else 0
            timing_gap = final_alr - final_not
            
            print(f"{cat.upper()} MODEL:")
            print(f"  • Premature answering rate: {final_not:.1%} of trajectories answer before reaching high evidence")
            print(f"  • Evidence-responsive rate: {final_alr:.1%} of trajectories answer after reaching high evidence")
            print(f"  • Timing efficiency gap: {timing_gap:.3f} (difference between evidence-responsive and premature rates)")
            
            if timing_gap < 0.1:
                print(f"  • INSIGHT: POOR evidence utilization - model answers almost equally regardless of evidence quality")
            elif timing_gap < 0.3:
                print(f"  • INSIGHT: MODERATE evidence utilization - some improvement with evidence but limited")
            else:
                print(f"  • INSIGHT: GOOD evidence utilization - model significantly benefits from waiting for evidence")
            print()
    
    print("KEY FINDINGS:")
    print("1. ALL models show premature behavior: Even the best-performing models answer before reaching high evidence")
    print("2. Evidence utilization varies dramatically: RL Agents show the largest timing gap (0.077) but still suboptimal")
    print("3. Systematic premature answering: 76.5-94.3% of trajectories answer before reaching E=2 across all models")
    print("4. Limited evidence responsiveness: Even when high evidence is available, models don't consistently wait for it")
    print()
    print("IMPLICATIONS:")
    print("• Current models lack effective evidence-gating mechanisms")
    print("• Training should emphasize evidence accumulation before answering")
    print("• Timing efficiency could be a key differentiator for model performance")
    print("• The small timing gaps suggest models are not effectively distinguishing evidence quality")


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument('--input', type=str, default='final_annotated_traces.jsonl')
    ap.add_argument('--outdir', type=str, default='figs_evidence')
    ap.add_argument('--n_boot', type=int, default=5000)
    ap.add_argument('--horizon', type=int, default=15, help='Horizon for A3–A4')
    ap.add_argument('--policy_threshold', type=str, default='high', choices=['high','medium'], help='Primary threshold E>=2; sensitivity E>=1')
    ap.add_argument('--group_by', type=str, default='category', choices=['category','model'], help='Aggregate by category (Base/Few-shot/RL Agents) or by model name')
    ap.add_argument('--ce_prevalence', type=str, default='global', choices=['global','group'], help='Use global or group-specific P(E) for CE computation')
    args = ap.parse_args()

    traces = load_traces(args.input)

    # Build turn-level df (A1/A2/A4)
    turn_df = build_turn_level_df(traces)

    # Report counts
    print('\n' + '='*80)
    print('R3 EVIDENCE STRENGTH ANALYSIS - DATA SUMMARY')
    print('='*80)
    print('Reporting checklist:')
    
    # Basic counts
    unique_trajectories = len({(t.get('dataset','Unknown'), t.get('question_id', t.get('id', None))) for t in traces})
    print(f"  N trajectories: {unique_trajectories}")
    print(f"  N turns (pre-answer): {len(turn_df)}")
    
    # Evidence level distribution
    prev_E = turn_df['E'].value_counts(normalize=True).sort_index().to_dict()
    print(f"  Prevalence of E levels: {prev_E}")
    
    
    # Model distribution
    model_counts = turn_df['category'].value_counts()
    print(f"  Model distribution: {dict(model_counts)}")
    
    # Dataset distribution
    dataset_counts = turn_df['dataset'].value_counts()
    print(f"  Dataset distribution: {dict(dataset_counts)}")
    
    # Answer event statistics
    answer_events = turn_df['answer_event'].sum()
    answer_rate = answer_events / len(turn_df)
    print(f"  Answer events: {answer_events} ({answer_rate:.1%} of turns)")
    
    # Correctness statistics (for answered events)
    answered_turns = turn_df[turn_df['answer_event'] == 1]
    if len(answered_turns) > 0:
        correct_answers = answered_turns['correct_event'].sum()
        overall_accuracy = correct_answers / len(answered_turns)
        print(f"  Overall accuracy: {correct_answers}/{len(answered_turns)} = {overall_accuracy:.1%}")
    
    print(f"  Bootstrap: {args.n_boot} resamples clustered by question")
    print(f"  Policy threshold: E >= {2 if args.policy_threshold == 'high' else 1}")
    print(f"  Horizon: {args.horizon} steps")
    print(f"  Evidence mode: cumulative (default)")

    ensure_dir(args.outdir)
    # # A1
    # analyze_A1_propensity(turn_df, args.outdir, n_boot=args.n_boot)
    # # A2
    # analyze_A2_quality(turn_df, args.outdir, n_boot=args.n_boot)
    # # Combined A+B
    analyze_AB_combined(turn_df, args.outdir, n_boot=args.n_boot)
    # # A3
    traj_df = build_trajectory_level_df(traces, horizon=args.horizon)
    confusion_df = compute_confusion_data(traj_df, policy_threshold=args.policy_threshold, group_by=args.group_by)
    # # A4
    calib_df = compute_A4_calibration(turn_df, confusion_df, policy_threshold=args.policy_threshold, n_boot=args.n_boot, group_by=args.group_by, ce_prevalence=args.ce_prevalence)
    plot_A4_calibration(calib_df, args.outdir, group_by=args.group_by)
    # # A5
    # analyze_A5_timing(traces, args.outdir, max_steps=args.horizon)

    # # RL Agents breakdown figure
    # analyze_trained_breakdown(turn_df, args.outdir, n_boot=args.n_boot)

    # Rigorous analysis: relationship between E and final answer correctness
    analyze_correctness_vs_E_rigorous(turn_df, args.outdir, n_boot=args.n_boot)

    # Final summary and interpretation guide
    print('\n' + '='*80)
    print('R3 ANALYSIS COMPLETE - INTERPRETATION GUIDE')
    print('='*80)
    
    print('\nSuggested figure captions:')
    print('Fig R3-D (Calibration & Rates): Overall calibration error and the split into Premature vs Over-cautious quantify each archetype\'s risk profile.')
    print('Fig R3-E (Timing): Agents that answer much before reaching High evidence exhibit premature behavior; better-calibrated agents wait.')
    
    print('\n' + '='*60)
    print('KEY INTERPRETATION GUIDELINES:')
    print('='*60)
    print('1. EVIDENCE LEVELS:')
    print('   - E=0 (Low): Insufficient information quality')
    print('   - E=1 (Medium): Sufficient quality but unclear')
    print('   - E=2 (High): Sufficient quality and clear')
    
    print('\n2. BEHAVIORAL PATTERNS:')
    print('   - PREMATURE: High answer rate at E=0, high FP rate')
    print('   - OVER-CAUTIOUS: Low answer rate at E=2, high FN rate')
    print('   - WELL-CALIBRATED: Low E=0 rate, high E=2 rate, balanced TP/TN')
    
    print('\n3. QUALITY INDICATORS:')
    print('   - Good models: High accuracy at E=2, low premature rate')
    print('   - Poor models: Low accuracy even at E=2, high premature rate')
    print('   - Timing efficiency: Large gap between "Not yet High" and "Already High" curves')
    
    print('\n4. CALIBRATION METRICS:')
    print('   - CE (Calibration Error): Lower is better (closer to ideal policy)')
    print('   - Premature Rate: Lower is better (fewer FP)')
    print('   - Over-cautious Rate: Lower is better (fewer FN)')
    print('   - Total Error Rate: Premature + Over-cautious (lower is better)')
    
    print('\n5. ACTIONABLE INSIGHTS:')
    print('   - High premature rate → Need better evidence thresholding')
    print('   - High over-cautious rate → Need more aggressive answering')
    print('   - Low accuracy at E=2 → Need better reasoning capabilities')
    print('   - Small timing gap → Need better evidence utilization')

    # Sensitivity analyses (brief hooks)
    # 1) Policy threshold E>=1
    if args.policy_threshold == 'high':
        print('\nSensitivity: Recomputing A3–A4 with policy E>=1 (should answer at Medium).')
        confusion_df_med = compute_confusion_data(traj_df, policy_threshold='medium', group_by=args.group_by)
        calib_df_med = compute_A4_calibration(turn_df, confusion_df_med, policy_threshold='medium', n_boot=args.n_boot)
        # Compare key rates to check for reversals
        def _summ(df):
            return {r['category']:{'CE':r['CE_point'],'Premature':r['premature_point'],'Over':r['over_cautious_point']} for _, r in df.iterrows()}
        print('  Primary rates:', _summ(calib_df))
        print('  Sensitivity rates:', _summ(calib_df_med))
        print('  Note: Report if any sensitivity reverses a conclusion; otherwise trends are robust.')


if __name__ == '__main__':
    main()
