#!/usr/bin/env python3
"""
Parallel Synthetic Experiments for SP-UCB-OLP

Runs S1, S2, S3 experiments with 10 parallel workers.
Each worker handles 2 seeds (20 seeds total).

Usage:
    # Run all workers sequentially (for testing)
    python run_parallel_synthetic.py --worker all

    # Run specific worker
    python run_parallel_synthetic.py --worker 0
    python run_parallel_synthetic.py --worker 5

    # Run with GNU parallel (recommended)
    seq 0 9 | parallel -j 10 python run_parallel_synthetic.py --worker {}
"""

import argparse
import json
import time
import numpy as np
from pathlib import Path
from datetime import datetime
import sys

from sp_ucb_olp.data import S1ComplementarityLoader, S2NoisyLoader, S3DominantLoader
from sp_ucb_olp.algorithms import get_algorithm

# =============================================================================
# EXPERIMENT CONFIGURATION
# =============================================================================

# Full experiment config (T=5000, d=3, 10 seeds - one per worker)
FULL_CONFIG = {
    'scenarios': {
        'S1': {'loader': S1ComplementarityLoader, 'K': 4, 'T': 5000, 'd': 3},
        'S2': {'loader': S2NoisyLoader, 'K': 4, 'T': 5000, 'd': 3},
        'S3': {'loader': S3DominantLoader, 'K': 2, 'T': 5000, 'd': 3},
    },
    'algorithms': ['SP-UCB-OLP', 'Greedy', 'OneHot', 'Oracle', 'Random'],
    'rho_values': [0.3, 0.5, 0.7, 0.9, 1.2],
    'seeds': list(range(42, 52)),  # 10 seeds
    'worker_seeds': {i: [42 + i] for i in range(10)},  # 1 seed per worker
    'algorithm_configs': {
        'SP-UCB-OLP': {'alpha': 0.1, 'warm_start': True},
        'Greedy': {},
        'OneHot': {'alpha': 0.1},
        'Oracle': {},
        'Random': {},
    },
}

# Quick smoke test config (T=500, d=3, 2 seeds, 2 rho values)
SMOKE_CONFIG = {
    'scenarios': {
        'S1': {'loader': S1ComplementarityLoader, 'K': 4, 'T': 500, 'd': 3},
        'S2': {'loader': S2NoisyLoader, 'K': 4, 'T': 500, 'd': 3},
        'S3': {'loader': S3DominantLoader, 'K': 2, 'T': 500, 'd': 3},
    },
    'algorithms': ['SP-UCB-OLP', 'Greedy', 'OneHot', 'Oracle', 'Random'],
    'rho_values': [0.5, 1.0],
    'seeds': [42, 43],
    'worker_seeds': {0: [42, 43]},  # Single worker for smoke test
    'algorithm_configs': {
        'SP-UCB-OLP': {'alpha': 0.1, 'warm_start': True},
        'Greedy': {},
        'OneHot': {'alpha': 0.1},
        'Oracle': {},
        'Random': {},
    },
}

# Default to full config
EXPERIMENT_CONFIG = FULL_CONFIG

# =============================================================================
# SINGLE RUN FUNCTION
# =============================================================================

