#!/usr/bin/env python3
"""
Regret Analysis Experiment

Study how regret (T * V_mix - R_T) scales with horizon T.
Theory predicts regret = O(√(KT·d·log T)) + O(√(T·log T))

Experiment design:
- T values: 100, 200, 500, 1000, 2000, 5000, 10000
- Algorithms: SP-UCB-α=0, SP-UCB-α=0.01, SP-UCB-α=0.1, OneHot, Oracle
- Scenarios: S1, S2, S3
- Seeds: 10 (42-51)
- ρ: 0.7 (fixed, middle budget level)

Total runs: 7 T × 5 alg × 3 scenarios × 10 seeds = 1050 runs
10 workers → ~105 runs per worker
"""

import os
import json
import argparse
import time
import numpy as np
from pathlib import Path
from datetime import datetime

from sp_ucb_olp.data import S1ComplementarityLoader, S2NoisyLoader, S3DominantLoader
from sp_ucb_olp.algorithms import get_algorithm


# Experiment configuration
REGRET_CONFIG = {
    'T_values': [100, 200, 500, 1000, 2000, 5000, 10000],
    'algorithms': {
        'SP-UCB-α=0': {'type': 'SP-UCB-OLP', 'alpha': 0.0},
        'SP-UCB-α=0.01': {'type': 'SP-UCB-OLP', 'alpha': 0.01},
        'SP-UCB-α=0.1': {'type': 'SP-UCB-OLP', 'alpha': 0.1},
        'OneHot': {'type': 'OneHot', 'alpha': 0.1},
        'Oracle': {'type': 'Oracle'},
    },
    'scenarios': {
        'S1': {'loader_class': S1ComplementarityLoader, 'K': 4, 'd': 3},
        'S2': {'loader_class': S2NoisyLoader, 'K': 4, 'd': 3},
        'S3': {'loader_class': S3DominantLoader, 'K': 2, 'd': 3},
    },
    'rho': 0.7,  # Fixed budget level
    'seeds': list(range(42, 52)),  # 10 seeds
    'n_workers': 10,
}


def get_worker_assignments():
    """
    Distribute work across 10 workers.
    Total: 7 T × 3 scenarios × 10 seeds = 210 (T, scenario, seed) combinations
    Each combination runs 5 algorithms.
    """
    combinations = []
    for T in REGRET_CONFIG['T_values']:
        for scenario in REGRET_CONFIG['scenarios'].keys():
            for seed in REGRET_CONFIG['seeds']:
                combinations.append((T, scenario, seed))

    # Distribute to workers
    n_workers = REGRET_CONFIG['n_workers']
    assignments = {i: [] for i in range(n_workers)}
    for idx, combo in enumerate(combinations):
        worker_id = idx % n_workers
        assignments[worker_id].append(combo)

    return assignments


def run_single_experiment(T: int, scenario: str, seed: int, results_dir: Path) -> list:
    """Run all algorithms for a single (T, scenario, seed) combination."""
    np.random.seed(seed)

    config = REGRET_CONFIG['scenarios'][scenario]
    loader_class = config['loader_class']
    K = config['K']
    d = config['d']
    rho = REGRET_CONFIG['rho']

    # Create loader with specified T
    loader = loader_class(K=K, T=T, d=d, seed=seed)
    B = loader.get_budget(rho)

    # Get oracle values (needed for Oracle algorithm and V_mix)
    oracle_values = loader.get_oracle_values(rho=rho)
    V_mix = oracle_values['V_mix']

    results = []

    for alg_name, alg_params in REGRET_CONFIG['algorithms'].items():
        np.random.seed(seed)  # Reset for each algorithm

        # Setup algorithm config
        alg_config = {}
        alg_type = alg_params['type']

        if 'alpha' in alg_params:
            alg_config['alpha'] = alg_params['alpha']

        if alg_type == 'Oracle':
            alg_config['w_star'] = oracle_values['w_star']
            alg_config['p_star'] = oracle_values['p_star']

        # Create algorithm
        algorithm = get_algorithm(alg_type, K, d, T, B, alg_config)

        # Run simulation
        start_time = time.time()

        for t in range(T):
            theta, w, p = algorithm.select_config(t)
            r, a = loader.get_arrival(theta, t)
            accept = algorithm.decide_admission(t, theta, r, a, p)
            algorithm.update(t, theta, r, a, accept)

        elapsed_time = time.time() - start_time

        # Collect results
        stats = algorithm.get_statistics()
        total_reward = stats['total_reward']

        # Compute metrics
        regret = T * V_mix - total_reward
        competitive_ratio = total_reward / (T * V_mix) if T * V_mix > 0 else 0

        result = {
            'scenario': scenario,
            'algorithm': alg_name,
            'T': T,
            'seed': seed,
            'rho': rho,
            'total_reward': total_reward,
            'regret': regret,
            'competitive_ratio': competitive_ratio,
            'acceptance_rate': stats['acceptance_rate'],
            'V_mix': V_mix,
            'T_V_mix': T * V_mix,
            'K': K,
            'd': d,
            'elapsed_time': elapsed_time,
        }
        results.append(result)

        print(f"  {alg_name}: R={total_reward:.1f}, Regret={regret:.1f}, CR={competitive_ratio:.3f} ({elapsed_time:.1f}s)")

    return results


