#!/usr/bin/env python3
"""
Alibaba Experiments Visualization: Boxplots

Creates boxplot visualizations similar to figA2_boxplots.png showing
competitive ratio distributions across algorithms for each Alibaba scenario.

Usage:
    python visualize_alibaba_boxplots.py --results-dir ./results/alibaba
    python visualize_alibaba_boxplots.py --results-dir ./results/alibaba/smoke_test
"""

import argparse
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path


# Algorithm display order and colors (matching synthetic experiments)
ALGORITHM_ORDER = ['SP-UCB-α=0', 'SP-UCB-α=0.01', 'SP-UCB-α=0.1', 'OneHot', 'Oracle']
ALGORITHM_LABELS = ['α=0', 'α=0.01', 'α=0.1', 'OneHot', 'Oracle']

ALGORITHM_COLORS = {
    'SP-UCB-α=0': '#1f77b4',      # Blue
    'SP-UCB-α=0.01': '#ff7f0e',   # Orange
    'SP-UCB-α=0.1': '#2ca02c',    # Green
    'OneHot': '#d62728',          # Red
    'Oracle': '#9467bd',          # Purple
}

SCENARIO_LABELS = {
    'quant_8bit': 'Quant-8bit',
    'quant_4bit': 'Quant-4bit',
    'batching': 'Batching',
}


def load_results(results_dir: Path) -> pd.DataFrame:
    """Load results from combined JSON or per-seed CSV."""
    # Try per-seed CSV first (more reliable)
    csv_file = results_dir / "alibaba_per_seed.csv"
    if csv_file.exists():
        print(f"Loading from {csv_file}")
        return pd.read_csv(csv_file)

    # Fall back to combined JSON
    json_file = results_dir / "alibaba_combined_results.json"
    if json_file.exists():
        print(f"Loading from {json_file}")
        with open(json_file) as f:
            data = json.load(f)
        # Filter out errors
        valid = [r for r in data['results'] if 'error' not in r]
        return pd.DataFrame(valid)

    raise FileNotFoundError(f"No results found in {results_dir}")


def create_boxplots_by_scenario(df: pd.DataFrame, output_dir: Path, rho_filter: float = None):
    """
    Create 3-panel boxplot figure (one per scenario).

    Parameters
    ----------
    df : pd.DataFrame
        Results dataframe
    output_dir : Path
        Output directory for figures
    rho_filter : float, optional
        If set, filter to single rho value
    """
    scenarios = list(SCENARIO_LABELS.keys())
    available_scenarios = [s for s in scenarios if s in df['scenario'].unique()]
    n_scenarios = len(available_scenarios)

    if n_scenarios == 0:
        print("No scenario data found!")
        return

    # Filter by rho if specified
    if rho_filter is not None:
        df = df[df['rho'] == rho_filter]
        rho_suffix = f"_rho{rho_filter}"
    else:
        rho_suffix = "_all_rho"

    # Create figure
    fig, axes = plt.subplots(1, n_scenarios, figsize=(4*n_scenarios, 5))
    if n_scenarios == 1:
        axes = [axes]

    for ax, scenario in zip(axes, available_scenarios):
        scenario_df = df[df['scenario'] == scenario]

        # Prepare data for boxplot
        data_by_alg = []
        labels = []
        colors = []

        for alg in ALGORITHM_ORDER:
            if alg in scenario_df['algorithm'].unique():
                alg_data = scenario_df[scenario_df['algorithm'] == alg]['competitive_ratio'].values
                data_by_alg.append(alg_data)
                labels.append(ALGORITHM_LABELS[ALGORITHM_ORDER.index(alg)])
                colors.append(ALGORITHM_COLORS[alg])

        if not data_by_alg:
            ax.set_title(f"{SCENARIO_LABELS[scenario]}\n(no data)")
            continue

        # Create boxplot
        bp = ax.boxplot(data_by_alg, labels=labels, patch_artist=True)

        # Color boxes
        for patch, color in zip(bp['boxes'], colors):
            patch.set_facecolor(color)
            patch.set_alpha(0.7)

        # Style
        ax.set_title(SCENARIO_LABELS[scenario], fontsize=14, fontweight='bold')
        ax.set_ylabel('Competitive Ratio' if ax == axes[0] else '')
        ax.set_ylim(0.3, 1.05)
        ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, linewidth=0.8)
        ax.grid(axis='y', alpha=0.3)
        ax.tick_params(axis='x', rotation=45)

    plt.tight_layout()

    # Save figure
    output_file = output_dir / f"alibaba_boxplots{rho_suffix}.png"
    plt.savefig(output_file, dpi=150, bbox_inches='tight')
    print(f"Saved: {output_file}")

    # Also save PDF for paper
    output_pdf = output_dir / f"alibaba_boxplots{rho_suffix}.pdf"
    plt.savefig(output_pdf, bbox_inches='tight')
    print(f"Saved: {output_pdf}")

    plt.close()


