#!/usr/bin/env python3
"""
Alpha Sweep Experiment for SP-UCB-OLP

Tests SP-UCB-OLP with different exploration rates (alpha values).
Uses 5 parallel workers, each handling 2 seeds.

Alpha values tested: 0.01, 0.05, 0.1 (existing), 0.5, 1.0
"""

import argparse
import json
import time
import numpy as np
from pathlib import Path
from datetime import datetime
import sys

from sp_ucb_olp.data import S1ComplementarityLoader, S2NoisyLoader, S3DominantLoader
from sp_ucb_olp.algorithms import get_algorithm

# =============================================================================
# EXPERIMENT CONFIGURATION
# =============================================================================

ALPHA_VALUES = [0.01, 0.05, 0.1, 0.5, 1.0]

ALPHA_CONFIG = {
    'scenarios': {
        'S1': {'loader': S1ComplementarityLoader, 'K': 4, 'T': 5000, 'd': 3},
        'S2': {'loader': S2NoisyLoader, 'K': 4, 'T': 5000, 'd': 3},
        'S3': {'loader': S3DominantLoader, 'K': 2, 'T': 5000, 'd': 3},
    },
    'rho_values': [0.3, 0.5, 0.7, 0.9, 1.2],
    'seeds': list(range(42, 52)),  # 10 seeds
    'worker_seeds': {i: [42 + 2*i, 43 + 2*i] for i in range(5)},  # 2 seeds per worker
}

# =============================================================================
# SINGLE RUN FUNCTION
# =============================================================================

def run_single_experiment(
    scenario_name: str,
    alpha: float,
    rho: float,
    seed: int,
    config: dict,
    loader=None,
    oracle_values=None
) -> dict:
    """Run a single SP-UCB-OLP experiment with given alpha."""
    scenario_config = config['scenarios'][scenario_name]
    K = scenario_config['K']
    T = scenario_config['T']
    d = scenario_config.get('d', 3)

    # Use provided loader or create new one
    if loader is None:
        LoaderClass = scenario_config['loader']
        loader = LoaderClass(K=K, T=T, seed=seed, d=d)

    d = loader.d
    B = loader.get_budget(rho)

    # Use provided oracle values or compute new ones
    if oracle_values is None:
        oracle_values = loader.get_oracle_values(rho=rho)
    V_mix = oracle_values['V_mix']

    # Setup algorithm config with specific alpha
    alg_config = {'alpha': alpha, 'warm_start': True}

    # Create algorithm
    algorithm = get_algorithm('SP-UCB-OLP', K, d, T, B, alg_config)

    # Run simulation
    start_time = time.time()

    for t in range(T):
        theta, w, p = algorithm.select_config(t)
        r, a = loader.get_arrival(theta, t)
        accept = algorithm.decide_admission(t, theta, r, a, p)
        algorithm.update(t, theta, r, a, accept)

    elapsed_time = time.time() - start_time

    # Collect results
    stats = algorithm.get_statistics()

    # Compute competitive ratio
    competitive_ratio = stats['total_reward'] / (T * V_mix) if V_mix > 0 else 0.0

    result = {
        # Identifiers
        'scenario': scenario_name,
        'algorithm': f'SP-UCB-OLP-alpha{alpha}',
        'alpha': alpha,
        'rho': rho,
        'seed': seed,

        # Key metrics
        'total_reward': stats['total_reward'],
        'competitive_ratio': competitive_ratio,
        'acceptance_rate': stats['acceptance_rate'],
        'total_accepts': stats['total_accepts'],

        # Budget info
        'B_remaining': stats['B_remaining'],
        'budget_utilization': list(stats['budget_utilization']),
        'budget_violation': bool(np.any(np.array(stats['B_remaining']) < -1e-9)),

        # Config usage
        'N_theta': stats['N_theta'],

        # Oracle values
        'V_mix': V_mix,
        'V_star': oracle_values['V_star'],

        # Metadata
        'K': K,
        'd': d,
        'T': T,
        'elapsed_time': elapsed_time,
    }

    return result

# =============================================================================
# WORKER FUNCTION
# =============================================================================

