"""
MC-COCO Experiment Plotting (v4)

Generates all figures from experiment results.
Uses diverse chart types: bar charts, heatmaps, scatter plots, area charts.
"""

import numpy as np
import json
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.patches import FancyArrowPatch
from pathlib import Path
import argparse

RESULTS_DIR = Path(__file__).parent.parent / "results"
FIGURES_DIR = Path(__file__).parent.parent / "figures"
FIGURES_DIR.mkdir(exist_ok=True)

# Style
plt.rcParams.update({
    'font.size': 12,
    'axes.labelsize': 14,
    'axes.titlesize': 14,
    'legend.fontsize': 10,
    'xtick.labelsize': 11,
    'ytick.labelsize': 11,
    'lines.linewidth': 2.2,
    'figure.figsize': (7, 5),
    'figure.dpi': 150,
    'axes.spines.top': False,
    'axes.spines.right': False,
})

# Color palette (colorblind-friendly)
C_MC1 = '#2171B5'       # Blue
C_NAIVE = '#CB181D'     # Red
C_THEORY = '#238B45'    # Green
C_ORANGE = '#E6550D'    # Orange
C_PURPLE = '#6A51A3'    # Purple
C_GRAY = '#969696'      # Gray

BETA_COLORS = {0.5: '#2171B5', 0.7: '#E6550D', 0.9: '#238B45', 1.0: '#6A51A3'}
BETA_HATCHES = {0.5: '', 0.7: '//', 0.9: '..', 1.0: 'xx'}


def plot_block2_k_dependence(results: list, save_path: Path):
    """Fig 1: K-dependence.
    Left: Grouped bar chart (MC-1 vs Naive) at beta=0.9.
    Right: Log-scale line plot with shaded confidence interval + ln(K) fit.
    """
    beta_09 = sorted(
        [r for r in results if abs(r.get('beta', 0.5) - 0.9) < 0.05],
        key=lambda r: r['K'])

    if not beta_09:
        return

    Ks = [r['K'] for r in beta_09]
    mc1_ccvs = [r['mc1']['max_ccv_mean'] for r in beta_09]
    mc1_stds = [r['mc1']['max_ccv_std'] for r in beta_09]
    naive_ccvs = [r['naive']['max_ccv_mean'] for r in beta_09]
    naive_stds = [r['naive']['max_ccv_std'] for r in beta_09]
    theory_ccvs = [r['mc1']['theoretical_per_ccv'] for r in beta_09]

    fig, axes = plt.subplots(1, 2, figsize=(14, 5.5))

    # ---- Left panel: Grouped bar chart ----
    ax = axes[0]
    x = np.arange(len(Ks))
    width = 0.32

    bars_mc1 = ax.bar(x - width/2, mc1_ccvs, width, yerr=mc1_stds,
                       color=C_MC1, alpha=0.85, capsize=4,
                       label='MC-1 (Ours)', edgecolor='white', linewidth=0.8)
    bars_naive = ax.bar(x + width/2, naive_ccvs, width, yerr=naive_stds,
                         color=C_NAIVE, alpha=0.85, capsize=4,
                         label='Naive Independent', edgecolor='white', linewidth=0.8)

    # Annotate improvement ratio on top of MC-1 bars
    for i, (mc, nv) in enumerate(zip(mc1_ccvs, naive_ccvs)):
        ratio = nv / mc
        ax.annotate(f'{ratio:.1f}×',
                     xy=(x[i] - width/2, mc + mc1_stds[i] + 30),
                     ha='center', va='bottom',
                     fontsize=9, fontweight='bold', color=C_MC1)

    ax.set_xlabel('Number of Constraints $K$')
    ax.set_ylabel('Max Per-Constraint CCV')
    ax.set_title(r'MC-1 vs Naive ($\beta=0.9$)')
    ax.set_xticks(x)
    ax.set_xticklabels([str(k) for k in Ks])
    ax.legend(loc='upper left')
    ax.grid(axis='y', alpha=0.3)

    # ---- Right panel: Log-scale line + shaded region + ln(K) fit ----
    ax = axes[1]

    mc1_arr = np.array(mc1_ccvs)
    mc1_std_arr = np.array(mc1_stds)

    ax.semilogy(Ks, mc1_ccvs, 'o-', color=C_MC1, markersize=8,
                 label='MC-1 (Empirical)', zorder=3)
    ax.fill_between(Ks, mc1_arr - 2*mc1_std_arr, mc1_arr + 2*mc1_std_arr,
                     color=C_MC1, alpha=0.15, label=r'$\pm 2\sigma$ CI')
    ax.semilogy(Ks, theory_ccvs, '^--', color=C_THEORY, markersize=7,
                 alpha=0.7, label='Theory bound (Thm 2)')

    # ln(K) fit reference line
    try:
        K_arr = np.array(Ks, dtype=float)
        log_K = np.log(np.log(K_arr + 1))
        log_ccv = np.log(mc1_arr)
        coeffs = np.polyfit(log_K, log_ccv, 1)
        K_ref = np.linspace(2, 105, 200)
        fitted = np.exp(np.polyval(coeffs, np.log(np.log(K_ref + 1))))
        ax.semilogy(K_ref, fitted, ':', color=C_GRAY, linewidth=2,
                     alpha=0.7, label=r'$\sim \ln(K)$ fit')
    except Exception:
        pass

    ax.set_xlabel('Number of Constraints $K$')
    ax.set_ylabel('Max Per-Constraint CCV')
    ax.set_title(r'Logarithmic Scaling ($\beta=0.9$, log-scale)')
    ax.legend(fontsize=9)
    ax.grid(True, alpha=0.3, which='both')

    plt.tight_layout()
    plt.savefig(save_path, bbox_inches='tight')
    plt.close()
    print(f"  Saved: {save_path}")