def create_boxplots_by_rho(df: pd.DataFrame, output_dir: Path):
    """
    Create separate boxplot panels for each rho value.

    Creates one figure per scenario with rho values as separate panels.
    """
    scenarios = list(SCENARIO_LABELS.keys())
    available_scenarios = [s for s in scenarios if s in df['scenario'].unique()]
    rho_values = sorted(df['rho'].unique())

    for scenario in available_scenarios:
        scenario_df = df[df['scenario'] == scenario]
        n_rho = len(rho_values)

        fig, axes = plt.subplots(1, n_rho, figsize=(3.5*n_rho, 5), sharey=True)
        if n_rho == 1:
            axes = [axes]

        for ax, rho in zip(axes, rho_values):
            rho_df = scenario_df[scenario_df['rho'] == rho]

            # Prepare data for boxplot
            data_by_alg = []
            labels = []
            colors = []

            for alg in ALGORITHM_ORDER:
                if alg in rho_df['algorithm'].unique():
                    alg_data = rho_df[rho_df['algorithm'] == alg]['competitive_ratio'].values
                    data_by_alg.append(alg_data)
                    labels.append(ALGORITHM_LABELS[ALGORITHM_ORDER.index(alg)])
                    colors.append(ALGORITHM_COLORS[alg])

            if not data_by_alg:
                ax.set_title(f"ρ={rho}\n(no data)")
                continue

            # Create boxplot
            bp = ax.boxplot(data_by_alg, labels=labels, patch_artist=True)

            # Color boxes
            for patch, color in zip(bp['boxes'], colors):
                patch.set_facecolor(color)
                patch.set_alpha(0.7)

            ax.set_title(f'ρ={rho}', fontsize=12)
            ax.set_ylabel('Competitive Ratio' if ax == axes[0] else '')
            ax.set_ylim(0.3, 1.05)
            ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, linewidth=0.8)
            ax.grid(axis='y', alpha=0.3)
            ax.tick_params(axis='x', rotation=45)

        fig.suptitle(f'{SCENARIO_LABELS[scenario]} - Competitive Ratio by Budget Tightness',
                     fontsize=14, fontweight='bold', y=1.02)
        plt.tight_layout()

        output_file = output_dir / f"alibaba_boxplots_{scenario}_by_rho.png"
        plt.savefig(output_file, dpi=150, bbox_inches='tight')
        print(f"Saved: {output_file}")

        plt.close()


def create_combined_heatmap(df: pd.DataFrame, output_dir: Path):
    """Create heatmap of mean CR across scenarios, algorithms, and rho values."""
    pivot = df.pivot_table(
        values='competitive_ratio',
        index=['scenario', 'rho'],
        columns='algorithm',
        aggfunc='mean'
    )

    # Reorder columns
    cols = [c for c in ALGORITHM_ORDER if c in pivot.columns]
    pivot = pivot[cols]

    fig, ax = plt.subplots(figsize=(10, 8))

    # Create heatmap
    im = ax.imshow(pivot.values, cmap='RdYlGn', aspect='auto', vmin=0.4, vmax=1.0)

    # Set ticks
    ax.set_xticks(range(len(cols)))
    ax.set_xticklabels([ALGORITHM_LABELS[ALGORITHM_ORDER.index(c)] for c in cols], rotation=45, ha='right')

    ax.set_yticks(range(len(pivot.index)))
    ax.set_yticklabels([f"{SCENARIO_LABELS.get(s, s)} ρ={r}" for s, r in pivot.index])

    # Add text annotations
    for i in range(len(pivot.index)):
        for j in range(len(cols)):
            val = pivot.values[i, j]
            color = 'white' if val < 0.6 or val > 0.9 else 'black'
            ax.text(j, i, f'{val:.3f}', ha='center', va='center', color=color, fontsize=9)

    ax.set_title('Alibaba Experiments: Mean Competitive Ratio', fontsize=14, fontweight='bold')

    # Colorbar
    cbar = plt.colorbar(im, ax=ax)
    cbar.set_label('Competitive Ratio')

    plt.tight_layout()

    output_file = output_dir / "alibaba_heatmap.png"
    plt.savefig(output_file, dpi=150, bbox_inches='tight')
    print(f"Saved: {output_file}")

    plt.close()


