#!/usr/bin/env python3
"""
Comprehensive Experiment Runner for SP-UCB-OLP ICML Paper

Runs all experiments needed for strong ICML acceptance:

E1: S4 Benchmark Sanity Check (validates V^mix > V* pathology)
    - 5 algorithms × 5 ρ values × 10 seeds = 250 runs

E2: Ablation Experiments (isolates algorithm components)
    - 4 ablation variants × 2 scenarios × 3 ρ values × 10 seeds = 240 runs

E3: K-Sweep (validates √K scaling)
    - 5 algorithms × 4 K values × 10 seeds = 200 runs

E4: Synthetic Scenarios S1/S2/S3 (already exists, updated with redesigned scenarios)
    - 5 algorithms × 3 scenarios × 5 ρ values × 10 seeds = 750 runs

E5: Alibaba Trace Experiments (already exists)
    - Runs separately via run_alibaba_experiments.py

Total new runs: 250 + 240 + 200 = 690 runs (~1-2 hours)

Usage:
    # Run all experiments
    python run_all_experiments.py

    # Run specific experiment
    python run_all_experiments.py --experiment E1
    python run_all_experiments.py --experiment E2
    python run_all_experiments.py --experiment E3

    # Smoke test (quick validation)
    python run_all_experiments.py --smoke

    # Run with multiple parallel workers
    python run_all_experiments.py --workers 4
"""

import argparse
import json
import time
import numpy as np
from pathlib import Path
from datetime import datetime
import sys
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Dict, Any, List, Tuple
import multiprocessing

from sp_ucb_olp.data import (
    S1ComplementarityLoader,
    S2NoisyLoader,
    S3DominantLoader,
    S4ComplementarityLoader,
)
from sp_ucb_olp.algorithms import get_algorithm


# =============================================================================
# EXPERIMENT CONFIGURATIONS
# =============================================================================

# E1: S4 Benchmark Sanity Check
E1_CONFIG = {
    'name': 'E1_Benchmark_Sanity',
    'description': 'Validates V^mix > V* pathology on complementarity scenario',
    'scenario': 'S4',
    'loader': S4ComplementarityLoader,
    'K': 2,
    'd': 2,
    'T': 5000,
    'algorithms': ['SP-UCB-OLP', 'Greedy', 'OneHot', 'Oracle', 'Random'],
    'rho_values': [0.3, 0.5, 0.7, 0.9, 1.2],
    'seeds': list(range(42, 52)),
    'algorithm_configs': {
        'SP-UCB-OLP': {'alpha': 0.1, 'warm_start': True},
        'Greedy': {},
        'OneHot': {'alpha': 0.1},
        'Oracle': {},
        'Random': {},
    },
}

# E2: Ablation Experiments
E2_CONFIG = {
    'name': 'E2_Ablations',
    'description': 'Isolates algorithm components via ablation variants',
    'scenarios': {
        'S3': {'loader': S3DominantLoader, 'K': 2, 'd': 3},
        'S4': {'loader': S4ComplementarityLoader, 'K': 2, 'd': 2},
    },
    'T': 5000,
    'algorithms': [
        'SP-UCB-OLP',      # Full algorithm (baseline for comparison)
        'EnvelopeGreedy',  # No mixture sampling
        'MixtureLocalPrice',  # Per-config prices
        'NoSlack',         # ε=0
        'AcceptedOnly',    # Selection bias
    ],
    'rho_values': [0.5, 0.7, 1.0],
    'seeds': list(range(42, 52)),
    'algorithm_configs': {
        'SP-UCB-OLP': {'alpha': 0.1, 'warm_start': True},
        'EnvelopeGreedy': {'alpha': 0.1, 'warm_start': True},
        'MixtureLocalPrice': {'alpha': 0.1, 'warm_start': True},
        'NoSlack': {'alpha': 0.1, 'warm_start': True},
        'AcceptedOnly': {'alpha': 0.1, 'warm_start': True},
    },
}

# E3: K-Sweep (imports from run_k_sweep.py)
E3_CONFIG = {
    'name': 'E3_K_Sweep',
    'description': 'Validates sqrt(K) scaling in regret bound',
    # This experiment uses run_k_sweep.py directly
}

# Smoke test configs (quick validation)
E1_SMOKE = {**E1_CONFIG, 'T': 500, 'rho_values': [0.5, 1.0], 'seeds': [42, 43]}
E2_SMOKE = {**E2_CONFIG, 'T': 500, 'rho_values': [0.7], 'seeds': [42, 43]}


