#!/usr/bin/env python3
"""
Regenerate Fig5: Regret Scaling with Horizon T

Changes from original:
- Title: "Regret Scaling with Horizon T" (removed "Figure 5:" and "(linear scale)")
- Removed √T reference dashed line
"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path

# Set publication-quality defaults
plt.rcParams.update({
    'font.size': 12,
    'axes.labelsize': 14,
    'axes.titlesize': 14,
    'xtick.labelsize': 11,
    'ytick.labelsize': 11,
    'legend.fontsize': 10,
    'figure.dpi': 150,
    'savefig.dpi': 300,
    'savefig.bbox': 'tight',
})

# Colors matching original figure
ALGORITHM_COLORS = {
    'SP-UCB-α=0': '#1f77b4',      # Blue
    'SP-UCB-α=0.01': '#2ca02c',   # Green
    'SP-UCB-α=0.1': '#ff7f0e',    # Orange
    'OneHot': '#e377c2',          # Pink
    'Oracle': '#17becf',          # Cyan
}

ALGORITHM_ORDER = ['SP-UCB-α=0', 'SP-UCB-α=0.01', 'SP-UCB-α=0.1', 'OneHot', 'Oracle']

SCENARIO_FULL_NAMES = {
    'S1': 'S1: High-Variance',
    'S2': 'S2: Deceptive Arms',
    'S3': 'S3: Selective Admission',
}


def generate_regret_scaling_figure(df: pd.DataFrame, output_dir: Path):
    """
    Generate regret scaling figure.

    Changes:
    - Title: "Regret Scaling with Horizon T"
    - No √T reference line
    """
    scenarios = ['S1', 'S2', 'S3']

    fig, axes = plt.subplots(1, 3, figsize=(14, 4.5))

    # Create shared legend data
    legend_handles = []
    legend_labels = []

    for ax, scenario in zip(axes, scenarios):
        scenario_df = df[df['scenario'] == scenario]

        for alg in ALGORITHM_ORDER:
            alg_df = scenario_df[scenario_df['algorithm'] == alg]
            if len(alg_df) == 0:
                continue

            # Group by T and compute mean/std
            grouped = alg_df.groupby('T')['regret'].agg(['mean', 'std', 'count'])
            T_values = grouped.index.values
            means = grouped['mean'].values
            stds = grouped['std'].values
            counts = grouped['count'].values

            # Standard error
            se = stds / np.sqrt(counts)

            color = ALGORITHM_COLORS[alg]

            line, = ax.plot(T_values, means, 'o-', color=color, label=alg,
                           linewidth=2, markersize=5)
            ax.fill_between(T_values, means - 1.96*se, means + 1.96*se,
                           color=color, alpha=0.2)

            if ax == axes[0]:
                legend_handles.append(line)
                legend_labels.append(alg)

        # NOTE: √T reference line removed

        ax.set_xlabel('Horizon T')
        ax.set_ylabel('Regret' if ax == axes[0] else '')
        ax.set_title(SCENARIO_FULL_NAMES[scenario], fontsize=12, fontweight='bold')
        ax.grid(alpha=0.3)

    # Add shared legend at top
    fig.legend(legend_handles, legend_labels, loc='upper center',
               ncol=5, bbox_to_anchor=(0.5, 1.02), frameon=False)

    # Updated title (no "Figure 5:" prefix, no "(linear scale)")
    fig.suptitle('Regret Scaling with Horizon T', fontsize=16, fontweight='bold', y=1.08)
    plt.tight_layout()

    # Save
    output_file = output_dir / "fig5_regret_scaling.png"
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    print(f"Saved: {output_file}")

    output_pdf = output_dir / "fig5_regret_scaling.pdf"
    plt.savefig(output_pdf, bbox_inches='tight')
    print(f"Saved: {output_pdf}")

    plt.close()


def main():
    # Paths
    script_dir = Path(__file__).parent.parent
    results_dir = script_dir / "results" / "regret_analysis"
    figures_dir = script_dir / "figures"
    figures_dir.mkdir(exist_ok=True)

    print("=" * 60)
    print("Regenerating Fig5: Regret Scaling with Horizon T")
    print("=" * 60)

    # Load data
    csv_file = results_dir / "regret_per_seed.csv"
    if not csv_file.exists():
        print(f"Error: Data file not found at {csv_file}")
        return

    df = pd.read_csv(csv_file)
    print(f"Loaded {len(df)} records from {csv_file}")

    # Generate figure
    generate_regret_scaling_figure(df, figures_dir)

    print("\n" + "=" * 60)
    print(f"Figure saved to: {figures_dir}")
    print("=" * 60)


if __name__ == '__main__':
    main()
