#!/usr/bin/env python3
"""
K-Sweep Experiments for SP-UCB-OLP

Validates the sqrt(K) scaling in the regret bound by varying K at fixed T.

Setup:
- Fix T=5000, d=3, rho=0.7
- K in {2, 4, 8, 16}
- For K=2: Use S3 scenario directly
- For K=4: Use S1 scenario directly
- For K=8/16: Use S1 configs + distractor configs

Expected result: Regret/sqrt(T) should scale linearly with sqrt(K).

Usage:
    # Run all experiments
    python run_k_sweep.py --worker all

    # Run specific worker
    python run_k_sweep.py --worker 0

    # Run with GNU parallel
    seq 0 9 | parallel -j 10 python run_k_sweep.py --worker {}

    # Smoke test
    python run_k_sweep.py --smoke
"""

import argparse
import json
import time
import numpy as np
from pathlib import Path
from datetime import datetime
import sys
from typing import Dict, Any, Tuple

from sp_ucb_olp.data import S1ComplementarityLoader, S3DominantLoader
from sp_ucb_olp.data.base_loader import BaseDataLoader
from sp_ucb_olp.algorithms import get_algorithm


# =============================================================================
# EXTENDED K LOADERS (K > 4)
# =============================================================================

class ExtendedKLoader(BaseDataLoader):
    """
    Extended loader that adds distractor configs to S1 to test K-scaling.

    For K > 4, we add (K-4) distractor configs with similar efficiency range
    to the S1 base configs, making exploration necessary to find the optimal.
    """

    def __init__(
        self,
        K: int = 8,
        T: int = 5000,
        seed: int = 42,
        d: int = 3,
    ):
        super().__init__(K, d, T, seed)

        # Start with S1 base profiles (4 configs)
        self._config_profiles = {
            # Config 0 (Trap A): HIGH VARIANCE, efficiency ~1.4
            0: {
                'reward_mean': 1.4, 'reward_std': 0.6,
                'consumption_mean': np.array([0.35, 0.35, 0.30]),
                'consumption_std': np.array([0.15, 0.15, 0.12]),
            },
            # Config 1 (Hidden Gem): LOW VARIANCE, efficiency ~1.6 (OPTIMAL)
            1: {
                'reward_mean': 1.12, 'reward_std': 0.08,
                'consumption_mean': np.array([0.25, 0.22, 0.23]),
                'consumption_std': np.array([0.02, 0.02, 0.02]),
            },
            # Config 2 (Trap B): HIGH VARIANCE
            2: {
                'reward_mean': 1.2, 'reward_std': 0.5,
                'consumption_mean': np.array([0.30, 0.28, 0.27]),
                'consumption_std': np.array([0.12, 0.12, 0.10]),
            },
            # Config 3 (Distractor)
            3: {
                'reward_mean': 0.9, 'reward_std': 0.3,
                'consumption_mean': np.array([0.28, 0.26, 0.26]),
                'consumption_std': np.array([0.08, 0.08, 0.08]),
            },
        }

        # Add (K-4) distractor configs with similar efficiency range
        if K > 4:
            rng = np.random.RandomState(seed + 1000)  # Different seed for distractors
            for k in range(4, K):
                # Random efficiency between 1.0 and 1.5 (worse than optimal 1.6)
                eff = 1.0 + 0.5 * rng.random()
                total_consumption = 0.6 + 0.4 * rng.random()
                r_mean = eff * total_consumption

                # Random variance (mix of high and low)
                is_high_variance = rng.random() > 0.5
                if is_high_variance:
                    r_std = 0.3 + 0.3 * rng.random()
                    c_std_mult = 0.10 + 0.05 * rng.random()
                else:
                    r_std = 0.05 + 0.10 * rng.random()
                    c_std_mult = 0.02 + 0.03 * rng.random()

                c_mean = rng.dirichlet(np.ones(self.d)) * total_consumption
                c_std = c_mean * c_std_mult

                self._config_profiles[k] = {
                    'reward_mean': r_mean,
                    'reward_std': r_std,
                    'consumption_mean': c_mean,
                    'consumption_std': c_std,
                }

        self._generate_arrivals()
        self._compute_nominal_budget()

    def _generate_arrivals(self):
        """Generate i.i.d. arrivals from stationary distributions."""
        rng = np.random.RandomState(self.seed)

        for theta in range(self.K):
            profile = self._config_profiles[theta]
            arrivals = np.zeros((self.T, self.d + 1))

            for t in range(self.T):
                r = rng.normal(profile['reward_mean'], profile['reward_std'])
                r = max(r, 0.01)

                a = rng.normal(profile['consumption_mean'], profile['consumption_std'])
                a = np.maximum(a, 0.01)

                arrivals[t, 0] = r
                arrivals[t, 1:] = a

            self._arrivals[theta] = arrivals

    def _compute_nominal_budget(self):
        """Budget based on hidden gem (config 1) consumption."""
        hidden_gem_consumption = self._config_profiles[1]['consumption_mean']
        self._nominal_budget = hidden_gem_consumption * self.T * 0.7

    def get_arrival(self, theta: int, t: int) -> Tuple[float, np.ndarray]:
        if theta < 0 or theta >= self.K:
            raise ValueError(f"theta {theta} out of range [0, {self.K-1}]")
        if t < 0 or t >= self.T:
            raise ValueError(f"t {t} out of range [0, {self.T-1}]")

        arrival = self._arrivals[theta][t]
        return float(arrival[0]), arrival[1:].copy()

    def get_budget(self, rho: float) -> np.ndarray:
        return rho * self._nominal_budget

    def _compute_surplus(self, theta: int, p: np.ndarray, n_samples: int = 10000) -> float:
        """Compute g_theta(p) using closed-form for normal distributions."""
        from scipy.stats import norm

        profile = self._config_profiles[theta]
        mu_s = profile['reward_mean'] - np.dot(p, profile['consumption_mean'])
        var_s = profile['reward_std']**2 + np.sum((p * profile['consumption_std'])**2)
        sigma_s = np.sqrt(var_s)

        if sigma_s < 1e-10:
            return max(mu_s, 0.0)

        z = mu_s / sigma_s
        return mu_s * norm.cdf(z) + sigma_s * norm.pdf(z)

    def _solve_V_mix(self, b: np.ndarray, n_grid: int = 20) -> Tuple[float, np.ndarray, np.ndarray]:
        """Solve V^mix(b) = min_p { <p, b> + max_theta g_theta(p) }."""
        from scipy.optimize import minimize

        p_max = 5.0

        def objective(p):
            surpluses = np.array([self._compute_surplus(k, p) for k in range(self.K)])
            envelope = np.max(surpluses)
            return np.dot(p, b) + envelope

        best_val = np.inf
        best_p = np.zeros(self.d)

        n_restarts = 5
        for restart in range(n_restarts):
            if restart == 0:
                p_init = np.zeros(self.d)
            else:
                p_init = np.random.uniform(0, p_max / 2, self.d)

            result = minimize(
                objective,
                p_init,
                method='L-BFGS-B',
                bounds=[(0, p_max)] * self.d,
                options={'maxiter': 100, 'ftol': 1e-6}
            )

            if result.fun < best_val:
                best_val = result.fun
                best_p = result.x.copy()

        best_surpluses = np.array([self._compute_surplus(k, best_p) for k in range(self.K)])
        max_surplus = np.max(best_surpluses)
        w_star = (best_surpluses >= max_surplus - 1e-8).astype(float)
        w_star /= w_star.sum()

        return best_val, best_p, w_star

    def get_oracle_values(self, rho: float = 1.0) -> Dict[str, Any]:
        """Compute theoretical oracle values."""
        from scipy.optimize import minimize

        b = rho * self._nominal_budget / self.T

        # Compute efficiencies
        efficiencies = {}
        for k, profile in self._config_profiles.items():
            total_cost = np.sum(profile['consumption_mean'])
            efficiencies[k] = profile['reward_mean'] / total_cost

        V_mix, p_star, w_star = self._solve_V_mix(b)

        # Compute V*
        V_star = 0.0
        best_theta = 0
        p_max = 5.0

        for theta in range(self.K):
            def obj_theta(p, th=theta):
                g = self._compute_surplus(th, p)
                return np.dot(p, b) + g

            result = minimize(
                obj_theta,
                np.zeros(self.d),
                method='L-BFGS-B',
                bounds=[(0, p_max)] * self.d,
                options={'maxiter': 100}
            )

            if result.fun > V_star:
                V_star = result.fun
                best_theta = theta

        return {
            'V_mix': V_mix,
            'V_star': V_star,
            'gap': V_mix - V_star,
            'w_star': w_star,
            'p_star': p_star,
            'efficiencies': efficiencies,
            'best_fixed_config': best_theta,
            'rho': rho,
            'b_per_period': b,
            'hidden_gem': 1,
            'traps': [0, 2] + list(range(4, self.K)),
        }

    def get_metadata(self) -> Dict[str, Any]:
        base = super().get_metadata()
        oracle = self.get_oracle_values()

        return {
            **base,
            'family': f'S1-Extended-K{self.K}',
            'name': f'Exploration-Critical (K={self.K})',
            'efficiencies': oracle['efficiencies'],
            'hidden_gem': oracle['hidden_gem'],
            'traps': oracle['traps'],
        }