def plot_block3_t_scaling(results: list, save_path: Path):
    """Fig 2: T-scaling.
    Left: Log-log line plot (|Regret| vs T) — classic rate verification.
    Right: Stacked area chart showing CCV growth across beta values.
    """
    fig, axes = plt.subplots(1, 2, figsize=(14, 5.5))

    beta_data = {}
    for beta_val in [0.5, 0.7, 0.9]:
        subset = sorted(
            [r for r in results if abs(r.get('beta_actual', r['beta']) - beta_val) < 0.01],
            key=lambda r: r.get('T_actual', r['T']))
        if subset:
            beta_data[beta_val] = subset

    # ---- Left panel: |Regret| vs T (loglog) ----
    ax = axes[0]
    for beta_val, subset in beta_data.items():
        Ts = [r.get('T_actual', r['T']) for r in subset]
        regrets = [abs(r['regret_mean']) for r in subset]
        c = BETA_COLORS[beta_val]
        ax.loglog(Ts, regrets, 'o-', color=c, markersize=7,
                   label=rf'$\beta={beta_val}$')

    # Reference slopes
    T_ref = np.array([80, 40000])
    ax.loglog(T_ref, 0.35 * T_ref, ':', color=C_GRAY, alpha=0.5,
               label=r'$\sim T$ ref', linewidth=1.5)
    ax.loglog(T_ref, 3.5 * T_ref**0.5, '--', color=C_GRAY, alpha=0.4,
               label=r'$\sim T^{0.5}$ ref', linewidth=1.5)

    ax.set_xlabel('Time Horizon $T$')
    ax.set_ylabel('$|$Regret$|$')
    ax.set_title('$|$Regret$|$ vs $T$ (log-log)')
    ax.legend(fontsize=9)
    ax.grid(True, alpha=0.3, which='both')

    # ---- Right panel: Stacked area chart for CCV ----
    ax = axes[1]

    # Use the common T range
    for beta_val in [0.9, 0.7, 0.5]:  # Plot from bottom to top
        if beta_val not in beta_data:
            continue
        subset = beta_data[beta_val]
        Ts = [r.get('T_actual', r['T']) for r in subset]
        ccvs = [r['max_ccv_mean'] for r in subset]
        theo = [r['theoretical_per_ccv'] for r in subset]
        c = BETA_COLORS[beta_val]

        ax.fill_between(Ts, 0, ccvs, color=c, alpha=0.25)
        ax.plot(Ts, ccvs, 'o-', color=c, markersize=6,
                label=rf'Empirical $\beta={beta_val}$')
        ax.plot(Ts, theo, '^--', color=c, alpha=0.4, markersize=5,
                label=rf'Theory $\beta={beta_val}$')

    ax.set_xlabel('Time Horizon $T$')
    ax.set_ylabel('Max Per-Constraint CCV')
    ax.set_title('Per-Constraint CCV vs $T$')
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.legend(fontsize=8, ncol=2, loc='upper left')
    ax.grid(True, alpha=0.3, which='both')

    plt.tight_layout()
    plt.savefig(save_path, bbox_inches='tight')
    plt.close()
    print(f"  Saved: {save_path}")


