#!/usr/bin/env python3
"""
Generate Publication-Ready Figures

Generates figures for the paper without "Figure X:" prefixes in titles.

Figures:
1. Synthetic boxplots (replaces figA2_boxplots.png)
2. CR vs T line plots (replaces fig7_cr_vs_T.png)
3. Alibaba boxplots (new)

Usage:
    python generate_paper_figures.py
"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path

# Set publication-quality defaults
plt.rcParams.update({
    'font.size': 12,
    'axes.labelsize': 14,
    'axes.titlesize': 14,
    'xtick.labelsize': 11,
    'ytick.labelsize': 11,
    'legend.fontsize': 10,
    'figure.dpi': 150,
    'savefig.dpi': 300,
    'savefig.bbox': 'tight',
})

# Colors matching original figures
ALGORITHM_COLORS = {
    'SP-UCB-α=0': '#1f77b4',      # Blue
    'SP-UCB-α=0.01': '#ff7f0e',   # Orange
    'SP-UCB-α=0.1': '#2ca02c',    # Green
    'OneHot': '#e377c2',          # Pink
    'Oracle': '#17becf',          # Cyan
}

ALGORITHM_ORDER = ['SP-UCB-α=0', 'SP-UCB-α=0.01', 'SP-UCB-α=0.1', 'OneHot', 'Oracle']
ALGORITHM_LABELS = ['α=0', 'α=0.01', 'α=0.1', 'OneHot', 'Oracle']

SCENARIO_NAMES = {
    'S1': 'S1',
    'S2': 'S2',
    'S3': 'S3',
}

SCENARIO_FULL_NAMES = {
    'S1': 'S1: High-Variance',
    'S2': 'S2: Deceptive Arms',
    'S3': 'S3: Selective Admission',
}

ALIBABA_SCENARIO_NAMES = {
    'quant_8bit': 'Quant-8bit',
    'quant_4bit': 'Quant-4bit',
    'batching': 'Batching',
}


def load_synthetic_data(results_dir: Path) -> pd.DataFrame:
    """Load synthetic experiment data."""
    csv_file = results_dir / "complete_per_seed_data.csv"
    if csv_file.exists():
        return pd.read_csv(csv_file)
    raise FileNotFoundError(f"Synthetic data not found at {csv_file}")


def load_regret_data(results_dir: Path) -> pd.DataFrame:
    """Load regret scaling data."""
    csv_file = results_dir / "regret_analysis" / "regret_per_seed.csv"
    if csv_file.exists():
        return pd.read_csv(csv_file)
    raise FileNotFoundError(f"Regret data not found at {csv_file}")


def load_alibaba_data(results_dir: Path) -> pd.DataFrame:
    """Load Alibaba experiment data."""
    csv_file = results_dir / "alibaba" / "alibaba_per_seed.csv"
    if csv_file.exists():
        return pd.read_csv(csv_file)
    raise FileNotFoundError(f"Alibaba data not found at {csv_file}")


def generate_synthetic_boxplots(df: pd.DataFrame, output_dir: Path, rho_filter: float = 0.7):
    """
    Generate boxplots for synthetic experiments (replaces figA2_boxplots.png).

    Style: 3-panel (S1, S2, S3), boxplots by algorithm, y-axis = competitive ratio
    """
    scenarios = ['S1', 'S2', 'S3']
    df_filtered = df[df['rho'] == rho_filter]

    fig, axes = plt.subplots(1, 3, figsize=(14, 5))

    for ax, scenario in zip(axes, scenarios):
        scenario_df = df_filtered[df_filtered['scenario'] == scenario]

        data_by_alg = []
        colors = []

        for alg in ALGORITHM_ORDER:
            if alg in scenario_df['algorithm'].unique():
                alg_data = scenario_df[scenario_df['algorithm'] == alg]['competitive_ratio'].values
                data_by_alg.append(alg_data)
                colors.append(ALGORITHM_COLORS[alg])
            else:
                data_by_alg.append([])
                colors.append('#999999')

        bp = ax.boxplot(data_by_alg, tick_labels=ALGORITHM_LABELS, patch_artist=True)

        for patch, color in zip(bp['boxes'], colors):
            patch.set_facecolor(color)
            patch.set_alpha(0.7)

        ax.set_title(SCENARIO_NAMES[scenario], fontsize=14, fontweight='bold')
        ax.set_ylabel('Competitive Ratio' if ax == axes[0] else '')
        ax.set_ylim(0.4, 1.05)
        ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, linewidth=0.8)
        ax.grid(axis='y', alpha=0.3)
        ax.tick_params(axis='x', rotation=45)

    fig.suptitle('Distribution of Competitive Ratios (10 seeds)', fontsize=16, fontweight='bold', y=1.02)
    plt.tight_layout()

    # Save
    output_file = output_dir / "synthetic_boxplots.png"
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    print(f"Saved: {output_file}")

    output_pdf = output_dir / "synthetic_boxplots.pdf"
    plt.savefig(output_pdf, bbox_inches='tight')
    print(f"Saved: {output_pdf}")

    plt.close()


def generate_cr_vs_T(df: pd.DataFrame, output_dir: Path):
    """
    Generate CR vs Horizon T plot (replaces fig7_cr_vs_T.png).

    Style: 3-panel, line plots with confidence intervals, log x-axis
    """
    scenarios = ['S1', 'S2', 'S3']

    fig, axes = plt.subplots(1, 3, figsize=(14, 4.5), sharey=True)

    # Create shared legend data
    legend_handles = []
    legend_labels = []

    for ax, scenario in zip(axes, scenarios):
        scenario_df = df[df['scenario'] == scenario]

        for alg in ALGORITHM_ORDER:
            alg_df = scenario_df[scenario_df['algorithm'] == alg]
            if len(alg_df) == 0:
                continue

            # Group by T and compute mean/std
            grouped = alg_df.groupby('T')['competitive_ratio'].agg(['mean', 'std', 'count'])
            T_values = grouped.index.values
            means = grouped['mean'].values
            stds = grouped['std'].values
            counts = grouped['count'].values

            # Standard error
            se = stds / np.sqrt(counts)

            color = ALGORITHM_COLORS[alg]
            label = alg.replace('SP-UCB-', 'SP-UCB-')

            line, = ax.plot(T_values, means, 'o-', color=color, label=label,
                           linewidth=2, markersize=5)
            ax.fill_between(T_values, means - 1.96*se, means + 1.96*se,
                           color=color, alpha=0.2)

            if ax == axes[0]:
                legend_handles.append(line)
                legend_labels.append(label)

        ax.set_xscale('log')
        ax.set_xlabel('Horizon T')
        ax.set_ylabel('Competitive Ratio' if ax == axes[0] else '')
        ax.set_title(SCENARIO_FULL_NAMES[scenario], fontsize=12, fontweight='bold')
        ax.set_ylim(0.4, 1.05)
        ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, linewidth=0.8)
        ax.grid(alpha=0.3)

        # Format x-axis ticks
        ax.set_xticks([100, 200, 500, 1000, 2000, 5000, 10000])
        ax.set_xticklabels(['100', '200', '500', '1K', '2K', '5K', '10K'])

    # Add shared legend at top
    fig.legend(legend_handles, legend_labels, loc='upper center',
               ncol=5, bbox_to_anchor=(0.5, 1.02), frameon=False)

    fig.suptitle('Competitive Ratio vs Horizon T', fontsize=16, fontweight='bold', y=1.08)
    plt.tight_layout()

    # Save
    output_file = output_dir / "cr_vs_T.png"
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    print(f"Saved: {output_file}")

    output_pdf = output_dir / "cr_vs_T.pdf"
    plt.savefig(output_pdf, bbox_inches='tight')
    print(f"Saved: {output_pdf}")

    plt.close()


def generate_alibaba_boxplots(df: pd.DataFrame, output_dir: Path):
    """
    Generate boxplots for Alibaba experiments.

    Style: 3-panel (quant_8bit, quant_4bit, batching), matching synthetic style
    """
    scenarios = ['quant_8bit', 'quant_4bit', 'batching']

    fig, axes = plt.subplots(1, 3, figsize=(14, 5))

    for ax, scenario in zip(axes, scenarios):
        scenario_df = df[df['scenario'] == scenario]

        data_by_alg = []
        colors = []

        for alg in ALGORITHM_ORDER:
            if alg in scenario_df['algorithm'].unique():
                alg_data = scenario_df[scenario_df['algorithm'] == alg]['competitive_ratio'].values
                data_by_alg.append(alg_data)
                colors.append(ALGORITHM_COLORS[alg])
            else:
                data_by_alg.append([])
                colors.append('#999999')

        bp = ax.boxplot(data_by_alg, tick_labels=ALGORITHM_LABELS, patch_artist=True)

        for patch, color in zip(bp['boxes'], colors):
            patch.set_facecolor(color)
            patch.set_alpha(0.7)

        ax.set_title(ALIBABA_SCENARIO_NAMES[scenario], fontsize=14, fontweight='bold')
        ax.set_ylabel('Competitive Ratio' if ax == axes[0] else '')
        ax.set_ylim(0, 1.05)
        ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, linewidth=0.8)
        ax.grid(axis='y', alpha=0.3)
        ax.tick_params(axis='x', rotation=45)

    fig.suptitle('Alibaba Experiments: Distribution of Competitive Ratios (10 seeds)',
                 fontsize=16, fontweight='bold', y=1.02)
    plt.tight_layout()

    # Save
    output_file = output_dir / "alibaba_boxplots.png"
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    print(f"Saved: {output_file}")

    output_pdf = output_dir / "alibaba_boxplots.pdf"
    plt.savefig(output_pdf, bbox_inches='tight')
    print(f"Saved: {output_pdf}")

    plt.close()


def main():
    # Paths
    script_dir = Path(__file__).parent.parent
    results_dir = script_dir / "results"
    figures_dir = script_dir / "figures"
    figures_dir.mkdir(exist_ok=True)

    print("="*60)
    print("Generating Publication-Ready Figures")
    print("="*60)

    # 1. Synthetic boxplots (replaces figA2_boxplots.png)
    print("\n1. Generating synthetic boxplots...")
    try:
        synthetic_df = load_synthetic_data(results_dir)
        generate_synthetic_boxplots(synthetic_df, figures_dir)
    except FileNotFoundError as e:
        print(f"   Skipped: {e}")

    # 2. CR vs T (replaces fig7_cr_vs_T.png)
    print("\n2. Generating CR vs T plots...")
    try:
        regret_df = load_regret_data(results_dir)
        generate_cr_vs_T(regret_df, figures_dir)
    except FileNotFoundError as e:
        print(f"   Skipped: {e}")

    # 3. Alibaba boxplots
    print("\n3. Generating Alibaba boxplots...")
    try:
        alibaba_df = load_alibaba_data(results_dir)
        generate_alibaba_boxplots(alibaba_df, figures_dir)
    except FileNotFoundError as e:
        print(f"   Skipped: {e}")

    print("\n" + "="*60)
    print(f"All figures saved to: {figures_dir}")
    print("="*60)


if __name__ == '__main__':
    main()
