#!/usr/bin/env python3
"""
Paper Experiments for SP-UCB-OLP

This script runs all experiments needed for the ICML paper on
switching-aware hybrid bandits for joint configuration selection
and admission control.

Experiments:
1. Synthetic S1: Pure Complementarity (V^mix - V* gap demonstration)
2. Synthetic S2: Noisy Complementarity (statistical learning)
3. Synthetic S3: Non-Complementarity (control case)
4. Alibaba: Real-world ML serving scenarios

Usage:
    python run_paper_experiments.py --experiment all
    python run_paper_experiments.py --experiment synthetic
    python run_paper_experiments.py --experiment alibaba
    python run_paper_experiments.py --experiment quick  # Fast test run
"""

import argparse
import json
import numpy as np
from pathlib import Path
from datetime import datetime
from sp_ucb_olp import ExperimentRunner
from sp_ucb_olp.data import (
    S1ComplementarityLoader,
    S2NoisyLoader,
    S3DominantLoader,
    AlibabaDataLoader,
)
from sp_ucb_olp.visualization import (
    plot_competitive_ratio_comparison,
    create_summary_figure,
    save_figure,
)


# Default experiment configuration
DEFAULT_CONFIG = {
    'algorithms': ['SP-UCB-OLP', 'Greedy', 'OneHot', 'Oracle', 'Random'],
    'n_seeds': 10,
    'seeds': list(range(42, 52)),
    'rho_values': [0.5, 0.7, 1.0, 1.2],
}

# Paper configuration (full experiments)
PAPER_CONFIG = {
    'algorithms': ['SP-UCB-OLP', 'Greedy', 'OneHot', 'Oracle', 'Random'],
    'n_seeds': 20,
    'seeds': list(range(42, 62)),
    'rho_values': [0.3, 0.5, 0.7, 0.9, 1.2],
}

# Quick test configuration
QUICK_CONFIG = {
    'algorithms': ['SP-UCB-OLP', 'Greedy', 'Random'],
    'n_seeds': 3,
    'seeds': [42, 43, 44],
    'rho_values': [1.0],
}


def run_synthetic_experiments(
    results_dir: Path,
    config: dict,
    T: int = 10000,
    verbose: bool = True
):
    """
    Run synthetic experiments (S1, S2, S3).

    Parameters
    ----------
    results_dir : Path
        Directory to save results
    config : dict
        Experiment configuration
    T : int
        Time horizon
    verbose : bool
        Print progress
    """
    results_dir.mkdir(parents=True, exist_ok=True)

    scenarios = [
        ('S1', S1ComplementarityLoader, {'K': 4, 'T': T}),  # Envelope learning
        ('S2', S2NoisyLoader, {'K': 4, 'T': T}),            # Deceptive arms
        ('S3', S3DominantLoader, {'K': 2, 'T': T}),         # Selective admission
    ]

    all_results = {}

    for name, LoaderClass, loader_kwargs in scenarios:
        if verbose:
            print(f"\n{'='*60}")
            print(f"Running {name} Experiments (T={T})")
            print(f"{'='*60}")

        scenario_results = {}

        for rho in config['rho_values']:
            if verbose:
                print(f"\n--- Budget factor rho={rho} ---")

            # Create loader
            loader = LoaderClass(seed=42, **loader_kwargs)

            # Create runner
            runner = ExperimentRunner(loader, rho=rho, results_dir=str(results_dir))

            # Run comparison
            results = runner.run_comparison(
                config['algorithms'],
                config['seeds'][:config['n_seeds']],
                verbose=verbose
            )

            # Store results
            scenario_results[rho] = results.to_dict()

            # Save individual results
            filename = f"{name}_rho{rho:.1f}_results.json"
            runner.save_results(results, filename)

            if verbose:
                print(f"\nResults for {name} (rho={rho}):")
                for alg in results.algorithms:
                    summary = results.get_summary(alg)
                    print(f"  {alg:15} ratio={summary['mean_competitive_ratio']:.4f} "
                          f"+/- {summary['std_competitive_ratio']:.4f}")

        all_results[name] = scenario_results

    # Save combined results
    combined_path = results_dir / 'synthetic_combined_results.json'
    with open(combined_path, 'w') as f:
        json.dump(all_results, f, indent=2)

    if verbose:
        print(f"\n\nSaved combined results to: {combined_path}")

    return all_results


def run_alibaba_experiments(
    results_dir: Path,
    config: dict,
    T: int = 10000,
    verbose: bool = True
):
    """
    Run Alibaba cluster trace experiments.

    Parameters
    ----------
    results_dir : Path
        Directory to save results
    config : dict
        Experiment configuration
    T : int
        Time horizon
    verbose : bool
        Print progress
    """
    results_dir.mkdir(parents=True, exist_ok=True)

    scenarios = ['quant_8bit', 'quant_4bit', 'batching']
    all_results = {}

    for scenario in scenarios:
        if verbose:
            print(f"\n{'='*60}")
            print(f"Running Alibaba {scenario} Experiments (T={T})")
            print(f"{'='*60}")

        scenario_results = {}

        for rho in config['rho_values']:
            if verbose:
                print(f"\n--- Budget factor rho={rho} ---")

            # Create loader
            loader = AlibabaDataLoader(T=T, seed=42, scenario=scenario)

            # Create runner
            runner = ExperimentRunner(loader, rho=rho, results_dir=str(results_dir))

            # Run comparison
            results = runner.run_comparison(
                config['algorithms'],
                config['seeds'][:config['n_seeds']],
                verbose=verbose
            )

            # Store results
            scenario_results[rho] = results.to_dict()

            # Save individual results
            filename = f"Alibaba_{scenario}_rho{rho:.1f}_results.json"
            runner.save_results(results, filename)

            if verbose:
                print(f"\nResults for Alibaba-{scenario} (rho={rho}):")
                for alg in results.algorithms:
                    summary = results.get_summary(alg)
                    print(f"  {alg:15} ratio={summary['mean_competitive_ratio']:.4f} "
                          f"+/- {summary['std_competitive_ratio']:.4f}")

        all_results[scenario] = scenario_results

    # Save combined results
    combined_path = results_dir / 'alibaba_combined_results.json'
    with open(combined_path, 'w') as f:
        json.dump(all_results, f, indent=2)

    if verbose:
        print(f"\n\nSaved combined results to: {combined_path}")

    return all_results