def plot_block5_heterogeneous(results: list, save_path: Path):
    """Fig 3: Heterogeneous prioritization.
    Grouped bar chart: each group is a config, bars are per-constraint CCVs.
    Color intensity encodes alpha_k (darker = higher priority).
    """
    # Group by beta
    beta_groups = {}
    for r in results:
        b = round(r.get('beta', 0.9), 1)
        if b not in beta_groups:
            beta_groups[b] = []
        beta_groups[b].append(r)

    n_betas = len(beta_groups)
    fig, axes = plt.subplots(1, n_betas, figsize=(7 * n_betas, 5.5))
    if n_betas == 1:
        axes = [axes]

    # Color map for priority intensity
    cmap = plt.cm.Blues

    for idx, (beta_val, group) in enumerate(sorted(beta_groups.items())):
        ax = axes[idx]

        config_names = [r.get('config_name', 'unknown') for r in group]
        n_configs = len(group)
        K = len(group[0]['alphas'])

        x = np.arange(K)
        total_width = 0.75
        bar_width = total_width / n_configs

        config_colors_map = {
            'uniform': '#2171B5',
            'geometric': '#E6550D',
            'one_critical': '#6A51A3',
        }

        for ci, res in enumerate(group):
            alphas = res['alphas']
            per_ccv = res['per_ccv_means']
            per_ccv_std = res['per_ccv_stds']
            name = res.get('config_name', 'unknown')
            base_color = config_colors_map.get(name, C_GRAY)

            # Vary saturation based on alpha_k
            positions = x + (ci - n_configs/2 + 0.5) * bar_width

            # Create bars with alpha-dependent shade
            for j in range(K):
                alpha_shade = 0.4 + 0.5 * alphas[j]  # Map alpha to visual intensity
                ax.bar(positions[j], per_ccv[j], bar_width * 0.9,
                        yerr=per_ccv_std[j], capsize=3,
                        color=base_color, alpha=alpha_shade,
                        edgecolor='white', linewidth=0.6,
                        label=name.replace('_', ' ').title() if j == 0 else '')

        # Add alpha labels on x-axis
        ax.set_xlabel('Constraint Index $k$')
        ax.set_ylabel(r'CCV$_k$')
        ax.set_title(rf'Heterogeneous Prioritization ($\beta={beta_val}$)')
        ax.set_xticks(x)
        ax.set_xticklabels([f'$k={i+1}$' for i in range(K)])
        ax.legend(fontsize=9)
        ax.grid(axis='y', alpha=0.3)

        # Add a second x-axis label showing alphas for geometric config
        geo = [r for r in group if r.get('config_name') == 'geometric']
        if geo:
            ax2 = ax.twiny()
            ax2.set_xlim(ax.get_xlim())
            ax2.set_xticks(x)
            ax2.set_xticklabels([rf'$\alpha={a}$' for a in geo[0]['alphas']],
                                 fontsize=8, color=C_ORANGE)
            ax2.set_xlabel(r'Priority weights $\alpha_k$ (geometric config)',
                            fontsize=9, color=C_ORANGE)

    plt.tight_layout()
    plt.savefig(save_path, bbox_inches='tight')
    plt.close()
    print(f"  Saved: {save_path}")