def create_regret_boxplots(df: pd.DataFrame, output_dir: Path, rho_filter: float = 0.7):
    """Create boxplots showing regret distribution."""
    scenarios = list(SCENARIO_LABELS.keys())
    available_scenarios = [s for s in scenarios if s in df['scenario'].unique()]
    n_scenarios = len(available_scenarios)

    if n_scenarios == 0:
        print("No scenario data found!")
        return

    df_filtered = df[df['rho'] == rho_filter]

    fig, axes = plt.subplots(1, n_scenarios, figsize=(4*n_scenarios, 5))
    if n_scenarios == 1:
        axes = [axes]

    for ax, scenario in zip(axes, available_scenarios):
        scenario_df = df_filtered[df_filtered['scenario'] == scenario]

        data_by_alg = []
        labels = []
        colors = []

        for alg in ALGORITHM_ORDER:
            if alg in scenario_df['algorithm'].unique():
                alg_data = scenario_df[scenario_df['algorithm'] == alg]['regret'].values
                data_by_alg.append(alg_data)
                labels.append(ALGORITHM_LABELS[ALGORITHM_ORDER.index(alg)])
                colors.append(ALGORITHM_COLORS[alg])

        if not data_by_alg:
            continue

        bp = ax.boxplot(data_by_alg, labels=labels, patch_artist=True)

        for patch, color in zip(bp['boxes'], colors):
            patch.set_facecolor(color)
            patch.set_alpha(0.7)

        ax.set_title(SCENARIO_LABELS[scenario], fontsize=14, fontweight='bold')
        ax.set_ylabel('Regret' if ax == axes[0] else '')
        ax.grid(axis='y', alpha=0.3)
        ax.tick_params(axis='x', rotation=45)

    fig.suptitle(f'Alibaba Experiments: Regret Distribution (ρ={rho_filter})',
                 fontsize=14, fontweight='bold', y=1.02)
    plt.tight_layout()

    output_file = output_dir / f"alibaba_regret_boxplots_rho{rho_filter}.png"
    plt.savefig(output_file, dpi=150, bbox_inches='tight')
    print(f"Saved: {output_file}")

    plt.close()


def print_summary_table(df: pd.DataFrame):
    """Print summary statistics table."""
    print("\n" + "="*80)
    print("ALIBABA EXPERIMENT SUMMARY")
    print("="*80)

    # Mean CR by scenario and algorithm
    print("\nMean Competitive Ratio by Scenario and Algorithm:")
    pivot_cr = df.pivot_table(
        values='competitive_ratio',
        index='scenario',
        columns='algorithm',
        aggfunc='mean'
    ).round(4)

    # Reorder columns
    cols = [c for c in ALGORITHM_ORDER if c in pivot_cr.columns]
    pivot_cr = pivot_cr[cols]
    print(pivot_cr.to_string())

    # Mean CR by scenario, rho, and algorithm
    print("\nMean Competitive Ratio by Scenario, ρ, and Algorithm:")
    pivot_detail = df.pivot_table(
        values='competitive_ratio',
        index=['scenario', 'rho'],
        columns='algorithm',
        aggfunc='mean'
    ).round(4)
    cols = [c for c in ALGORITHM_ORDER if c in pivot_detail.columns]
    pivot_detail = pivot_detail[cols]
    print(pivot_detail.to_string())

    # Standard deviation
    print("\nStd Dev of Competitive Ratio by Scenario and Algorithm:")
    pivot_std = df.pivot_table(
        values='competitive_ratio',
        index='scenario',
        columns='algorithm',
        aggfunc='std'
    ).round(4)
    cols = [c for c in ALGORITHM_ORDER if c in pivot_std.columns]
    pivot_std = pivot_std[cols]
    print(pivot_std.to_string())

    print("\n" + "="*80)


def main():
    parser = argparse.ArgumentParser(description='Visualize Alibaba experiment results')
    parser.add_argument('--results-dir', type=str, default='./results/alibaba',
                       help='Results directory')
    parser.add_argument('--rho', type=float, default=None,
                       help='Filter to specific rho value for main boxplot')

    args = parser.parse_args()

    results_dir = Path(args.results_dir)

    if not results_dir.exists():
        print(f"Error: Results directory {results_dir} does not exist")
        print("Run the experiments first with: python run_alibaba_experiments.py --worker all")
        return

    # Load results
    try:
        df = load_results(results_dir)
    except FileNotFoundError as e:
        print(f"Error: {e}")
        return

    print(f"Loaded {len(df)} results")
    print(f"Scenarios: {df['scenario'].unique().tolist()}")
    print(f"Algorithms: {df['algorithm'].unique().tolist()}")
    print(f"Rho values: {sorted(df['rho'].unique().tolist())}")
    print(f"Seeds: {sorted(df['seed'].unique().tolist())}")

    # Create output directory for figures
    figures_dir = results_dir / "figures"
    figures_dir.mkdir(exist_ok=True)

    # Generate all visualizations
    print("\nGenerating visualizations...")

    # Main boxplots (all rho values combined or filtered)
    create_boxplots_by_scenario(df, figures_dir, rho_filter=args.rho)

    # Boxplots separated by rho
    create_boxplots_by_rho(df, figures_dir)

    # Heatmap
    create_combined_heatmap(df, figures_dir)

    # Regret boxplots for typical rho
    for rho in [0.7, 1.0]:
        if rho in df['rho'].values:
            create_regret_boxplots(df, figures_dir, rho_filter=rho)

    # Print summary
    print_summary_table(df)

    print(f"\nAll figures saved to: {figures_dir}")


if __name__ == '__main__':
    main()