# =============================================================================
# EXPERIMENT CONFIGURATION
# =============================================================================

def get_loader_for_K(K: int, T: int, seed: int, d: int = 3):
    """Get appropriate loader for given K."""
    if K == 2:
        return S3DominantLoader(K=K, T=T, seed=seed, d=d)
    elif K == 4:
        return S1ComplementarityLoader(K=K, T=T, seed=seed, d=d)
    else:
        return ExtendedKLoader(K=K, T=T, seed=seed, d=d)


# Full experiment config
FULL_CONFIG = {
    'K_values': [2, 4, 8, 16],
    'T': 5000,
    'd': 3,
    'rho': 0.7,
    'algorithms': ['SP-UCB-OLP', 'Greedy', 'OneHot', 'Oracle', 'Random'],
    'seeds': list(range(42, 52)),  # 10 seeds
    'worker_seeds': {i: [42 + i] for i in range(10)},
    'algorithm_configs': {
        'SP-UCB-OLP': {'alpha': 0.1, 'warm_start': True},
        'Greedy': {},
        'OneHot': {'alpha': 0.1},
        'Oracle': {},
        'Random': {},
    },
}

# Smoke test config
SMOKE_CONFIG = {
    'K_values': [2, 4, 8],
    'T': 500,
    'd': 3,
    'rho': 0.7,
    'algorithms': ['SP-UCB-OLP', 'Greedy', 'Oracle'],
    'seeds': [42, 43],
    'worker_seeds': {0: [42, 43]},
    'algorithm_configs': {
        'SP-UCB-OLP': {'alpha': 0.1, 'warm_start': True},
        'Greedy': {},
        'Oracle': {},
    },
}