def plot_block6_tradeoff(results: list, save_path: Path):
    """Fig 4: Regret × CCV trade-off.
    Left: Scatter plot of (|Regret|, CCV_total) for different beta, with Pareto front.
    Right: Bar chart of normalized product vs K at beta=1.
    """
    fig, axes = plt.subplots(1, 2, figsize=(14, 5.5))

    # Separate results
    beta_results = sorted(
        [r for r in results if r.get('K') == 10 and r.get('T') == 10000],
        key=lambda r: r['beta'])
    k_results_at_b1 = sorted(
        [r for r in results if abs(r.get('beta', 0) - 1.0) < 0.01 and r.get('T') == 10000],
        key=lambda r: r['K'])

    # ---- Left panel: Scatter plot with Pareto frontier ----
    ax = axes[0]

    if beta_results:
        betas = [r['beta'] for r in beta_results]
        regrets_abs = [abs(r['regret_mean']) for r in beta_results]
        ccvs_total = [r['total_ccv_mean'] for r in beta_results]

        # Color map by beta
        scatter = ax.scatter(regrets_abs, ccvs_total,
                              c=betas, cmap='RdYlGn_r', s=180, zorder=5,
                              edgecolors='black', linewidths=1.2,
                              vmin=0.2, vmax=1.1)

        # Label each point with beta
        for i, b in enumerate(betas):
            offset_x = 8 if i < len(betas) - 1 else -30
            offset_y = -15 if i % 2 == 0 else 10
            ax.annotate(rf'$\beta={b}$',
                         xy=(regrets_abs[i], ccvs_total[i]),
                         xytext=(offset_x, offset_y),
                         textcoords='offset points',
                         fontsize=9, fontweight='bold',
                         arrowprops=dict(arrowstyle='->', color=C_GRAY, lw=0.8)
                         if abs(offset_x) > 20 else None)

        # Connect points to show Pareto front
        ax.plot(regrets_abs, ccvs_total, '--', color=C_GRAY, alpha=0.5,
                 linewidth=1.5, zorder=2, label='Trade-off frontier')

        # Add arrow showing trade-off direction
        ax.annotate('', xy=(regrets_abs[-1] + 20, ccvs_total[-1]),
                     xytext=(regrets_abs[0] + 20, ccvs_total[0]),
                     arrowprops=dict(arrowstyle='->', color=C_ORANGE, lw=2))
        ax.text(max(regrets_abs) * 0.7, max(ccvs_total) * 0.5,
                r'$\beta \uparrow$', fontsize=14, color=C_ORANGE,
                fontweight='bold', ha='center')

        cbar = plt.colorbar(scatter, ax=ax, pad=0.02, shrink=0.85)
        cbar.set_label(r'$\beta$', fontsize=12)

    ax.set_xlabel(r'$|$Regret$|$')
    ax.set_ylabel(r'CCV$_{\mathrm{total}}$')
    ax.set_title(r'Regret--CCV Trade-off ($K=10$, $T=10{,}000$)')
    ax.legend(loc='upper right', fontsize=9)
    ax.grid(True, alpha=0.3)

    # ---- Right panel: Bar chart of normalized product vs K ----
    ax = axes[1]

    if k_results_at_b1:
        Ks = [r['K'] for r in k_results_at_b1]
        normed = [r['product_normalized'] for r in k_results_at_b1]

        # Color bars by value (gradient)
        norm = mcolors.Normalize(vmin=min(normed), vmax=max(normed))
        bar_colors = [plt.cm.Blues_r(norm(v)) for v in normed]

        bars = ax.bar(range(len(Ks)), normed, color=bar_colors,
                       edgecolor='white', linewidth=1, width=0.65)

        # Add value labels on bars
        for i, (bar, val) in enumerate(zip(bars, normed)):
            ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.3,
                     f'{val:.2f}', ha='center', va='bottom',
                     fontsize=9, fontweight='bold')

        ax.set_xticks(range(len(Ks)))
        ax.set_xticklabels([str(k) for k in Ks])

        # Reference line at O(1)
        ax.axhline(y=1.0, linestyle='--', color=C_GRAY, alpha=0.6,
                     linewidth=1.5, label=r'$O(1)$ reference')

    ax.set_xlabel('Number of Constraints $K$')
    ax.set_ylabel(r'$|R| \times V_{\mathrm{total}}$ / ($KT$)')
    ax.set_title(r'Normalized Product at $\beta=1$')
    ax.legend(fontsize=10)
    ax.grid(axis='y', alpha=0.3)

    plt.tight_layout()
    plt.savefig(save_path, bbox_inches='tight')
    plt.close()
    print(f"  Saved: {save_path}")


