"""
LA-COCO Publication-Quality Plots
Generates figures for top-venue paper submission.
Only includes plots for experiments that support theoretical claims.
"""

import numpy as np
import json
from pathlib import Path
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.ticker import LogLocator, NullFormatter
import matplotlib.ticker as ticker

# Publication style
plt.rcParams.update({
    'font.size': 12,
    'font.family': 'serif',
    'axes.labelsize': 14,
    'axes.titlesize': 14,
    'legend.fontsize': 10,
    'xtick.labelsize': 11,
    'ytick.labelsize': 11,
    'figure.dpi': 150,
    'savefig.dpi': 300,
    'savefig.bbox': 'tight',
    'lines.linewidth': 2,
    'lines.markersize': 6,
})

RESULTS_DIR = Path(__file__).parent.parent / "results"
FIGURES_DIR = Path(__file__).parent.parent / "figures"
FIGURES_DIR.mkdir(exist_ok=True)

# Color scheme (colorblind-friendly)
COLORS = {
    'B': '#2196F3',      # Blue - prediction-augmented
    'A': '#FF9800',      # Orange - baseline
    'Hedge': '#4CAF50',  # Green - Hedge
    'Naive': '#9E9E9E',  # Gray - naive OGD
    'PD': '#F44336',     # Red - primal-dual
    'theory': '#000000', # Black - theoretical bound
}

MARKERS = {'B': 'o', 'A': 's', 'Hedge': '^', 'Naive': 'D', 'PD': 'v'}


def load_results():
    with open(RESULTS_DIR / 'experiment_results.json') as f:
        return json.load(f)


def fig1_growth_rate(data):
    """Figure 1: CCV growth rate comparison (log-log plot)."""
    results = data['block1_growth_rate']
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4.5))
    
    Ts = [r['T'] for r in results]
    
    # (a) Absolute CCV
    for name, key, color, marker in [
        ('Sub-policy B (Ours)', 'ccv_B_mean', COLORS['B'], MARKERS['B']),
        ('Sub-policy A (Baseline)', 'ccv_A_mean', COLORS['A'], MARKERS['A']),
        ('LA-COCO Hedge', 'ccv_hedge_mean', COLORS['Hedge'], MARKERS['Hedge']),
        ('Naive OGD', 'ccv_naive_mean', COLORS['Naive'], MARKERS['Naive']),
        ('Primal-Dual', 'ccv_pd_mean', COLORS['PD'], MARKERS['PD']),
    ]:
        vals = [r[key] for r in results]
        stds = [r.get(key.replace('mean', 'std'), 0) for r in results]
        ax1.plot(Ts, vals, color=color, marker=marker, label=name, markersize=5)
        ax1.fill_between(Ts,
                         [max(v-s, 0.01) for v,s in zip(vals, stds)],
                         [v+s for v,s in zip(vals, stds)],
                         alpha=0.15, color=color)
    
    # Reference lines
    T_ref = np.array(Ts)
    # Scale reference lines to be visually distinct from data
    c_log = 2.5   # O(log T) reference
    c_sqrt = 0.8  # O(sqrt T) reference
    ax1.plot(T_ref, c_log * np.log(T_ref), '--', color='gray', alpha=0.5, label=r'$O(\log T)$')
    ax1.plot(T_ref, c_sqrt * np.sqrt(T_ref), ':', color='gray', alpha=0.5, label=r'$O(\sqrt{T})$')
    
    ax1.set_xscale('log')
    ax1.set_yscale('log')
    ax1.set_xlabel(r'Time horizon $T$')
    ax1.set_ylabel(r'Cumulative Constraint Violation')
    ax1.set_title('(a) CCV Growth Rate')
    ax1.legend(loc='upper left', fontsize=8, ncol=1)
    ax1.grid(True, alpha=0.3)
    
    # (b) Normalized CCV/T
    for name, key, color, marker in [
        ('Sub-policy B', 'ccv_B_mean', COLORS['B'], MARKERS['B']),
        ('Sub-policy A', 'ccv_A_mean', COLORS['A'], MARKERS['A']),
        ('Hedge', 'ccv_hedge_mean', COLORS['Hedge'], MARKERS['Hedge']),
        ('Naive OGD', 'ccv_naive_mean', COLORS['Naive'], MARKERS['Naive']),
        ('Primal-Dual', 'ccv_pd_mean', COLORS['PD'], MARKERS['PD']),
    ]:
        vals = [r[key] / r['T'] for r in results]
        ax2.plot(Ts, vals, color=color, marker=marker, label=name, markersize=5)
    
    ax2.set_xscale('log')
    ax2.set_yscale('log')
    ax2.set_xlabel(r'Time horizon $T$')
    ax2.set_ylabel(r'CCV / $T$')
    ax2.set_title('(b) Normalized CCV (convergence rate)')
    ax2.legend(loc='upper right', fontsize=8)
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(FIGURES_DIR / 'fig1_growth_rate.pdf')
    plt.savefig(FIGURES_DIR / 'fig1_growth_rate.png')
    plt.close()
    print("  Saved fig1_growth_rate")