# =============================================================================
# SINGLE RUN FUNCTION
# =============================================================================

def run_single_experiment(params: Dict[str, Any]) -> Dict[str, Any]:
    """
    Run a single experiment with given parameters.

    Parameters
    ----------
    params : dict
        Must contain: scenario, algorithm, rho, seed, T, K, d, loader_class, alg_config

    Returns
    -------
    dict : Result dictionary with all metrics
    """
    scenario = params['scenario']
    algorithm_name = params['algorithm']
    rho = params['rho']
    seed = params['seed']
    T = params['T']
    K = params['K']
    d = params['d']
    LoaderClass = params['loader_class']
    alg_config = params['alg_config'].copy()

    try:
        # Create loader
        loader = LoaderClass(K=K, T=T, seed=seed, d=d)
        d = loader.d  # May be overridden by loader

        B = loader.get_budget(rho)

        # Get oracle values
        oracle_values = loader.get_oracle_values(rho=rho)
        V_mix = oracle_values['V_mix']
        V_star = oracle_values['V_star']

        # Setup oracle config if needed
        if algorithm_name == 'Oracle':
            alg_config['w_star'] = oracle_values['w_star']
            alg_config['p_star'] = oracle_values['p_star']

        # Create algorithm
        algorithm = get_algorithm(algorithm_name, K, d, T, B, alg_config)

        # Run simulation
        start_time = time.time()

        for t in range(T):
            theta, w, p = algorithm.select_config(t)
            r, a = loader.get_arrival(theta, t)
            accept = algorithm.decide_admission(t, theta, r, a, p)
            algorithm.update(t, theta, r, a, accept)

        elapsed_time = time.time() - start_time

        # Collect results
        stats = algorithm.get_statistics()

        # Compute metrics
        CR_mix = stats['total_reward'] / (T * V_mix) if V_mix > 0 else 0.0
        CR_star = stats['total_reward'] / (T * V_star) if V_star > 0 else 0.0
        regret_mix = T * V_mix - stats['total_reward']

        result = {
            # Identifiers
            'experiment': params.get('experiment', 'unknown'),
            'scenario': scenario,
            'algorithm': algorithm_name,
            'rho': rho,
            'seed': seed,
            'K': K,
            'd': d,
            'T': T,

            # Key metrics
            'total_reward': stats['total_reward'],
            'CR_mix': CR_mix,
            'CR_star': CR_star,
            'regret_mix': regret_mix,
            'acceptance_rate': stats['acceptance_rate'],

            # Oracle values
            'V_mix': V_mix,
            'V_star': V_star,
            'complementarity_gap': V_mix / V_star if V_star > 0 else float('inf'),

            # Budget info
            'B_remaining': list(stats['B_remaining']),
            'budget_utilization': list(stats['budget_utilization']),

            # Config usage
            'N_theta': stats['N_theta'],

            # Metadata
            'elapsed_time': elapsed_time,
            'status': 'success',
        }

        return result

    except Exception as e:
        import traceback
        return {
            'experiment': params.get('experiment', 'unknown'),
            'scenario': scenario,
            'algorithm': algorithm_name,
            'rho': rho,
            'seed': seed,
            'status': 'error',
            'error': str(e),
            'traceback': traceback.format_exc(),
        }


# =============================================================================
# EXPERIMENT RUNNERS
# =============================================================================

def run_e1(config: Dict, results_dir: Path, n_workers: int = 4, verbose: bool = True) -> List[Dict]:
    """Run E1: Benchmark Sanity Check experiments."""
    if verbose:
        print("\n" + "="*70)
        print("E1: BENCHMARK SANITY CHECK (S4 Complementarity)")
        print("Purpose: Validate that V^mix > V* and CR^* can exceed 1")
        print("="*70)

    # Generate all run parameters
    run_params = []
    for seed in config['seeds']:
        for rho in config['rho_values']:
            for algorithm in config['algorithms']:
                run_params.append({
                    'experiment': 'E1',
                    'scenario': config['scenario'],
                    'algorithm': algorithm,
                    'rho': rho,
                    'seed': seed,
                    'T': config['T'],
                    'K': config['K'],
                    'd': config['d'],
                    'loader_class': config['loader'],
                    'alg_config': config['algorithm_configs'].get(algorithm, {}),
                })

    total_runs = len(run_params)
    if verbose:
        print(f"Total runs: {total_runs}")

    # Run experiments
    results = run_experiments_parallel(run_params, n_workers, verbose)

    # Save results
    save_results(results, results_dir / 'E1_benchmark_sanity.json')

    # Create summary
    summarize_e1(results, results_dir, verbose)

    return results