def run_worker(worker_id: int, results_dir: Path):
    """Run experiments for a single worker."""
    assignments = get_worker_assignments()
    my_assignments = assignments[worker_id]

    print(f"Worker {worker_id}: {len(my_assignments)} (T, scenario, seed) combinations")
    print(f"  = {len(my_assignments) * 5} total runs (5 algorithms each)")

    all_results = []

    for idx, (T, scenario, seed) in enumerate(my_assignments):
        print(f"\n[Worker {worker_id}] ({idx+1}/{len(my_assignments)}) "
              f"T={T}, {scenario}, seed={seed}")

        results = run_single_experiment(T, scenario, seed, results_dir)
        all_results.extend(results)

        # Save intermediate results
        if (idx + 1) % 5 == 0:
            output_file = results_dir / f"regret_worker_{worker_id}_partial.json"
            with open(output_file, 'w') as f:
                json.dump(all_results, f, indent=2)

    # Save final results
    output_file = results_dir / f"regret_worker_{worker_id}.json"
    with open(output_file, 'w') as f:
        json.dump(all_results, f, indent=2)

    print(f"\nWorker {worker_id} complete: {len(all_results)} results saved to {output_file}")
    return all_results


def combine_results(results_dir: Path):
    """Combine results from all workers."""
    all_results = []

    for worker_id in range(REGRET_CONFIG['n_workers']):
        worker_file = results_dir / f"regret_worker_{worker_id}.json"
        if worker_file.exists():
            with open(worker_file, 'r') as f:
                worker_results = json.load(f)
            all_results.extend(worker_results)
            print(f"Loaded {len(worker_results)} results from worker {worker_id}")

    # Save combined
    combined_file = results_dir / "regret_combined.json"
    with open(combined_file, 'w') as f:
        json.dump(all_results, f, indent=2)

    print(f"\nCombined {len(all_results)} total results to {combined_file}")

    # Generate summary
    generate_summary(all_results, results_dir)

    return all_results


def generate_summary(results: list, results_dir: Path):
    """Generate summary statistics."""
    import pandas as pd

    df = pd.DataFrame(results)

    print("\n" + "=" * 80)
    print("REGRET ANALYSIS SUMMARY")
    print("=" * 80)

    for scenario in ['S1', 'S2', 'S3']:
        print(f"\n{scenario} - Mean Regret by T:")
        scenario_df = df[df['scenario'] == scenario]

        pivot = scenario_df.pivot_table(
            values='regret',
            index='T',
            columns='algorithm',
            aggfunc='mean'
        )
        print(pivot.round(1).to_string())

    # Save to CSV
    df.to_csv(results_dir / "regret_per_seed.csv", index=False)

    # Save summary
    summary_df = df.groupby(['scenario', 'algorithm', 'T']).agg({
        'regret': ['mean', 'std'],
        'competitive_ratio': ['mean', 'std'],
    }).reset_index()
    summary_df.columns = ['scenario', 'algorithm', 'T', 'regret_mean', 'regret_std',
                          'cr_mean', 'cr_std']
    summary_df.to_csv(results_dir / "regret_summary.csv", index=False)

    print(f"\nSaved per-seed data to {results_dir}/regret_per_seed.csv")
    print(f"Saved summary to {results_dir}/regret_summary.csv")

    # Also print sqrt(T) scaling check
    print("\n" + "=" * 80)
    print("REGRET SCALING ANALYSIS (Regret / √T)")
    print("=" * 80)

    for scenario in ['S1', 'S2', 'S3']:
        print(f"\n{scenario}:")
        scenario_df = df[df['scenario'] == scenario]

        # Compute regret / sqrt(T)
        scenario_df = scenario_df.copy()
        scenario_df['regret_normalized'] = scenario_df['regret'] / np.sqrt(scenario_df['T'])

        pivot = scenario_df.pivot_table(
            values='regret_normalized',
            index='T',
            columns='algorithm',
            aggfunc='mean'
        )
        print(pivot.round(2).to_string())


def main():
    parser = argparse.ArgumentParser(description='Regret Analysis Experiment')
    parser.add_argument('--worker', type=str, default=None,
                       help='Worker ID (0-9) or "all" to run all sequentially, '
                            'or "combine" to combine results')
    parser.add_argument('--results-dir', type=str, default='./results/regret_analysis',
                       help='Directory for results')
    args = parser.parse_args()

    results_dir = Path(args.results_dir)
    results_dir.mkdir(parents=True, exist_ok=True)

    if args.worker == 'combine':
        combine_results(results_dir)
    elif args.worker == 'all':
        # Run all workers sequentially
        for worker_id in range(REGRET_CONFIG['n_workers']):
            run_worker(worker_id, results_dir)
        combine_results(results_dir)
    elif args.worker is not None:
        worker_id = int(args.worker)
        run_worker(worker_id, results_dir)
    else:
        # Print assignment info
        assignments = get_worker_assignments()
        print("Regret Analysis Experiment")
        print("=" * 60)
        print(f"T values: {REGRET_CONFIG['T_values']}")
        print(f"Algorithms: {list(REGRET_CONFIG['algorithms'].keys())}")
        print(f"Scenarios: {list(REGRET_CONFIG['scenarios'].keys())}")
        print(f"Seeds: {REGRET_CONFIG['seeds']}")
        print(f"ρ: {REGRET_CONFIG['rho']}")
        print(f"\nTotal combinations: {sum(len(a) for a in assignments.values())}")
        print(f"Total runs: {sum(len(a) for a in assignments.values()) * 5}")
        print("\nWorker assignments:")
        for worker_id, combos in assignments.items():
            print(f"  Worker {worker_id}: {len(combos)} combinations ({len(combos)*5} runs)")

        print("\nTo run:")
        print("  Single worker:  python run_regret_analysis.py --worker 0")
        print("  All sequential: python run_regret_analysis.py --worker all")
        print("  Combine only:   python run_regret_analysis.py --worker combine")
        print("\nTo run 10 parallel workers:")
        print("  for i in {0..9}; do python run_regret_analysis.py --worker $i &; done")


if __name__ == "__main__":
    main()