def fig5_event_E(data):
    """Figure 5: Event E study — publication-quality redesign.

    (a) Violin + swarm plot of per-seed margin for stochastic conditions
        (Adaptive shown as inset with broken axis to avoid scale distortion).
    (b) Grouped bar decomposition: V*Regret vs Bonus side-by-side.
    """
    results = data['block5_event_E']

    # ── Separate stochastic / adaptive conditions ──
    stoch_results = [r for r in results if r['event_E_rate'] > 0.5]
    adv_results   = [r for r in results if r['event_E_rate'] <= 0.5]

    labels_all = [r['label'] for r in results]
    e_rates_all = [r['event_E_rate'] for r in results]

    # ── Refined color palette ──
    # Use a sequential blue palette for stochastic, red for adaptive, purple for OCS
    palette = []
    blues = ['#90CAF9', '#42A5F5', '#1E88E5', '#1565C0']
    b_idx = 0
    for r in results:
        if 'OCS' in r['label'] or 'ocs' in r['label'].lower():
            palette.append('#AB47BC')  # purple
        elif r['event_E_rate'] <= 0.5:
            palette.append('#EF5350')  # red
        else:
            palette.append(blues[min(b_idx, len(blues)-1)])
            b_idx += 1

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13, 5))

    # ═══════════════════════════════════════════════════
    # (a) Violin + strip plot — margin distribution
    # ═══════════════════════════════════════════════════
    all_margins = []
    for r in results:
        all_margins.append([raw['V_regret_plus_bonus'] for raw in r['raw']])

    positions = np.arange(len(results))

    # Draw violin plots
    parts = ax1.violinplot(all_margins, positions=positions, widths=0.6,
                           showmeans=False, showmedians=False, showextrema=False)
    for i, pc in enumerate(parts['bodies']):
        pc.set_facecolor(palette[i])
        pc.set_edgecolor('none')
        pc.set_alpha(0.35)

    # Overlay box plots (thin, elegant)
    bp = ax1.boxplot(all_margins, positions=positions, widths=0.18,
                     patch_artist=True, showfliers=False,
                     medianprops=dict(color='white', linewidth=1.5),
                     whiskerprops=dict(color='gray', linewidth=0.8),
                     capprops=dict(color='gray', linewidth=0.8))
    for i, patch in enumerate(bp['boxes']):
        patch.set_facecolor(palette[i])
        patch.set_edgecolor(palette[i])
        patch.set_alpha(0.85)

    # Overlay individual seed points (jittered strip)
    rng = np.random.RandomState(42)
    for i, (margins, color) in enumerate(zip(all_margins, palette)):
        jitter = rng.uniform(-0.22, 0.22, len(margins))
        ax1.scatter(np.full(len(margins), i) + jitter, margins,
                    color=color, alpha=0.55, s=12, zorder=5,
                    edgecolors='white', linewidths=0.3)

    # E boundary line
    ax1.axhline(y=0, color='#D32F2F', linestyle='--', linewidth=1.2, alpha=0.8,
                label=r'$\mathcal{E}$ boundary (margin $= 0$)')

    # Annotate E rate as badges
    for i, (e_rate, margins_i) in enumerate(zip(e_rates_all, all_margins)):
        y_anchor = max(margins_i) if e_rate > 0.5 else min(margins_i)
        va = 'bottom' if e_rate > 0.5 else 'top'
        color_txt = '#2E7D32' if e_rate > 0.5 else '#C62828'
        badge_bg = '#E8F5E9' if e_rate > 0.5 else '#FFEBEE'
        ax1.annotate(f'{e_rate:.0%}',
                     xy=(i, y_anchor),
                     textcoords="offset points",
                     xytext=(0, 8 if e_rate > 0.5 else -10),
                     ha='center', va=va, fontsize=8, fontweight='bold',
                     color=color_txt,
                     bbox=dict(boxstyle='round,pad=0.25', facecolor=badge_bg,
                               edgecolor=color_txt, alpha=0.85, linewidth=0.6))

    # Handle extreme scale: use symlog
    all_flat = [m for ms in all_margins for m in ms]
    if min(all_flat) < -1000:
        ax1.set_yscale('symlog', linthresh=500)
        # Add a subtle break indicator
        ax1.axhspan(-500, 500, color='#FAFAFA', alpha=0.3, zorder=0)

    ax1.set_xticks(positions)
    short_labels = []
    for lab in labels_all:
        lab = lab.replace('Stochastic ', 'Stoch.\n')
        short_labels.append(lab)
    ax1.set_xticklabels(short_labels, fontsize=9, ha='center')
    ax1.set_ylabel(r'Margin $= V \!\cdot\! \mathrm{Regret}_T + \mathrm{Bonus}_T$',
                   fontsize=12)
    ax1.set_title(r'(a) Per-Seed $\mathcal{E}$ Margin Distribution', fontsize=13,
                  fontweight='bold', pad=10)
    ax1.legend(fontsize=8, loc='lower left',
               framealpha=0.9, edgecolor='gray')
    ax1.grid(True, alpha=0.15, axis='y', linestyle='-')
    ax1.set_axisbelow(True)

    # ═══════════════════════════════════════════════════
    # (b) Grouped bar: V*Regret and Bonus side-by-side
    # ═══════════════════════════════════════════════════
    V_regrets = []
    bonuses = []
    for r in results:
        v_reg = np.mean([raw['V'] * raw['regret_B'] for raw in r['raw']])
        bon = np.mean([raw['bonus_B'] for raw in r['raw']])
        V_regrets.append(v_reg)
        bonuses.append(bon)

    x = np.arange(len(results))
    bar_w = 0.32

    # Use log-abs scale for display: show sign via hatching
    # But first, separate into two panels via broken y-axis approach
    # Simpler: use symlog and grouped bars

    bars_vr = ax2.bar(x - bar_w/2, V_regrets, bar_w,
                      label=r'$V \!\cdot\! \mathrm{Regret}_T$',
                      color=[palette[i] for i in range(len(results))],
                      alpha=0.85, edgecolor='white', linewidth=0.8)

    bars_bon = ax2.bar(x + bar_w/2, bonuses, bar_w,
                       label=r'$\mathrm{Bonus}_T$',
                       color=[palette[i] for i in range(len(results))],
                       alpha=0.45, edgecolor=[palette[i] for i in range(len(results))],
                       linewidth=1.2, hatch='///')

    # Annotate total margin above each group
    for i in range(len(results)):
        total = V_regrets[i] + bonuses[i]
        y_top = max(V_regrets[i], bonuses[i], V_regrets[i] + bonuses[i])
        y_bot = min(V_regrets[i], 0)

        if total >= 0:
            txt = f'Σ={total:,.0f}'
            ax2.annotate(txt, (i, max(y_top, bonuses[i])),
                         textcoords="offset points", xytext=(0, 8),
                         ha='center', fontsize=7.5, fontweight='bold',
                         color='#2E7D32')
        else:
            if abs(total) > 10000:
                txt = f'Σ={total:.1e}'
            else:
                txt = f'Σ={total:,.0f}'
            ax2.annotate(txt, (i, y_bot),
                         textcoords="offset points", xytext=(0, -12),
                         ha='center', fontsize=7.5, fontweight='bold',
                         color='#C62828',
                         bbox=dict(boxstyle='round,pad=0.2', facecolor='#FFF3E0',
                                   edgecolor='#E65100', alpha=0.85, linewidth=0.5))

    ax2.axhline(y=0, color='black', linestyle='-', linewidth=0.6, alpha=0.5)

    # Use symlog if needed
    all_vals = V_regrets + bonuses
    if min(all_vals) < -5000:
        ax2.set_yscale('symlog', linthresh=500)

    ax2.set_xticks(x)
    ax2.set_xticklabels(short_labels, fontsize=9, ha='center')
    ax2.set_ylabel('Value', fontsize=12)
    ax2.set_title(r'(b) Margin Decomposition', fontsize=13,
                  fontweight='bold', pad=10)

    # Custom legend with clear distinction
    from matplotlib.patches import Patch
    legend_elements = [
        Patch(facecolor='#666', alpha=0.85, edgecolor='white',
              label=r'$V \!\cdot\! \mathrm{Regret}_T$'),
        Patch(facecolor='#666', alpha=0.4, edgecolor='#666',
              hatch='///', label=r'$\mathrm{Bonus}_T$'),
    ]
    ax2.legend(handles=legend_elements, fontsize=9, loc='upper left',
               framealpha=0.9, edgecolor='gray')
    ax2.grid(True, alpha=0.15, axis='y', linestyle='-')
    ax2.set_axisbelow(True)

    # ── Compact theory note as figure suptitle ──
    fig.suptitle(
        r'Thm 3: Stochastic $\Rightarrow \mathcal{E}$ holds;'
        r'  Adaptive adversary $\Rightarrow \mathcal{E}$ fails',
        fontsize=9, fontstyle='italic', color='#555', y=0.98)

    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.savefig(FIGURES_DIR / 'fig5_event_E.pdf')
    plt.savefig(FIGURES_DIR / 'fig5_event_E.png')
    plt.close()
    print("  Saved fig5_event_E")