def plot_block1_sanity(results: list, save_path: Path):
    """Sanity check: CCV and Regret vs beta with theory bounds."""
    betas = [r['beta'] for r in results]
    ccvs = [r['max_ccv_mean'] for r in results]
    ccv_stds = [r['max_ccv_std'] for r in results]
    theo_ccvs = [r['theoretical_per_ccv'] for r in results]

    fig, ax = plt.subplots(figsize=(8, 5))
    ax.errorbar(betas, ccvs, yerr=ccv_stds, fmt='o-', color=C_MC1,
                capsize=5, markersize=8, label='Empirical CCV')
    ax.plot(betas, theo_ccvs, '^--', color=C_THEORY,
            markersize=7, label='Theory bound (Thm 2)')

    ax.set_xlabel(r'Trade-off parameter $\beta$')
    ax.set_ylabel('Max Per-Constraint CCV')
    ax.set_title('Sanity Check: CCV vs β (N=50, K=5, T=10000)')
    ax.set_yscale('log')
    ax.legend()
    ax.grid(True, alpha=0.3, which='both')

    plt.tight_layout()
    plt.savefig(save_path, bbox_inches='tight')
    plt.close()
    print(f"  Saved: {save_path}")


def main():
    parser = argparse.ArgumentParser(description='MC-COCO Experiment Plotting')
    parser.add_argument('--results-file', type=str,
                        default=str(RESULTS_DIR / 'experiment_results.json'))
    parser.add_argument('--output-dir', type=str, default=str(FIGURES_DIR))
    args = parser.parse_args()

    with open(args.results_file) as f:
        results = json.load(f)

    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    print("Generating plots...")

    if 'block1_sanity' in results:
        plot_block1_sanity(
            results['block1_sanity'],
            output_dir / 'fig_sanity.png')

    if 'block2_k_dependence' in results:
        plot_block2_k_dependence(
            results['block2_k_dependence'],
            output_dir / 'fig_k_dependence.png')

    if 'block3_t_scaling' in results:
        plot_block3_t_scaling(
            results['block3_t_scaling'],
            output_dir / 'fig_t_scaling.png')

    if 'block5_heterogeneous' in results:
        plot_block5_heterogeneous(
            results['block5_heterogeneous'],
            output_dir / 'fig_heterogeneous.png')

    if 'block6_tradeoff' in results:
        plot_block6_tradeoff(
            results['block6_tradeoff'],
            output_dir / 'fig_tradeoff.png')

    print("Done!")


if __name__ == '__main__':
    main()