# =============================================================================
# SINGLE RUN FUNCTION
# =============================================================================

def run_single_experiment(
    K: int,
    algorithm_name: str,
    seed: int,
    config: dict,
    loader=None,
    oracle_values=None
) -> dict:
    """Run a single K-sweep experiment."""
    T = config['T']
    d = config['d']
    rho = config['rho']

    # Use provided loader or create new one
    if loader is None:
        loader = get_loader_for_K(K, T, seed, d)

    d = loader.d
    B = loader.get_budget(rho)

    # Use provided oracle values or compute new ones
    if oracle_values is None:
        oracle_values = loader.get_oracle_values(rho=rho)
    V_mix = oracle_values['V_mix']
    V_star = oracle_values['V_star']

    # Setup algorithm config
    alg_config = config['algorithm_configs'].get(algorithm_name, {}).copy()
    if algorithm_name == 'Oracle':
        alg_config['w_star'] = oracle_values['w_star']
        alg_config['p_star'] = oracle_values['p_star']

    # Create algorithm
    algorithm = get_algorithm(algorithm_name, K, d, T, B, alg_config)

    # Run simulation
    start_time = time.time()

    for t in range(T):
        theta, w, p = algorithm.select_config(t)
        r, a = loader.get_arrival(theta, t)
        accept = algorithm.decide_admission(t, theta, r, a, p)
        algorithm.update(t, theta, r, a, accept)

    elapsed_time = time.time() - start_time

    # Collect results
    stats = algorithm.get_statistics()

    # Compute metrics
    competitive_ratio_mix = stats['total_reward'] / (T * V_mix) if V_mix > 0 else 0.0
    competitive_ratio_star = stats['total_reward'] / (T * V_star) if V_star > 0 else 0.0
    regret_mix = T * V_mix - stats['total_reward']
    regret_star = T * V_star - stats['total_reward']

    # Normalized regret for sqrt(K) scaling analysis
    regret_normalized = regret_mix / np.sqrt(T) if T > 0 else 0.0

    result = {
        # Identifiers
        'K': K,
        'algorithm': algorithm_name,
        'rho': rho,
        'seed': seed,

        # Key metrics
        'total_reward': stats['total_reward'],
        'competitive_ratio_mix': competitive_ratio_mix,
        'competitive_ratio_star': competitive_ratio_star,
        'regret_mix': regret_mix,
        'regret_star': regret_star,
        'regret_normalized': regret_normalized,
        'acceptance_rate': stats['acceptance_rate'],

        # Budget info
        'B_remaining': stats['B_remaining'],
        'budget_utilization': list(stats['budget_utilization']),

        # Oracle values
        'V_mix': V_mix,
        'V_star': V_star,

        # Metadata
        'T': T,
        'd': d,
        'elapsed_time': elapsed_time,
    }

    return result