def fig7_upper_bound(data):
    """Figure 7: CCV empirical vs theoretical upper bound.
    
    Shows that empirical CCV is always below the theoretical worst-case bound.
    This validates the correctness of Lemma 8(a), T1b, and T1d-i.
    
    Panel (a): Paired comparison — each method's empirical CCV vs its own
               theoretical bound, with per-seed scatter to show all 60 runs.
    Panel (b): Ratio (empirical / theoretical) on log scale to separate
               the three curves that otherwise overlap.
    """
    results = data['block7_upper_bound']

    Ts = np.array([r['T'] for r in results])

    ccv_A = np.array([r['ccv_A_mean'] for r in results])
    ccv_A_std = np.array([r['ccv_A_std'] for r in results])
    theory_A = np.array([r['theory_A'] for r in results])

    ccv_B = np.array([r['ccv_B_mean'] for r in results])
    ccv_B_std = np.array([r['ccv_B_std'] for r in results])
    theory_B = np.array([r['theory_B'] for r in results])

    ccv_H = np.array([r['ccv_hedge_mean'] for r in results])
    ccv_H_std = np.array([r.get('ccv_hedge_std', 0) for r in results])
    theory_H = np.array([r['theory_H'] for r in results])

    ratio_A = np.array([r['ratio_A'] for r in results])
    ratio_B = np.array([r['ratio_B'] for r in results])
    ratio_H = np.array([r['ratio_H'] for r in results])

    # ── Collect per-seed CCV for scatter ──
    seeds_A = [[raw['ccv_A'] for raw in r['raw_results']] for r in results]
    seeds_B = [[raw['ccv_B'] for raw in r['raw_results']] for r in results]

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4.5))

    # ═══════════════════════════════════════════════════
    # (a) Empirical vs Theory — paired by method
    # ═══════════════════════════════════════════════════
    # Theory lines (dashed)
    ax1.plot(Ts, theory_A, '--', color=COLORS['A'], linewidth=2.0,
             label=r'Bound A: $O(\sqrt{G^3 D T \ln T / \alpha})$',
             zorder=5)
    ax1.plot(Ts, theory_B, '--', color=COLORS['B'], linewidth=2.0,
             label=r'Bound B: $O(G^2 \ln(Te) / \alpha)$',
             zorder=5)

    # Per-seed scatter (small, semi-transparent)
    rng = np.random.RandomState(0)
    for i, T_val in enumerate(Ts):
        jitter_A = T_val * np.exp(rng.uniform(-0.03, 0.03, len(seeds_A[i])))
        jitter_B = T_val * np.exp(rng.uniform(-0.03, 0.03, len(seeds_B[i])))
        if i == 0:
            ax1.scatter(jitter_A, seeds_A[i], color=COLORS['A'], alpha=0.35,
                        s=12, zorder=2, label='_nolegend_', marker='s',
                        edgecolors='none')
            ax1.scatter(jitter_B, seeds_B[i], color=COLORS['B'], alpha=0.35,
                        s=12, zorder=2, label='_nolegend_', marker='o',
                        edgecolors='none')
        else:
            ax1.scatter(jitter_A, seeds_A[i], color=COLORS['A'], alpha=0.35,
                        s=12, zorder=2, marker='s', edgecolors='none')
            ax1.scatter(jitter_B, seeds_B[i], color=COLORS['B'], alpha=0.35,
                        s=12, zorder=2, marker='o', edgecolors='none')

    # Empirical mean lines (solid, with markers)
    ax1.plot(Ts, ccv_A, '-', color=COLORS['A'], marker=MARKERS['A'],
             markersize=5, linewidth=1.8,
             label='Sub-policy A (empirical)', zorder=4)
    ax1.fill_between(Ts,
                     np.maximum(ccv_A - ccv_A_std, 0.5),
                     ccv_A + ccv_A_std,
                     alpha=0.12, color=COLORS['A'])

    ax1.plot(Ts, ccv_B, '-', color=COLORS['B'], marker=MARKERS['B'],
             markersize=5, linewidth=1.8,
             label='Sub-policy B (empirical)', zorder=4)
    ax1.fill_between(Ts,
                     np.maximum(ccv_B - ccv_B_std, 0.5),
                     ccv_B + ccv_B_std,
                     alpha=0.12, color=COLORS['B'])

    # Arrows connecting empirical to theory at T=10000 for visual emphasis
    i_last = len(Ts) - 1
    for ccv_val, th_val, color in [(ccv_A[i_last], theory_A[i_last], COLORS['A']),
                                    (ccv_B[i_last], theory_B[i_last], COLORS['B'])]:
        ax1.annotate('', xy=(Ts[i_last], th_val),
                     xytext=(Ts[i_last], ccv_val),
                     arrowprops=dict(arrowstyle='<->', color=color,
                                     lw=1.0, alpha=0.5,
                                     connectionstyle='arc3,rad=0.15'))

    ax1.set_xscale('log')
    ax1.set_yscale('log')
    ax1.set_xlabel(r'Time horizon $T$')
    ax1.set_ylabel(r'CCV$_T$')
    ax1.set_title(r'(a) Empirical CCV vs.\ Theoretical Bound')
    ax1.legend(fontsize=7.5, loc='upper left', framealpha=0.9,
               edgecolor='#ccc', handlelength=2.5)
    ax1.grid(True, alpha=0.2, which='both')

    # ═══════════════════════════════════════════════════
    # (b) Ratio on log scale — separates overlapping curves
    # ═══════════════════════════════════════════════════
    ax2.plot(Ts, ratio_B, color=COLORS['B'], marker=MARKERS['B'],
             markersize=6, linewidth=1.8,
             label=r'CCV$_\mathrm{B}$ / Bound$_\mathrm{B}$')
    ax2.plot(Ts, ratio_A, color=COLORS['A'], marker=MARKERS['A'],
             markersize=6, linewidth=1.8,
             label=r'CCV$_\mathrm{A}$ / Bound$_\mathrm{A}$')
    ax2.plot(Ts, ratio_H, color=COLORS['Hedge'], marker=MARKERS['Hedge'],
             markersize=6, linewidth=1.8,
             label=r'CCV$_\mathrm{Hedge}$ / Bound$_\mathrm{Hedge}$')

    # Violation threshold
    ax2.axhline(y=1.0, color='#D32F2F', linestyle='--', linewidth=1.2,
                alpha=0.7, label='Violation threshold')

    # Shade the safe region
    ax2.axhspan(0, 1.0, color='#E8F5E9', alpha=0.25, zorder=0)

    ax2.set_xscale('log')
    ax2.set_yscale('log')
    ax2.set_xlabel(r'Time horizon $T$')
    ax2.set_ylabel(r'Empirical / Theoretical')
    ax2.set_title('(b) Ratio (must be $< 1$)')

    # Tight y-limits around data range on log scale
    all_ratios = np.concatenate([ratio_A, ratio_B, ratio_H])
    y_lo = 10 ** (np.floor(np.log10(all_ratios.min())) - 0.15)
    y_hi = 2.0  # show the violation line at 1.0 with headroom
    ax2.set_ylim(y_lo, y_hi)

    ax2.legend(fontsize=8, loc='upper right', framealpha=0.9,
               edgecolor='#ccc')
    ax2.grid(True, alpha=0.2, which='both')

    plt.tight_layout()
    plt.savefig(FIGURES_DIR / 'fig7_upper_bound.pdf')
    plt.savefig(FIGURES_DIR / 'fig7_upper_bound.png')
    plt.close()
    print("  Saved fig7_upper_bound")