def run_all_experiments(
    results_dir: Path,
    config: dict,
    T_synthetic: int = 10000,
    T_alibaba: int = 10000,
    verbose: bool = True
):
    """Run all paper experiments."""
    if verbose:
        print("\n" + "="*70)
        print("SP-UCB-OLP PAPER EXPERIMENTS")
        print("="*70)
        print(f"\nConfiguration:")
        print(f"  Algorithms: {config['algorithms']}")
        print(f"  Seeds: {config['n_seeds']}")
        print(f"  Rho values: {config['rho_values']}")
        print(f"  T (synthetic): {T_synthetic}")
        print(f"  T (Alibaba): {T_alibaba}")

    # Run synthetic experiments
    synthetic_results = run_synthetic_experiments(
        results_dir / 'synthetic',
        config,
        T=T_synthetic,
        verbose=verbose
    )

    # Run Alibaba experiments
    alibaba_results = run_alibaba_experiments(
        results_dir / 'alibaba',
        config,
        T=T_alibaba,
        verbose=verbose
    )

    # Generate summary
    if verbose:
        print("\n" + "="*70)
        print("EXPERIMENT SUMMARY")
        print("="*70)

        print("\n--- Synthetic Experiments ---")
        for name in ['S1', 'S2', 'S3']:
            if name in synthetic_results:
                rho_1 = synthetic_results[name].get(1.0, {})
                if 'summaries' in rho_1:
                    print(f"\n{name} (rho=1.0):")
                    for alg, stats in rho_1['summaries'].items():
                        print(f"  {alg:15} ratio={stats['mean_competitive_ratio']:.4f}")

        print("\n--- Alibaba Experiments ---")
        for scenario in ['quant_8bit', 'quant_4bit', 'batching']:
            if scenario in alibaba_results:
                rho_1 = alibaba_results[scenario].get(1.0, {})
                if 'summaries' in rho_1:
                    print(f"\n{scenario} (rho=1.0):")
                    for alg, stats in rho_1['summaries'].items():
                        print(f"  {alg:15} ratio={stats['mean_competitive_ratio']:.4f}")

    return {
        'synthetic': synthetic_results,
        'alibaba': alibaba_results
    }


def main():
    """Main entry point."""
    parser = argparse.ArgumentParser(
        description='Run SP-UCB-OLP paper experiments',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  python run_paper_experiments.py --experiment quick
  python run_paper_experiments.py --experiment synthetic --T 10000
  python run_paper_experiments.py --experiment alibaba --T 20000
  python run_paper_experiments.py --experiment all --config paper
        """
    )

    parser.add_argument(
        '--experiment',
        type=str,
        choices=['quick', 'synthetic', 'alibaba', 'all'],
        default='quick',
        help='Experiment to run (default: quick)'
    )

    parser.add_argument(
        '--config',
        type=str,
        choices=['quick', 'default', 'paper'],
        default='default',
        help='Configuration profile (default: default)'
    )

    parser.add_argument(
        '--T',
        type=int,
        default=None,
        help='Time horizon (overrides config default)'
    )

    parser.add_argument(
        '--results-dir',
        type=str,
        default='./results',
        help='Results directory (default: ./results)'
    )

    parser.add_argument(
        '--quiet',
        action='store_true',
        help='Suppress verbose output'
    )

    args = parser.parse_args()

    # Select configuration
    if args.config == 'quick' or args.experiment == 'quick':
        config = QUICK_CONFIG.copy()
        T_default = 1000
    elif args.config == 'paper':
        config = PAPER_CONFIG.copy()
        T_default = 10000
    else:
        config = DEFAULT_CONFIG.copy()
        T_default = 10000

    T = args.T if args.T else T_default

    # Setup results directory
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    results_dir = Path(args.results_dir) / f'experiment_{timestamp}'

    verbose = not args.quiet

    if verbose:
        print(f"Results will be saved to: {results_dir}")

    # Run experiments
    if args.experiment == 'quick':
        run_synthetic_experiments(results_dir, config, T=T, verbose=verbose)
    elif args.experiment == 'synthetic':
        run_synthetic_experiments(results_dir, config, T=T, verbose=verbose)
    elif args.experiment == 'alibaba':
        run_alibaba_experiments(results_dir, config, T=T, verbose=verbose)
    else:  # all
        run_all_experiments(results_dir, config, T_synthetic=T, T_alibaba=T, verbose=verbose)

    if verbose:
        print(f"\n\nAll results saved to: {results_dir}")


if __name__ == '__main__':
    main()