def run_e2(config: Dict, results_dir: Path, n_workers: int = 4, verbose: bool = True) -> List[Dict]:
    """Run E2: Ablation experiments."""
    if verbose:
        print("\n" + "="*70)
        print("E2: ABLATION EXPERIMENTS")
        print("Purpose: Isolate necessity of mixture sampling and global prices")
        print("="*70)

    # Generate all run parameters
    run_params = []
    for scenario_name, scenario_config in config['scenarios'].items():
        for seed in config['seeds']:
            for rho in config['rho_values']:
                for algorithm in config['algorithms']:
                    run_params.append({
                        'experiment': 'E2',
                        'scenario': scenario_name,
                        'algorithm': algorithm,
                        'rho': rho,
                        'seed': seed,
                        'T': config['T'],
                        'K': scenario_config['K'],
                        'd': scenario_config['d'],
                        'loader_class': scenario_config['loader'],
                        'alg_config': config['algorithm_configs'].get(algorithm, {}),
                    })

    total_runs = len(run_params)
    if verbose:
        print(f"Total runs: {total_runs}")

    # Run experiments
    results = run_experiments_parallel(run_params, n_workers, verbose)

    # Save results
    save_results(results, results_dir / 'E2_ablations.json')

    # Create summary
    summarize_e2(results, results_dir, verbose)

    return results


def run_experiments_parallel(
    run_params: List[Dict],
    n_workers: int = 4,
    verbose: bool = True
) -> List[Dict]:
    """Run experiments in parallel using ProcessPoolExecutor."""
    results = []
    total = len(run_params)
    completed = 0
    start_time = time.time()

    if n_workers == 1:
        # Sequential execution for debugging
        for params in run_params:
            result = run_single_experiment(params)
            results.append(result)
            completed += 1
            if verbose and completed % 10 == 0:
                elapsed = time.time() - start_time
                eta = elapsed / completed * (total - completed)
                print(f"  Progress: {completed}/{total} ({100*completed/total:.0f}%) ETA: {eta:.0f}s")
    else:
        # Parallel execution
        with ProcessPoolExecutor(max_workers=n_workers) as executor:
            futures = {executor.submit(run_single_experiment, p): p for p in run_params}

            for future in as_completed(futures):
                result = future.result()
                results.append(result)
                completed += 1

                if verbose and completed % 20 == 0:
                    elapsed = time.time() - start_time
                    eta = elapsed / completed * (total - completed)
                    print(f"  Progress: {completed}/{total} ({100*completed/total:.0f}%) ETA: {eta:.0f}s")

    total_time = time.time() - start_time
    if verbose:
        n_errors = sum(1 for r in results if r.get('status') == 'error')
        print(f"Completed {total} runs in {total_time:.1f}s ({n_errors} errors)")

    return results


def save_results(results: List[Dict], filepath: Path):
    """Save results to JSON file."""
    filepath.parent.mkdir(parents=True, exist_ok=True)
    with open(filepath, 'w') as f:
        json.dump({
            'total_results': len(results),
            'timestamp': datetime.now().isoformat(),
            'results': results,
        }, f, indent=2, default=str)
    print(f"Saved {len(results)} results to {filepath}")


# =============================================================================
# SUMMARY FUNCTIONS
# =============================================================================

def summarize_e1(results: List[Dict], results_dir: Path, verbose: bool = True):
    """Create summary for E1 experiments."""
    import pandas as pd

    df = pd.DataFrame([r for r in results if r.get('status') == 'success'])

    if df.empty:
        print("No successful E1 results to summarize")
        return

    # Key metric: CR^* should exceed 1 for SP-UCB-OLP (proves fixed oracle can be beaten)
    summary = df.groupby(['algorithm', 'rho']).agg({
        'CR_mix': ['mean', 'std'],
        'CR_star': ['mean', 'std'],
        'complementarity_gap': ['mean'],
    }).round(4)

    summary.to_csv(results_dir / 'E1_summary.csv')

    if verbose:
        print("\n" + "-"*70)
        print("E1 SUMMARY: Competitive Ratios (should show CR^* > 1 for SP-UCB)")
        print("-"*70)

        cr_star_pivot = df.pivot_table(
            values='CR_star',
            index='rho',
            columns='algorithm',
            aggfunc='mean'
        ).round(4)
        print("\nCR^* (vs fixed oracle V*):")
        print(cr_star_pivot.to_string())

        cr_mix_pivot = df.pivot_table(
            values='CR_mix',
            index='rho',
            columns='algorithm',
            aggfunc='mean'
        ).round(4)
        print("\nCR^mix (vs switching oracle V^mix):")
        print(cr_mix_pivot.to_string())

        # Show complementarity gap
        mean_gap = df['complementarity_gap'].mean()
        print(f"\nMean Complementarity Gap (V^mix/V*): {mean_gap:.2f} (expected ~2.0)")