def run_worker(worker_id: int, results_dir: Path, config: dict, verbose: bool = True):
    """Run all experiments for a single worker."""
    seeds = config['worker_seeds'][worker_id]
    scenarios = list(config['scenarios'].keys())
    rho_values = config['rho_values']

    total_runs = len(seeds) * len(scenarios) * len(ALPHA_VALUES) * len(rho_values)

    if verbose:
        print(f"Worker {worker_id}: Starting {total_runs} runs")
        print(f"  Seeds: {seeds}")
        print(f"  Alpha values: {ALPHA_VALUES}")

    results = []
    run_count = 0
    start_time = time.time()

    for seed in seeds:
        for scenario in scenarios:
            # Create loader once per (scenario, seed)
            scenario_config = config['scenarios'][scenario]
            LoaderClass = scenario_config['loader']
            K = scenario_config['K']
            T = scenario_config['T']
            d = scenario_config.get('d', 3)
            loader = LoaderClass(K=K, T=T, seed=seed, d=d)

            for rho in rho_values:
                # Compute oracle once per (scenario, seed, rho)
                oracle_values = loader.get_oracle_values(rho=rho)

                for alpha in ALPHA_VALUES:
                    run_count += 1

                    if verbose and run_count % 10 == 0:
                        elapsed = time.time() - start_time
                        eta = elapsed / run_count * (total_runs - run_count)
                        print(f"  Worker {worker_id}: {run_count}/{total_runs} "
                              f"({100*run_count/total_runs:.0f}%) ETA: {eta:.0f}s")

                    try:
                        result = run_single_experiment(
                            scenario, alpha, rho, seed, config,
                            loader=loader, oracle_values=oracle_values
                        )
                        results.append(result)
                    except Exception as e:
                        print(f"  ERROR Worker {worker_id}: {scenario}/alpha={alpha}/rho={rho}/seed={seed}: {e}")
                        results.append({
                            'scenario': scenario,
                            'alpha': alpha,
                            'rho': rho,
                            'seed': seed,
                            'error': str(e),
                        })

    total_time = time.time() - start_time

    if verbose:
        print(f"Worker {worker_id}: Completed {total_runs} runs in {total_time:.1f}s")

    # Save results
    output_file = results_dir / f"alpha_worker_{worker_id}_results.json"
    with open(output_file, 'w') as f:
        json.dump({
            'worker_id': worker_id,
            'seeds': seeds,
            'alpha_values': ALPHA_VALUES,
            'total_runs': total_runs,
            'total_time': total_time,
            'results': results,
        }, f, indent=2)

    if verbose:
        print(f"Worker {worker_id}: Saved to {output_file}")

    return results

# =============================================================================
# COMBINE RESULTS
# =============================================================================

def combine_worker_results(results_dir: Path, verbose: bool = True):
    """Combine all worker results into a single file."""
    all_results = []

    for worker_id in range(5):
        worker_file = results_dir / f"alpha_worker_{worker_id}_results.json"
        if worker_file.exists():
            with open(worker_file) as f:
                data = json.load(f)
                all_results.extend(data['results'])
                if verbose:
                    print(f"Loaded {len(data['results'])} results from worker {worker_id}")

    # Save combined results
    combined_file = results_dir / "alpha_combined_results.json"
    with open(combined_file, 'w') as f:
        json.dump({
            'alpha_values': ALPHA_VALUES,
            'total_results': len(all_results),
            'results': all_results,
        }, f, indent=2)

    if verbose:
        print(f"\nCombined {len(all_results)} results into {combined_file}")

    # Create summary
    create_summary(all_results, results_dir, verbose)

    return all_results

def create_summary(results: list, results_dir: Path, verbose: bool = True):
    """Create summary statistics from raw results."""
    import pandas as pd

    # Convert to DataFrame
    df = pd.DataFrame([r for r in results if 'error' not in r])

    if len(df) == 0:
        print("No valid results to summarize")
        return

    # Summary by scenario, alpha, rho
    print("\n" + "="*80)
    print("ALPHA SWEEP SUMMARY: Mean Competitive Ratio")
    print("="*80)

    pivot = df.pivot_table(
        values='competitive_ratio',
        index=['scenario', 'rho'],
        columns='alpha',
        aggfunc='mean'
    ).round(4)

    print(pivot.to_string())

    # Save summary
    summary_file = results_dir / "alpha_summary.csv"
    pivot.to_csv(summary_file)

    if verbose:
        print(f"\nSummary saved to {summary_file}")

    # Scenario averages
    print("\n" + "="*80)
    print("SCENARIO AVERAGES BY ALPHA")
    print("="*80)
    scenario_avg = df.groupby(['scenario', 'alpha'])['competitive_ratio'].mean().unstack('alpha')
    print(scenario_avg.round(4).to_string())

# =============================================================================
# MAIN
# =============================================================================

def main():
    parser = argparse.ArgumentParser(description='Run alpha sweep experiments')
    parser.add_argument('--worker', type=str, default='0',
                       help='Worker ID (0-4) or "all" or "combine"')
    parser.add_argument('--results-dir', type=str, default='./results/alpha_sweep',
                       help='Results directory')
    parser.add_argument('--quiet', action='store_true', help='Suppress output')

    args = parser.parse_args()

    results_dir = Path(args.results_dir)
    results_dir.mkdir(parents=True, exist_ok=True)
    verbose = not args.quiet

    config = ALPHA_CONFIG

    if verbose:
        n_seeds = len(config['seeds'])
        n_rho = len(config['rho_values'])
        n_alpha = len(ALPHA_VALUES)
        total = 3 * n_seeds * n_rho * n_alpha
        print(f"Alpha Sweep: {n_alpha} alpha values × {n_seeds} seeds × 3 scenarios × {n_rho} rho")
        print(f"Alpha values: {ALPHA_VALUES}")
        print(f"Total runs: {total}")

    if args.worker == 'combine':
        combine_worker_results(results_dir, verbose)
    elif args.worker == 'all':
        if verbose:
            print("="*70)
            print("RUNNING ALL WORKERS SEQUENTIALLY")
            print("="*70)

        start = time.time()
        for worker_id in range(5):
            run_worker(worker_id, results_dir, config, verbose)

        if verbose:
            print(f"\nCompleted in {time.time()-start:.1f}s")

        combine_worker_results(results_dir, verbose)
    else:
        worker_id = int(args.worker)
        if worker_id < 0 or worker_id > 4:
            print(f"Error: worker must be 0-4, got {worker_id}")
            sys.exit(1)
        run_worker(worker_id, results_dir, config, verbose)

if __name__ == '__main__':
    main()