def fig8_decomposition(data):
    """Figure 8: Lemma 3 regret decomposition verification.
    
    Verifies Q^2(T) + V*Regret + Bonus <= Regret' + Penalty
    by showing the gap (RHS - LHS) is always non-negative.
    """
    from matplotlib.patches import Patch

    results = data['block8_decomposition']

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4.5))

    # ── (a) Scatter: LHS vs RHS for ALL configs ──
    # Points above y=x line ⟹ RHS ≥ LHS (inequality holds)
    noise_colors = {0.0: '#2196F3', 0.1: '#FF9800', 0.5: '#4CAF50'}
    noise_markers = {0.0: 'o', 0.1: 's', 0.5: '^'}
    noise_labels = {0.0: r'$\sigma_{\mathrm{pred}}=0$',
                    0.1: r'$\sigma_{\mathrm{pred}}=0.1$',
                    0.5: r'$\sigma_{\mathrm{pred}}=0.5$'}

    all_lhs, all_rhs = [], []
    for r in results:
        lhs_val = r['decomp_lhs_mean']
        rhs_val = r['decomp_rhs_mean']
        n = r['noise']
        all_lhs.append(lhs_val)
        all_rhs.append(rhs_val)
        ax1.scatter(lhs_val, rhs_val,
                    c=noise_colors[n], marker=noise_markers[n],
                    s=80, edgecolors='white', linewidths=0.6,
                    zorder=3)
        # Annotate T value next to each point
        ax1.annotate(f'$T$={r["T"]}', (lhs_val, rhs_val),
                     textcoords='offset points', xytext=(6, -3),
                     fontsize=7, color='#444')

    # y=x reference line
    lo = min(min(all_lhs), min(all_rhs)) * 0.8
    hi = max(max(all_lhs), max(all_rhs)) * 1.15
    ax1.plot([lo, hi], [lo, hi], 'k--', lw=1.0, alpha=0.5, label='$y = x$')
    ax1.fill_between([lo, hi], [lo, hi], [hi, hi],
                     color='#E8F5E9', alpha=0.35, zorder=0)
    ax1.set_xlim(lo, hi)
    ax1.set_ylim(lo, hi)
    ax1.set_xscale('log')
    ax1.set_yscale('log')
    ax1.set_xlabel('LHS: $\Phi(Q_T) + V\\beta\,\\mathrm{Regret}_T + \\mathrm{Bonus}_T$',
                   fontsize=10)
    ax1.set_ylabel("RHS: $\\mathrm{Regret}'_T + \\mathrm{Penalty}_T$",
                   fontsize=10)
    ax1.set_title(r'(a) Lemma~3 verification: RHS $\geq$ LHS')

    # Legend with noise markers
    legend_handles = [plt.Line2D([0], [0], marker=noise_markers[n], color='w',
                                  markerfacecolor=noise_colors[n],
                                  markeredgecolor='white', markersize=8,
                                  label=noise_labels[n])
                      for n in sorted(noise_colors.keys())]
    legend_handles.append(plt.Line2D([0], [0], color='k', ls='--', lw=1,
                                      label='$y = x$'))
    ax1.legend(handles=legend_handles, fontsize=8, loc='upper left',
               framealpha=0.9, edgecolor='#ccc')
    ax1.set_aspect('equal', adjustable='box')
    ax1.grid(True, alpha=0.2, which='both')

    # ── (b) Grouped bar chart: Gap by T, colored by noise ──
    noise_levels = sorted(set(r['noise'] for r in results))
    T_values = sorted(set(r['T'] for r in results))
    n_noise = len(noise_levels)
    n_T = len(T_values)

    # Build gap lookup
    gap_lookup = {}
    for r in results:
        gap_lookup[(r['T'], r['noise'])] = r['decomp_gap_mean']

    x_pos = np.arange(n_T)
    bar_width = 0.22
    offsets = np.linspace(-(n_noise - 1) / 2 * bar_width,
                          (n_noise - 1) / 2 * bar_width, n_noise)

    for j, n in enumerate(noise_levels):
        gaps = [gap_lookup.get((T, n), 0) for T in T_values]
        bars = ax2.bar(x_pos + offsets[j], gaps, bar_width,
                       color=noise_colors[n], alpha=0.85,
                       edgecolor='white', linewidth=0.6,
                       label=noise_labels[n], zorder=3)
        # Value labels on top
        for i, g in enumerate(gaps):
            if g >= 1000:
                txt = f'{g/1000:.0f}k' if g < 100000 else f'{g/1000:.0f}k'
            else:
                txt = f'{g:.1f}'
            ax2.annotate(txt, (x_pos[i] + offsets[j], g),
                         textcoords='offset points', xytext=(0, 4),
                         ha='center', fontsize=6.5, color='#333',
                         fontweight='bold')

    ax2.set_yscale('log')
    ax2.set_xticks(x_pos)
    ax2.set_xticklabels([f'$T$={T}' for T in T_values], fontsize=10)
    ax2.set_xlabel(r'Time horizon $T$')
    ax2.set_ylabel(r'Gap (RHS $-$ LHS)')
    ax2.set_title(r'(b) Decomposition gap $\geq 0$ (all configs)')
    ax2.legend(fontsize=8, loc='upper left', framealpha=0.9, edgecolor='#ccc')
    ax2.grid(True, alpha=0.2, axis='y', which='both')
    ax2.set_axisbelow(True)

    plt.tight_layout()
    plt.savefig(FIGURES_DIR / 'fig8_decomposition.pdf')
    plt.savefig(FIGURES_DIR / 'fig8_decomposition.png')
    plt.close()
    print("  Saved fig8_decomposition")