# =============================================================================
# WORKER FUNCTION
# =============================================================================

def run_worker(worker_id: int, results_dir: Path, config: dict, verbose: bool = True):
    """Run all K-sweep experiments for a single worker."""
    seeds = config['worker_seeds'][worker_id]
    K_values = config['K_values']
    algorithms = config['algorithms']

    total_runs = len(seeds) * len(K_values) * len(algorithms)

    if verbose:
        print(f"Worker {worker_id}: Starting {total_runs} runs")
        print(f"  Seeds: {seeds}")
        print(f"  K values: {K_values}")
        print(f"  Algorithms: {algorithms}")

    results = []
    run_count = 0
    start_time = time.time()

    for seed in seeds:
        for K in K_values:
            # Create loader once per (K, seed)
            T = config['T']
            d = config['d']
            loader = get_loader_for_K(K, T, seed, d)

            # Compute oracle once per (K, seed)
            oracle_values = loader.get_oracle_values(rho=config['rho'])

            for algorithm in algorithms:
                run_count += 1

                if verbose and run_count % 10 == 0:
                    elapsed = time.time() - start_time
                    eta = elapsed / run_count * (total_runs - run_count)
                    print(f"  Worker {worker_id}: {run_count}/{total_runs} "
                          f"({100*run_count/total_runs:.0f}%) ETA: {eta:.0f}s")

                try:
                    result = run_single_experiment(
                        K, algorithm, seed, config,
                        loader=loader, oracle_values=oracle_values
                    )
                    results.append(result)
                except Exception as e:
                    print(f"  ERROR Worker {worker_id}: K={K}/{algorithm}/seed={seed}: {e}")
                    import traceback
                    traceback.print_exc()
                    results.append({
                        'K': K,
                        'algorithm': algorithm,
                        'seed': seed,
                        'error': str(e),
                    })

    total_time = time.time() - start_time

    if verbose:
        print(f"Worker {worker_id}: Completed {total_runs} runs in {total_time:.1f}s")

    # Save results
    output_file = results_dir / f"worker_{worker_id}_results.json"
    with open(output_file, 'w') as f:
        json.dump({
            'worker_id': worker_id,
            'seeds': seeds,
            'total_runs': total_runs,
            'total_time': total_time,
            'results': results,
        }, f, indent=2)

    if verbose:
        print(f"Worker {worker_id}: Saved to {output_file}")

    return results


# =============================================================================
# COMBINE RESULTS
# =============================================================================

def combine_worker_results(results_dir: Path, verbose: bool = True):
    """Combine all worker results into a single file."""
    all_results = []

    for worker_id in range(10):
        worker_file = results_dir / f"worker_{worker_id}_results.json"
        if worker_file.exists():
            with open(worker_file) as f:
                data = json.load(f)
                all_results.extend(data['results'])
                if verbose:
                    print(f"Loaded {len(data['results'])} results from worker {worker_id}")

    # Save combined results
    combined_file = results_dir / "combined_results.json"
    with open(combined_file, 'w') as f:
        json.dump({
            'total_results': len(all_results),
            'results': all_results,
        }, f, indent=2)

    if verbose:
        print(f"\nCombined {len(all_results)} results into {combined_file}")

    # Create summary statistics
    create_summary(all_results, results_dir, verbose)

    return all_results