def summarize_e2(results: List[Dict], results_dir: Path, verbose: bool = True):
    """Create summary for E2 ablation experiments."""
    import pandas as pd

    df = pd.DataFrame([r for r in results if r.get('status') == 'success'])

    if df.empty:
        print("No successful E2 results to summarize")
        return

    # Key comparison: SP-UCB-OLP vs each ablation
    summary = df.groupby(['scenario', 'algorithm', 'rho']).agg({
        'CR_mix': ['mean', 'std'],
        'regret_mix': ['mean', 'std'],
    }).round(4)

    summary.to_csv(results_dir / 'E2_summary.csv')

    if verbose:
        print("\n" + "-"*70)
        print("E2 SUMMARY: Ablation Results")
        print("(SP-UCB-OLP should outperform all ablations)")
        print("-"*70)

        for scenario in df['scenario'].unique():
            print(f"\nScenario: {scenario}")
            scenario_df = df[df['scenario'] == scenario]

            pivot = scenario_df.pivot_table(
                values='CR_mix',
                index='rho',
                columns='algorithm',
                aggfunc='mean'
            ).round(4)
            print(pivot.to_string())


# =============================================================================
# MAIN
# =============================================================================

def main():
    parser = argparse.ArgumentParser(description='Run all SP-UCB-OLP experiments')
    parser.add_argument('--experiment', type=str, default='all',
                       choices=['all', 'E1', 'E2', 'E3'],
                       help='Which experiment to run')
    parser.add_argument('--results-dir', type=str, default='./results/comprehensive',
                       help='Results directory')
    parser.add_argument('--smoke', action='store_true',
                       help='Run quick smoke test')
    parser.add_argument('--workers', type=int, default=4,
                       help='Number of parallel workers')
    parser.add_argument('--quiet', action='store_true', help='Suppress output')

    args = parser.parse_args()

    results_dir = Path(args.results_dir)
    if args.smoke:
        results_dir = results_dir / 'smoke_test'
    results_dir.mkdir(parents=True, exist_ok=True)

    verbose = not args.quiet
    n_workers = args.workers

    if verbose:
        print("="*70)
        print("SP-UCB-OLP COMPREHENSIVE EXPERIMENT RUNNER")
        print(f"Results directory: {results_dir}")
        print(f"Workers: {n_workers}")
        print("="*70)

    all_results = []
    start_time = time.time()

    # E1: Benchmark Sanity Check
    if args.experiment in ['all', 'E1']:
        config = E1_SMOKE if args.smoke else E1_CONFIG
        e1_results = run_e1(config, results_dir, n_workers, verbose)
        all_results.extend(e1_results)

    # E2: Ablation Experiments
    if args.experiment in ['all', 'E2']:
        config = E2_SMOKE if args.smoke else E2_CONFIG
        e2_results = run_e2(config, results_dir, n_workers, verbose)
        all_results.extend(e2_results)

    # E3: K-Sweep (run via separate script)
    if args.experiment in ['all', 'E3']:
        if verbose:
            print("\n" + "="*70)
            print("E3: K-SWEEP (run via run_k_sweep.py)")
            print("="*70)

        import subprocess
        k_sweep_script = Path(__file__).parent / 'run_k_sweep.py'
        k_sweep_results_dir = results_dir / 'k_sweep'

        cmd = [
            sys.executable, str(k_sweep_script),
            '--worker', 'all',
            '--results-dir', str(k_sweep_results_dir),
        ]
        if args.smoke:
            cmd.append('--smoke')
        if args.quiet:
            cmd.append('--quiet')

        subprocess.run(cmd, check=True)

    # Final summary
    total_time = time.time() - start_time
    if verbose:
        print("\n" + "="*70)
        print(f"ALL EXPERIMENTS COMPLETED")
        print(f"Total time: {total_time:.1f}s ({total_time/60:.1f} min)")
        print(f"Total runs: {len(all_results)}")
        print(f"Results saved to: {results_dir}")
        print("="*70)


if __name__ == '__main__':
    main()