def fig9_ocs_growth(data):
    """Figure 9: OCS setting — validates Corollary 3.
    
    When f_t=0, Regret=0, E holds unconditionally (Theorem 3a).
    CCV_B is far below the theoretical bound O(G^2*log(T)/alpha).
    """
    results = data['block9_ocs_growth']

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4.5))

    Ts = np.array([r['T'] for r in results])
    ccv_B = np.array([r['ccv_B_mean'] for r in results])
    ccv_B_std = np.array([r['ccv_B_std'] for r in results])
    theory_B = np.array([r['theory_B'] for r in results])
    ratios = ccv_B / theory_B * 100  # percentage

    # ── (a) Empirical CCV vs Theory bound (log-log) ──
    # Theory line
    ax1.plot(Ts, theory_B, '--', color=COLORS['theory'], linewidth=2.0,
             label=r'Theory: $\frac{16G^2\ln(Te)}{\alpha}+6GD$', zorder=5)
    # Empirical with error bars
    ccv_display = np.maximum(ccv_B, 1e-4)
    ax1.errorbar(Ts, ccv_display, yerr=ccv_B_std,
                 color=COLORS['B'], marker=MARKERS['B'], markersize=6,
                 capsize=3, capthick=1.2, linewidth=1.8,
                 markeredgecolor='white', markeredgewidth=0.6,
                 label='Sub-policy B (empirical)', zorder=4)
    # Shade gap region
    ax1.fill_between(Ts, ccv_display, theory_B, alpha=0.08, color=COLORS['B'])

    ax1.set_xscale('log')
    ax1.set_yscale('log')
    ax1.set_xlabel(r'Time horizon $T$')
    ax1.set_ylabel(r'CCV$_T$')
    ax1.set_title(r'(a) OCS ($f_t\!=\!0$): empirical CCV vs.\ theory bound')
    ax1.legend(fontsize=8, loc='upper left', framealpha=0.9, edgecolor='#ccc')
    ax1.grid(True, alpha=0.2, which='both')

    # Annotate Pr[E]=1 inside panel (a)
    ax1.text(0.97, 0.05,
             r'$\Pr[\mathcal{E}]=1.0\;\forall\,T$' '\n'
             r'(Corollary~3)',
             transform=ax1.transAxes, fontsize=9, ha='right', va='bottom',
             bbox=dict(boxstyle='round,pad=0.35', facecolor='#E8F5E9',
                       edgecolor='#81C784', alpha=0.85))

    # ── (b) Ratio (empirical / theory) in percent ──
    ax2.plot(Ts, ratios, color=COLORS['B'], marker=MARKERS['B'],
             markersize=7, linewidth=1.8,
             markeredgecolor='white', markeredgewidth=0.6, zorder=4)
    # Fill area under curve
    ax2.fill_between(Ts, 0, ratios, alpha=0.12, color=COLORS['B'])

    # Annotate each point with its ratio value
    for i, (t, r_val) in enumerate(zip(Ts, ratios)):
        offset_y = 8 if i % 2 == 0 else -14
        ax2.annotate(f'{r_val:.1f}%', (t, r_val),
                     textcoords='offset points', xytext=(0, offset_y),
                     ha='center', fontsize=8, color='#333', fontweight='bold')

    ax2.set_xscale('log')
    ax2.set_xlabel(r'Time horizon $T$')
    ax2.set_ylabel(r'Ratio: empirical / theory (\%)')
    ax2.set_title(r'(b) Tightness ratio (empirical CCV / theory bound)')
    ax2.set_ylim(bottom=0)
    ax2.grid(True, alpha=0.2, which='both')
    ax2.set_axisbelow(True)

    # Set a sensible y-limit with some headroom above the data
    y_max = ratios.max() * 1.3
    ax2.set_ylim(0, y_max)

    plt.tight_layout()
    plt.savefig(FIGURES_DIR / 'fig9_ocs_growth.pdf', bbox_inches='tight')
    plt.savefig(FIGURES_DIR / 'fig9_ocs_growth.png', bbox_inches='tight', dpi=300)
    plt.close()
    print("  Saved fig9_ocs_growth")