def create_summary(results: list, results_dir: Path, verbose: bool = True):
    """Create summary statistics focusing on sqrt(K) scaling."""
    import pandas as pd

    # Convert to DataFrame
    df = pd.DataFrame([r for r in results if 'error' not in r])

    if df.empty:
        print("No valid results to summarize")
        return

    # Group by K and algorithm
    summary = df.groupby(['K', 'algorithm']).agg({
        'regret_mix': ['mean', 'std'],
        'regret_normalized': ['mean', 'std'],
        'competitive_ratio_mix': ['mean', 'std'],
        'elapsed_time': ['mean', 'sum'],
    }).round(4)

    # Save summary
    summary_file = results_dir / "summary_statistics.csv"
    summary.to_csv(summary_file)

    if verbose:
        print(f"\nSummary saved to {summary_file}")
        print("\n" + "="*70)
        print("K-SWEEP SUMMARY: Mean Regret/sqrt(T) by K and Algorithm")
        print("(If regret scales as sqrt(KT), this should scale as sqrt(K))")
        print("="*70)

        pivot = df.pivot_table(
            values='regret_normalized',
            index='K',
            columns='algorithm',
            aggfunc='mean'
        ).round(4)
        print(pivot.to_string())

        # Show sqrt(K) scaling analysis
        print("\n" + "="*70)
        print("SQRT(K) SCALING ANALYSIS")
        print("Regret_normalized / sqrt(K) should be roughly constant")
        print("="*70)

        for alg in df['algorithm'].unique():
            alg_df = df[df['algorithm'] == alg]
            print(f"\n{alg}:")
            for K in sorted(df['K'].unique()):
                k_df = alg_df[alg_df['K'] == K]
                if not k_df.empty:
                    mean_regret_norm = k_df['regret_normalized'].mean()
                    scaled = mean_regret_norm / np.sqrt(K)
                    print(f"  K={K:2d}: regret/sqrt(T)={mean_regret_norm:.2f}, "
                          f"/ sqrt(K)={scaled:.2f}")


# =============================================================================
# MAIN
# =============================================================================

def main():
    parser = argparse.ArgumentParser(description='Run K-sweep experiments')
    parser.add_argument('--worker', type=str, default='0',
                       help='Worker ID (0-9) or "all" or "combine"')
    parser.add_argument('--results-dir', type=str, default='./results/k_sweep',
                       help='Results directory')
    parser.add_argument('--smoke', action='store_true',
                       help='Run quick smoke test')
    parser.add_argument('--quiet', action='store_true', help='Suppress output')

    args = parser.parse_args()

    # Select config
    if args.smoke:
        config = SMOKE_CONFIG
        results_dir = Path(args.results_dir) / 'smoke_test'
    else:
        config = FULL_CONFIG
        results_dir = Path(args.results_dir)

    results_dir.mkdir(parents=True, exist_ok=True)
    verbose = not args.quiet

    if verbose:
        n_K = len(config['K_values'])
        n_seeds = len(config['seeds'])
        n_alg = len(config['algorithms'])
        total = n_K * n_seeds * n_alg
        print(f"Config: K_values={config['K_values']}, T={config['T']}, "
              f"seeds={n_seeds}, algorithms={n_alg}")
        print(f"Total runs: {total}")

    if args.worker == 'combine':
        combine_worker_results(results_dir, verbose)
    elif args.worker == 'all' or args.smoke:
        if verbose:
            print("="*70)
            print("RUNNING ALL WORKERS" + (" (SMOKE TEST)" if args.smoke else ""))
            print("="*70)

        start = time.time()
        n_workers = 1 if args.smoke else 10
        for worker_id in range(n_workers):
            run_worker(worker_id, results_dir, config, verbose)

        if verbose:
            print(f"\nCompleted in {time.time()-start:.1f}s")

        combine_worker_results(results_dir, verbose)
    else:
        worker_id = int(args.worker)
        if worker_id < 0 or worker_id > 9:
            print(f"Error: worker must be 0-9, got {worker_id}")
            sys.exit(1)
        run_worker(worker_id, results_dir, config, verbose)


if __name__ == '__main__':
    main()
