#!/usr/bin/env python3
"""
Plot Alpha=1.5 Regret Scaling Results

Generates figure showing regret vs T for SP-UCB with α=1.5.
Validates √T scaling predicted by Theorem 5.

Style matches generate_paper_figures.py (fig7_cr_vs_T).
"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from scipy.optimize import curve_fit

# Publication-quality defaults (matching generate_paper_figures.py)
plt.rcParams.update({
    'font.size': 12,
    'axes.labelsize': 14,
    'axes.titlesize': 14,
    'xtick.labelsize': 11,
    'ytick.labelsize': 11,
    'legend.fontsize': 10,
    'figure.dpi': 150,
    'savefig.dpi': 300,
    'savefig.bbox': 'tight',
})


def sqrt_func(T, c):
    """√T scaling function for fitting."""
    return c * np.sqrt(T)


def sqrt_T_logT_func(T, c):
    """√(T log T) scaling function for fitting."""
    return c * np.sqrt(T * np.log(T))


def plot_regret_scaling_fig7_style(df: pd.DataFrame, output_dir: Path):
    """
    Generate regret scaling figure in the style of fig7_cr_vs_T.

    Single panel, line plot with SE confidence bands, fitted √T reference.
    Log x-axis, matching publication style.
    """
    # Colors matching fig7
    color_data = '#1f77b4'   # Blue (same as SP-UCB-α=0 in fig7)
    color_ref = '#d62728'    # Red for reference line

    # Group by T
    grouped = df.groupby('T')['regret'].agg(['mean', 'std', 'count'])
    T_values = grouped.index.values
    means = grouped['mean'].values
    stds = grouped['std'].values
    counts = grouped['count'].values

    # Standard error
    se = stds / np.sqrt(counts)

    # Fit √(T log T) curve
    popt, _ = curve_fit(sqrt_T_logT_func, T_values, means)
    c_fit = popt[0]

    # Generate smooth fitted curve
    T_smooth = np.linspace(T_values.min(), T_values.max(), 100)
    fitted_curve = sqrt_T_logT_func(T_smooth, c_fit)

    # Create figure (single panel, similar aspect to one panel of fig7)
    fig, ax = plt.subplots(1, 1, figsize=(7, 5))

    # Plot fitted √T reference line first (behind data)
    ax.plot(T_smooth, fitted_curve, '--', color=color_ref, linewidth=2,
            label=f'Fitted $\\sqrt{{T}}$ (c={c_fit:.2f})')

    # Plot data with confidence bands (fig7 style: line + fill_between)
    ax.plot(T_values, means, 'o-', color=color_data,
            linewidth=2, markersize=6,
            label='SP-UCB (α=1.5)')
    ax.fill_between(T_values, means - 1.96*se, means + 1.96*se,
                    color=color_data, alpha=0.2)

    # Style matching fig7_cr_vs_T (but with linear scale)
    ax.set_xlabel('Time Horizon $T$')
    ax.set_ylabel('Regret')
    ax.set_title('Regret Scaling: SP-UCB-OLP with α=1.5 (50 Seeds)',
                 fontsize=14, fontweight='bold')

    # Grid (matching fig7)
    ax.grid(alpha=0.3)

    # Legend
    ax.legend(loc='upper left', frameon=True, fancybox=True, framealpha=0.9)

    # Add annotation box for n and error bar info (matching fig7 style)
    textstr = f'n = {int(counts[0])} seeds per $T$\nError bands: ±1.96 SE'
    props = dict(boxstyle='round', facecolor='white', alpha=0.8, edgecolor='gray')
    ax.text(0.97, 0.03, textstr, transform=ax.transAxes, fontsize=9,
            verticalalignment='bottom', horizontalalignment='right', bbox=props)

    plt.tight_layout()

    # Save PNG and PDF
    output_png = output_dir / "fig_regret_scaling_50seeds.png"
    output_pdf = output_dir / "fig_regret_scaling_50seeds.pdf"

    plt.savefig(output_png, dpi=300, bbox_inches='tight')
    print(f"Saved: {output_png}")

    plt.savefig(output_pdf, bbox_inches='tight')
    print(f"Saved: {output_pdf}")

    plt.close()

    return c_fit


def plot_regret_scaling(df: pd.DataFrame, output_dir: Path):
    """
    Generate regret scaling figure with two panels:
    1. Regret vs T (linear scale)
    2. Regret/√T vs T (should be constant if √T scaling holds)
    """
    fig, axes = plt.subplots(1, 2, figsize=(12, 4.5))

    # Colors
    color_main = '#2ca02c'  # Green for α=1.5
    color_ref = '#888888'   # Gray for reference

    # Group by T
    grouped = df.groupby('T')['regret'].agg(['mean', 'std', 'count'])
    T_values = grouped.index.values
    means = grouped['mean'].values
    stds = grouped['std'].values
    counts = grouped['count'].values
    se = stds / np.sqrt(counts)

    # Panel 1: Regret vs T
    ax1 = axes[0]
    ax1.errorbar(T_values, means, yerr=1.96*se, fmt='o-', color=color_main,
                 linewidth=2, markersize=8, capsize=4, label='SP-UCB-α=1.5')

    # Add √T reference line (scaled)
    T_ref = np.linspace(T_values[0], T_values[-1], 100)
    # Scale to match the data at the midpoint
    mid_idx = len(T_values) // 2
    scale = means[mid_idx] / np.sqrt(T_values[mid_idx])
    ax1.plot(T_ref, scale * np.sqrt(T_ref), '--', color=color_ref,
             linewidth=1.5, alpha=0.7, label=r'$c \cdot \sqrt{T}$ reference')

    ax1.set_xlabel('Horizon T')
    ax1.set_ylabel('Regret')
    ax1.set_title('Regret vs Horizon', fontweight='bold')
    ax1.legend(loc='upper left')
    ax1.grid(alpha=0.3)

    # Panel 2: Regret/√T vs T
    ax2 = axes[1]
    grouped_norm = df.groupby('T')['regret_normalized'].agg(['mean', 'std', 'count'])
    means_norm = grouped_norm['mean'].values
    stds_norm = grouped_norm['std'].values
    counts_norm = grouped_norm['count'].values
    se_norm = stds_norm / np.sqrt(counts_norm)

    ax2.errorbar(T_values, means_norm, yerr=1.96*se_norm, fmt='o-', color=color_main,
                 linewidth=2, markersize=8, capsize=4, label='SP-UCB-α=1.5')

    # Add horizontal reference line at mean
    mean_all = np.mean(means_norm)
    ax2.axhline(mean_all, color=color_ref, linestyle='--', linewidth=1.5,
                alpha=0.7, label=f'Mean = {mean_all:.2f}')

    ax2.set_xlabel('Horizon T')
    ax2.set_ylabel(r'Regret / $\sqrt{T}$')
    ax2.set_title(r'Normalized Regret (should be $\approx$ constant)', fontweight='bold')
    ax2.legend(loc='upper right')
    ax2.grid(alpha=0.3)

    # Set y-limits for normalized plot to show flatness
    y_margin = 0.3 * mean_all
    ax2.set_ylim(max(0, mean_all - y_margin * 3), mean_all + y_margin * 3)

    plt.suptitle(r'SP-UCB with $\alpha = 1.5$ (Theory-Compliant): $\sqrt{T}$ Regret Scaling',
                 fontsize=14, fontweight='bold', y=1.02)
    plt.tight_layout()

    # Save
    output_png = output_dir / "fig_alpha15_regret.png"
    plt.savefig(output_png, dpi=300, bbox_inches='tight')
    print(f"Saved: {output_png}")

    output_pdf = output_dir / "fig_alpha15_regret.pdf"
    plt.savefig(output_pdf, bbox_inches='tight')
    print(f"Saved: {output_pdf}")

    plt.close()


def plot_simple_regret(df: pd.DataFrame, output_dir: Path):
    """
    Generate simple single-panel regret scaling figure.
    """
    fig, ax = plt.subplots(figsize=(7, 5))

    color_main = '#2ca02c'
    color_ref = '#888888'

    # Group by T
    grouped = df.groupby('T')['regret'].agg(['mean', 'std', 'count'])
    T_values = grouped.index.values
    means = grouped['mean'].values
    stds = grouped['std'].values
    counts = grouped['count'].values
    se = stds / np.sqrt(counts)

    # Main plot
    ax.errorbar(T_values, means, yerr=1.96*se, fmt='o-', color=color_main,
                linewidth=2, markersize=8, capsize=4, label=r'SP-UCB-$\alpha$=1.5')

    # √T reference
    T_ref = np.linspace(T_values[0], T_values[-1], 100)
    mid_idx = len(T_values) // 2
    scale = means[mid_idx] / np.sqrt(T_values[mid_idx])
    ax.plot(T_ref, scale * np.sqrt(T_ref), '--', color=color_ref,
            linewidth=1.5, alpha=0.7, label=r'$c \cdot \sqrt{T}$')

    ax.set_xlabel('Horizon T')
    ax.set_ylabel('Regret')
    ax.set_title(r'Regret Scaling with $\alpha = 1.5$ (Theory-Compliant)',
                 fontsize=14, fontweight='bold')
    ax.legend(loc='upper left', fontsize=11)
    ax.grid(alpha=0.3)

    plt.tight_layout()

    output_png = output_dir / "fig_alpha15_simple.png"
    plt.savefig(output_png, dpi=300, bbox_inches='tight')
    print(f"Saved: {output_png}")

    output_pdf = output_dir / "fig_alpha15_simple.pdf"
    plt.savefig(output_pdf, bbox_inches='tight')
    print(f"Saved: {output_pdf}")

    plt.close()


def main():
    # Paths
    script_dir = Path(__file__).parent.parent
    results_dir = script_dir / "results" / "alpha15_regret"
    figures_dir = results_dir  # Save in same directory as data

    print("=" * 60)
    print("Plotting Alpha=1.5 Regret Scaling Results (Fig7 Style)")
    print("=" * 60)

    # Load data
    csv_file = results_dir / "alpha15_regret_per_seed.csv"
    if not csv_file.exists():
        print(f"Error: Results file not found at {csv_file}")
        print("Run the experiments first: python run_alpha15_regret.py --seed 42")
        return

    df = pd.read_csv(csv_file)
    print(f"Loaded {len(df)} records from {csv_file}")

    # Generate main figure (fig7 style)
    c_fit = plot_regret_scaling_fig7_style(df, figures_dir)
    print(f"\nFitted coefficient: Regret = {c_fit:.3f} × √T")

    print("\n" + "=" * 60)
    print(f"Figure saved to: {figures_dir}")
    print("=" * 60)


if __name__ == "__main__":
    main()