def fig10_regret_vs_E(data):
    """Figure 10: Regret sign vs Event E — validates Theorem 3(a).
    
    Shows that Regret >= 0 always implies E holds.
    """
    results = data['block10_regret_vs_E']
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4.5))
    
    # (a) Regret vs E rate across adversary types
    labels = [r['label'] for r in results]
    regrets = [r['regret_mean'] for r in results]
    e_rates = [r['event_E_rate'] for r in results]
    
    colors = ['#4CAF50' if r['thm3a_valid'] else '#F44336' for r in results]
    
    x = np.arange(len(labels))
    
    # Bar chart: regret
    bars = ax1.bar(x, regrets, color=colors, alpha=0.7)
    ax1.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
    ax1.set_xticks(x)
    ax1.set_xticklabels(labels, rotation=25, ha='right', fontsize=8)
    ax1.set_ylabel(r'$\mathrm{Regret}_T$')
    ax1.set_title(r'(a) Regret by Adversary Type')
    ax1.grid(True, alpha=0.3, axis='y')
    
    # Add E rate annotation
    for i, (reg, e) in enumerate(zip(regrets, e_rates)):
        ax1.annotate(f'E={e:.1f}', (i, reg),
                    textcoords="offset points",
                    xytext=(0, 10 if reg >= 0 else -15),
                    ha='center', fontsize=8, fontweight='bold',
                    color='green' if e > 0.5 else 'red')
    
    # (b) Scatter: per-seed Regret vs V*Regret+Bonus
    all_regrets = []
    all_margins = []
    all_types = []
    
    for r in results:
        for raw in r['raw']:
            all_regrets.append(raw['regret_B'])
            all_margins.append(raw['V_regret_plus_bonus'])
            all_types.append(r['adv_type'])
    
    type_colors = {'ocs': '#9C27B0', 'stochastic': '#2196F3', 'adaptive': '#F44336'}
    type_labels_done = set()
    
    for reg, margin, atype in zip(all_regrets, all_margins, all_types):
        label = atype if atype not in type_labels_done else None
        ax2.scatter(reg, margin, color=type_colors.get(atype, 'gray'),
                   alpha=0.5, s=20, label=label)
        type_labels_done.add(atype)
    
    ax2.axhline(y=0, color='red', linestyle='--', linewidth=1.5,
                label=r'$\mathcal{E}$ boundary')
    ax2.axvline(x=0, color='gray', linestyle=':', linewidth=1)
    
    # Highlight Theorem 3(a) region
    xlim = ax2.get_xlim()
    ylim = ax2.get_ylim()
    ax2.fill_between([0, max(xlim[1], 100)], ylim[0], ylim[1],
                     alpha=0.05, color='green')
    ax2.text(max(xlim[1]*0.3, 10), ylim[1]*0.8, 'Thm 3(a):\nRegret≥0 ⟹ E',
             fontsize=9, color='green', ha='center',
             bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    ax2.set_xlabel(r'$\mathrm{Regret}_T$')
    ax2.set_ylabel(r'$V \cdot \mathrm{Regret}_T + \mathrm{Bonus}_T$')
    ax2.set_title(r'(b) Theorem 3(a): Regret $\geq 0 \Rightarrow \mathcal{E}$')
    ax2.legend(fontsize=8)
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(FIGURES_DIR / 'fig10_regret_vs_E.pdf')
    plt.savefig(FIGURES_DIR / 'fig10_regret_vs_E.png')
    plt.close()
    print("  Saved fig10_regret_vs_E")


def main():
    print("Loading results...")
    data = load_results()
    
    print("Generating figures...")
    if 'block1_growth_rate' in data:
        fig1_growth_rate(data)
    if 'block5_event_E' in data:
        fig5_event_E(data)
    if 'block7_upper_bound' in data:
        fig7_upper_bound(data)
    if 'block8_decomposition' in data:
        fig8_decomposition(data)
    if 'block9_ocs_growth' in data:
        fig9_ocs_growth(data)
    if 'block10_regret_vs_E' in data:
        fig10_regret_vs_E(data)
    
    print(f"\nAll figures saved to {FIGURES_DIR}")


if __name__ == '__main__':
    main()