def run_single_experiment(
    scenario_name: str,
    algorithm_name: str,
    rho: float,
    seed: int,
    config: dict,
    loader=None,
    oracle_values=None
) -> dict:
    """
    Run a single experiment and return raw results.

    Parameters
    ----------
    loader : optional
        Pre-created loader (avoids recreating for each algorithm)
    oracle_values : optional
        Pre-computed oracle values (avoids recomputing for each algorithm)

    Returns dict with all seed-level metrics.
    """
    scenario_config = config['scenarios'][scenario_name]
    K = scenario_config['K']
    T = scenario_config['T']
    d = scenario_config.get('d', 3)

    # Use provided loader or create new one
    if loader is None:
        LoaderClass = scenario_config['loader']
        loader = LoaderClass(K=K, T=T, seed=seed, d=d)

    d = loader.d
    B = loader.get_budget(rho)

    # Use provided oracle values or compute new ones
    if oracle_values is None:
        oracle_values = loader.get_oracle_values(rho=rho)
    V_mix = oracle_values['V_mix']

    # Setup algorithm config
    alg_config = config['algorithm_configs'].get(algorithm_name, {}).copy()
    if algorithm_name == 'Oracle':
        alg_config['w_star'] = oracle_values['w_star']
        alg_config['p_star'] = oracle_values['p_star']

    # Create algorithm
    algorithm = get_algorithm(algorithm_name, K, d, T, B, alg_config)

    # Run simulation
    start_time = time.time()

    for t in range(T):
        theta, w, p = algorithm.select_config(t)
        r, a = loader.get_arrival(theta, t)
        accept = algorithm.decide_admission(t, theta, r, a, p)
        algorithm.update(t, theta, r, a, accept)

    elapsed_time = time.time() - start_time

    # Collect results
    stats = algorithm.get_statistics()

    # Compute competitive ratio
    competitive_ratio = stats['total_reward'] / (T * V_mix) if V_mix > 0 else 0.0

    result = {
        # Identifiers
        'scenario': scenario_name,
        'algorithm': algorithm_name,
        'rho': rho,
        'seed': seed,

        # Key metrics
        'total_reward': stats['total_reward'],
        'competitive_ratio': competitive_ratio,
        'acceptance_rate': stats['acceptance_rate'],
        'total_accepts': stats['total_accepts'],

        # Budget info
        'B_remaining': stats['B_remaining'],
        'budget_utilization': list(stats['budget_utilization']),
        'budget_violation': bool(np.any(np.array(stats['B_remaining']) < -1e-9)),

        # Config usage
        'N_theta': stats['N_theta'],

        # Oracle values
        'V_mix': V_mix,
        'V_star': oracle_values['V_star'],

        # Metadata
        'K': K,
        'd': d,
        'T': T,
        'elapsed_time': elapsed_time,
    }

    return result

# =============================================================================
# WORKER FUNCTION
# =============================================================================

def run_worker(worker_id: int, results_dir: Path, config: dict, verbose: bool = True):
    """
    Run all experiments for a single worker.

    Each worker handles 2 seeds across all scenarios, algorithms, and rho values.
    """
    seeds = config['worker_seeds'][worker_id]
    scenarios = list(config['scenarios'].keys())
    algorithms = config['algorithms']
    rho_values = config['rho_values']

    total_runs = len(seeds) * len(scenarios) * len(algorithms) * len(rho_values)

    if verbose:
        print(f"Worker {worker_id}: Starting {total_runs} runs")
        print(f"  Seeds: {seeds}")
        print(f"  Scenarios: {scenarios}")
        print(f"  Algorithms: {algorithms}")
        print(f"  Rho values: {rho_values}")

    results = []
    run_count = 0
    start_time = time.time()

    for seed in seeds:
        for scenario in scenarios:
            # Create loader once per (scenario, seed)
            scenario_config = config['scenarios'][scenario]
            LoaderClass = scenario_config['loader']
            K = scenario_config['K']
            T = scenario_config['T']
            d = scenario_config.get('d', 3)
            loader = LoaderClass(K=K, T=T, seed=seed, d=d)

            for rho in rho_values:
                # Compute oracle once per (scenario, seed, rho)
                oracle_values = loader.get_oracle_values(rho=rho)

                for algorithm in algorithms:
                    run_count += 1

                    if verbose and run_count % 25 == 0:
                        elapsed = time.time() - start_time
                        eta = elapsed / run_count * (total_runs - run_count)
                        print(f"  Worker {worker_id}: {run_count}/{total_runs} "
                              f"({100*run_count/total_runs:.0f}%) ETA: {eta:.0f}s")

                    try:
                        result = run_single_experiment(
                            scenario, algorithm, rho, seed, config,
                            loader=loader, oracle_values=oracle_values
                        )
                        results.append(result)
                    except Exception as e:
                        print(f"  ERROR Worker {worker_id}: {scenario}/{algorithm}/rho={rho}/seed={seed}: {e}")
                        results.append({
                            'scenario': scenario,
                            'algorithm': algorithm,
                            'rho': rho,
                            'seed': seed,
                            'error': str(e),
                        })

    total_time = time.time() - start_time

    if verbose:
        print(f"Worker {worker_id}: Completed {total_runs} runs in {total_time:.1f}s")

    # Save results
    output_file = results_dir / f"worker_{worker_id}_results.json"
    with open(output_file, 'w') as f:
        json.dump({
            'worker_id': worker_id,
            'seeds': seeds,
            'total_runs': total_runs,
            'total_time': total_time,
            'results': results,
        }, f, indent=2)

    if verbose:
        print(f"Worker {worker_id}: Saved to {output_file}")

    return results

# =============================================================================
# COMBINE RESULTS
# =============================================================================

def combine_worker_results(results_dir: Path, verbose: bool = True):
    """Combine all worker results into a single file."""
    all_results = []

    for worker_id in range(10):
        worker_file = results_dir / f"worker_{worker_id}_results.json"
        if worker_file.exists():
            with open(worker_file) as f:
                data = json.load(f)
                all_results.extend(data['results'])
                if verbose:
                    print(f"Loaded {len(data['results'])} results from worker {worker_id}")

    # Save combined results
    combined_file = results_dir / "combined_results.json"
    with open(combined_file, 'w') as f:
        json.dump({
            'total_results': len(all_results),
            'results': all_results,
        }, f, indent=2)

    if verbose:
        print(f"\nCombined {len(all_results)} results into {combined_file}")

    # Create summary statistics
    create_summary(all_results, results_dir, verbose)

    return all_results

def create_summary(results: list, results_dir: Path, verbose: bool = True):
    """Create summary statistics from raw results."""
    import pandas as pd

    # Convert to DataFrame
    df = pd.DataFrame([r for r in results if 'error' not in r])

    # Group by scenario, algorithm, rho
    summary = df.groupby(['scenario', 'algorithm', 'rho']).agg({
        'competitive_ratio': ['mean', 'std', 'min', 'max'],
        'total_reward': ['mean', 'std'],
        'acceptance_rate': ['mean', 'std'],
        'elapsed_time': ['mean', 'sum'],
    }).round(4)

    # Save summary
    summary_file = results_dir / "summary_statistics.csv"
    summary.to_csv(summary_file)

    if verbose:
        print(f"\nSummary saved to {summary_file}")
        print("\n" + "="*70)
        print("SUMMARY: Mean Competitive Ratio by Scenario/Algorithm/Rho")
        print("="*70)

        pivot = df.pivot_table(
            values='competitive_ratio',
            index=['scenario', 'rho'],
            columns='algorithm',
            aggfunc='mean'
        ).round(4)
        print(pivot.to_string())

# =============================================================================
# MAIN
# =============================================================================

def main():
    parser = argparse.ArgumentParser(description='Run parallel synthetic experiments')
    parser.add_argument('--worker', type=str, default='0',
                       help='Worker ID (0-9) or "all" or "combine"')
    parser.add_argument('--results-dir', type=str, default='./results/parallel',
                       help='Results directory')
    parser.add_argument('--smoke', action='store_true',
                       help='Run quick smoke test (T=500, 2 seeds)')
    parser.add_argument('--quiet', action='store_true', help='Suppress output')

    args = parser.parse_args()

    # Select config
    if args.smoke:
        config = SMOKE_CONFIG
        results_dir = Path(args.results_dir) / 'smoke_test'
    else:
        config = FULL_CONFIG
        results_dir = Path(args.results_dir)

    results_dir.mkdir(parents=True, exist_ok=True)
    verbose = not args.quiet

    if verbose:
        T = list(config['scenarios'].values())[0]['T']
        n_seeds = len(config['seeds'])
        n_rho = len(config['rho_values'])
        n_alg = len(config['algorithms'])
        total = 3 * n_seeds * n_rho * n_alg
        print(f"Config: T={T}, seeds={n_seeds}, rho_values={n_rho}, algorithms={n_alg}")
        print(f"Total runs: {total}")

    if args.worker == 'combine':
        combine_worker_results(results_dir, verbose)
    elif args.worker == 'all' or args.smoke:
        if verbose:
            print("="*70)
            print("RUNNING ALL WORKERS" + (" (SMOKE TEST)" if args.smoke else ""))
            print("="*70)

        start = time.time()
        n_workers = 1 if args.smoke else 10
        for worker_id in range(n_workers):
            run_worker(worker_id, results_dir, config, verbose)

        if verbose:
            print(f"\nCompleted in {time.time()-start:.1f}s")

        combine_worker_results(results_dir, verbose)
    else:
        worker_id = int(args.worker)
        if worker_id < 0 or worker_id > 9:
            print(f"Error: worker must be 0-9, got {worker_id}")
            sys.exit(1)
        run_worker(worker_id, results_dir, config, verbose)

if __name__ == '__main__':
    main()
