#!/usr/bin/env python3
"""
Alibaba Experiment Runner for SP-UCB-OLP

Runs the same algorithms as synthetic experiments on Alibaba data:
- SP-UCB-α=0 (greedy)
- SP-UCB-α=0.01 (minimal exploration)
- SP-UCB-α=0.1 (moderate exploration)
- OneHot (per-config UCB)
- Oracle (Monte Carlo computed)

Key differences from synthetic experiments:
1. Non-stationarity handled via SHUFFLING
2. No closed-form oracle - uses MONTE CARLO estimation
3. K=3 regimes (8-bit, 4-bit, batching), d=2 resources (CPU, Memory)

Usage:
    # Run all workers sequentially
    python run_alibaba_experiments.py --worker all

    # Run specific worker
    python run_alibaba_experiments.py --worker 0

    # Run with GNU parallel (recommended)
    seq 0 9 | parallel -j 10 python run_alibaba_experiments.py --worker {}

    # Smoke test
    python run_alibaba_experiments.py --smoke
"""

import argparse
import json
import time
import numpy as np
from pathlib import Path
from datetime import datetime
from scipy.optimize import minimize
from scipy.stats import lognorm, expon, norm
import sys

from sp_ucb_olp.algorithms import get_algorithm

# =============================================================================
# ALIBABA DATA LOADER WITH SHUFFLING
# =============================================================================

SCENARIO_PROFILES = {
    'quant_8bit': {
        'label': 'Quant-8bit',
        'description': 'Moderate budgets with emphasis on high-quality 8-bit inference',
        'budget_scale': np.array([0.9, 0.9]),
        'regime_multipliers': {
            0: {'reward': 1.0, 'cpu': 1.0, 'mem': 1.0},      # 8-bit baseline
            1: {'reward': 0.72, 'cpu': 1.0, 'mem': 0.55},    # 4-bit aggressive
            2: {'reward': 1.1, 'cpu': 0.85, 'mem': 0.92},    # Batching
        },
    },
    'quant_4bit': {
        'label': 'Quant-4bit',
        'description': 'Tight memory budget favouring aggressive quantization',
        'budget_scale': np.array([0.8, 0.55]),
        'regime_multipliers': {
            0: {'reward': 1.0, 'cpu': 1.05, 'mem': 1.0},
            1: {'reward': 0.7, 'cpu': 0.95, 'mem': 0.45},
            2: {'reward': 1.05, 'cpu': 0.8, 'mem': 0.8},
        },
    },
    'batching': {
        'label': 'Batching',
        'description': 'Looser budgets highlighting batching policies',
        'budget_scale': np.array([1.1, 1.0]),
        'regime_multipliers': {
            0: {'reward': 1.05, 'cpu': 0.95, 'mem': 0.95},
            1: {'reward': 0.75, 'cpu': 0.9, 'mem': 0.5},
            2: {'reward': 1.25, 'cpu': 0.7, 'mem': 0.75},
        },
    },
}


class AlibabaShuffledLoader:
    """
    Alibaba data loader with shuffling for pseudo-stationarity.

    Shuffling breaks temporal patterns in the trace data to create
    approximately i.i.d. arrivals from the empirical distribution.

    Parameters
    ----------
    T : int
        Time horizon (default: 10000)
    seed : int
        Random seed for reproducibility
    scenario : str
        Scenario: 'quant_8bit', 'quant_4bit', or 'batching'
    shuffle : bool
        Whether to shuffle arrivals (default: True)
    base_cpu_budget : float
        Base CPU budget before scaling
    base_mem_budget : float
        Base memory budget before scaling
    """

    def __init__(
        self,
        T: int = 10000,
        seed: int = 42,
        scenario: str = 'quant_8bit',
        shuffle: bool = True,
        base_cpu_budget: float = 600.0,
        base_mem_budget: float = 600.0
    ):
        self.K = 3  # 8-bit, 4-bit, batching
        self.d = 2  # CPU, Memory
        self.T = T
        self.seed = seed
        self.shuffle = shuffle

        if scenario not in SCENARIO_PROFILES:
            raise ValueError(f"Unknown scenario '{scenario}'. Supported: {list(SCENARIO_PROFILES.keys())}")

        self.scenario = scenario
        self.scenario_profile = SCENARIO_PROFILES[scenario]
        self.base_cpu_budget = base_cpu_budget
        self.base_mem_budget = base_mem_budget

        # Initialize RNG
        self._rng = np.random.RandomState(seed)

        # Initialize distributions (calibrated to Alibaba characteristics)
        self._mem_dist = lognorm(s=1.8, scale=np.exp(-1))
        self._duration_dist = expon(scale=300)
        self._quality_dist = norm(loc=1.0, scale=0.1)

        # Pre-generate base arrivals
        self._generate_base_arrivals()

        # Shuffle if requested
        if shuffle:
            self._shuffle_arrivals()

        # Generate transformed arrivals for each regime
        self._arrivals = {}
        self._generate_arrivals()

        # Compute nominal budget
        self._compute_nominal_budget()

    def _generate_base_arrivals(self):
        """Generate base arrivals (regime-agnostic)."""
        self._base_arrivals = []

        for _ in range(self.T):
            mem = max(0.01, self._mem_dist.rvs(random_state=self._rng))
            duration = max(0.01, self._duration_dist.rvs(random_state=self._rng))
            quality = max(0.01, self._quality_dist.rvs(random_state=self._rng))

            self._base_arrivals.append({
                'reward': duration * quality,
                'cpu': 0.5,
                'mem': mem,
            })

    def _shuffle_arrivals(self):
        """Shuffle arrival order to create pseudo-i.i.d. environment."""
        # Use a different seed offset for shuffling
        shuffle_rng = np.random.RandomState(self.seed + 1000)
        perm = shuffle_rng.permutation(self.T)
        self._base_arrivals = [self._base_arrivals[i] for i in perm]

    def _generate_arrivals(self):
        """Generate arrivals for all regimes by transforming base arrivals."""
        regime_multipliers = self.scenario_profile['regime_multipliers']

        for theta in range(self.K):
            arrivals = np.zeros((self.T, self.d + 1))
            mult = regime_multipliers[theta]

            for t in range(self.T):
                base = self._base_arrivals[t]

                # Apply regime-specific transformations
                if theta == 2:  # Batching regime has efficiency scaling
                    batch_size = max(1, self._rng.poisson(3))
                    efficiency = np.log1p(batch_size) / np.log(2 + 1e-9)
                    r = base['reward'] * mult['reward'] * efficiency
                else:
                    r = base['reward'] * mult['reward']

                a_cpu = base['cpu'] * mult['cpu']
                a_mem = base['mem'] * mult['mem']

                arrivals[t, 0] = max(0.01, r)
                arrivals[t, 1] = max(0.01, a_cpu)
                arrivals[t, 2] = max(0.01, a_mem)

            self._arrivals[theta] = arrivals

    def _compute_nominal_budget(self):
        """Compute nominal budget based on average consumption."""
        total_cpu = 0.0
        total_mem = 0.0

        for theta in range(self.K):
            total_cpu += np.mean(self._arrivals[theta][:, 1])
            total_mem += np.mean(self._arrivals[theta][:, 2])

        avg_cpu = total_cpu / self.K
        avg_mem = total_mem / self.K

        scale = self.scenario_profile['budget_scale']
        self._nominal_budget = np.array([
            self.base_cpu_budget * scale[0],
            self.base_mem_budget * scale[1]
        ])

    def get_arrival(self, theta: int, t: int):
        """Get arrival at timestep t under configuration theta."""
        if theta < 0 or theta >= self.K:
            raise ValueError(f"theta {theta} out of range [0, {self.K-1}]")
        if t < 0 or t >= self.T:
            raise ValueError(f"t {t} out of range [0, {self.T-1}]")

        arrival = self._arrivals[theta][t]
        return float(arrival[0]), arrival[1:].copy()

    def get_budget(self, rho: float) -> np.ndarray:
        """Get budget scaled by rho."""
        return rho * self._nominal_budget

    def get_all_samples(self, theta: int):
        """Get all samples for a configuration."""
        arrivals = self._arrivals[theta]
        return arrivals[:, 0], arrivals[:, 1:]

    def get_regime_name(self, theta: int) -> str:
        """Get human-readable regime name."""
        names = {0: "8-bit", 1: "4-bit", 2: "Batching"}
        return names.get(theta, f"Regime {theta}")


# =============================================================================
# MONTE CARLO ORACLE COMPUTATION
# =============================================================================

def compute_alibaba_oracle_mc(
    loader: AlibabaShuffledLoader,
    rho: float,
    n_restarts: int = 10
) -> dict:
    """
    Compute switching-aware oracle V^mix via Monte Carlo.

    V^mix(b) = min_p { <p, b> + max_θ g_θ(p) }

    where g_θ(p) = E[(r - <p, a>)_+] is estimated from samples.

    Parameters
    ----------
    loader : AlibabaShuffledLoader
        Data loader with samples
    rho : float
        Budget scaling parameter
    n_restarts : int
        Number of optimization restarts

    Returns
    -------
    dict
        Oracle values including V_mix, w_star, p_star
    """
    K, d, T = loader.K, loader.d, loader.T

    # Collect all samples
    samples = {}
    for theta in range(K):
        rewards, consumptions = loader.get_all_samples(theta)
        samples[theta] = (rewards, consumptions)

    # Per-period budget
    B = loader.get_budget(rho)
    b = B / T

    # Bounds for optimization
    R_max = max(np.max(samples[theta][0]) for theta in range(K))
    A_max = max(np.max(samples[theta][1]) for theta in range(K))
    b_min = np.min(b[b > 0]) if np.any(b > 0) else 1.0
    P_max = R_max / b_min + 1.0

    def empirical_surplus(theta: int, p: np.ndarray) -> float:
        """g_θ(p) = (1/n) Σ (r - <p, a>)_+"""
        rewards, consumptions = samples[theta]
        margins = rewards - consumptions @ p
        return np.mean(np.maximum(margins, 0))

    def objective_V_mix(p: np.ndarray) -> float:
        """min_p { <p, b> + max_θ g_θ(p) }"""
        linear_term = np.dot(p, b)
        surpluses = [empirical_surplus(theta, p) for theta in range(K)]
        envelope = max(surpluses)
        return linear_term + envelope

    # Multi-start optimization
    best_val = np.inf
    best_p = np.zeros(d)

    for restart in range(n_restarts):
        if restart == 0:
            p_init = np.zeros(d)
        else:
            p_init = np.random.uniform(0, P_max / 2, d)

        result = minimize(
            objective_V_mix,
            p_init,
            method='L-BFGS-B',
            bounds=[(0, P_max)] * d,
            options={'maxiter': 200, 'ftol': 1e-8}
        )

        if result.fun < best_val:
            best_val = result.fun
            best_p = result.x.copy()

    # Compute optimal mixture w* (configs achieving envelope at p_star)
    surpluses = np.array([empirical_surplus(theta, best_p) for theta in range(K)])
    max_surplus = np.max(surpluses)
    w_star = (surpluses >= max_surplus - 1e-8).astype(float)
    if w_star.sum() > 0:
        w_star /= w_star.sum()
    else:
        w_star = np.ones(K) / K

    return {
        'V_mix': best_val,
        'V_star': best_val,  # Use same for simplicity
        'w_star': w_star,
        'p_star': best_p,
        'R_max': R_max,
        'A_max': A_max,
        'P_max': P_max,
    }


# =============================================================================
# EXPERIMENT CONFIGURATION
# =============================================================================

FULL_CONFIG = {
    'scenarios': ['quant_8bit', 'quant_4bit', 'batching'],
    'T': 10000,
    'algorithms': {
        'SP-UCB-α=0': {'type': 'SP-UCB-OLP', 'alpha': 0.0},
        'SP-UCB-α=0.01': {'type': 'SP-UCB-OLP', 'alpha': 0.01},
        'SP-UCB-α=0.1': {'type': 'SP-UCB-OLP', 'alpha': 0.1},
        'OneHot': {'type': 'OneHot', 'alpha': 0.1},
        'Oracle': {'type': 'Oracle'},
    },
    'rho_values': [0.5, 0.7, 1.0, 1.2],
    'seeds': list(range(42, 52)),  # 10 seeds
    'worker_seeds': {i: [42 + i] for i in range(10)},
    'shuffle': True,
}

SMOKE_CONFIG = {
    'scenarios': ['quant_8bit'],
    'T': 1000,
    'algorithms': {
        'SP-UCB-α=0': {'type': 'SP-UCB-OLP', 'alpha': 0.0},
        'SP-UCB-α=0.01': {'type': 'SP-UCB-OLP', 'alpha': 0.01},
        'SP-UCB-α=0.1': {'type': 'SP-UCB-OLP', 'alpha': 0.1},
        'OneHot': {'type': 'OneHot', 'alpha': 0.1},
        'Oracle': {'type': 'Oracle'},
    },
    'rho_values': [0.7, 1.0],
    'seeds': [42, 43],
    'worker_seeds': {0: [42, 43]},
    'shuffle': True,
}


# =============================================================================
# SINGLE RUN FUNCTION
# =============================================================================

def run_single_experiment(
    scenario: str,
    algorithm_name: str,
    alg_params: dict,
    rho: float,
    seed: int,
    config: dict,
    loader=None,
    oracle_values=None
) -> dict:
    """Run a single experiment and return results."""
    T = config['T']

    # Use provided loader or create new one
    if loader is None:
        loader = AlibabaShuffledLoader(
            T=T, seed=seed, scenario=scenario,
            shuffle=config.get('shuffle', True)
        )

    K, d = loader.K, loader.d
    B = loader.get_budget(rho)

    # Use provided oracle or compute new one
    if oracle_values is None:
        oracle_values = compute_alibaba_oracle_mc(loader, rho)
    V_mix = oracle_values['V_mix']

    # Setup algorithm config
    alg_config = {
        'R_max': oracle_values['R_max'],
        'A_max': oracle_values['A_max'],
        'warm_start': True,
    }

    if 'alpha' in alg_params:
        alg_config['alpha'] = alg_params['alpha']

    if alg_params['type'] == 'Oracle':
        alg_config['w_star'] = oracle_values['w_star']
        alg_config['p_star'] = oracle_values['p_star']

    # Create algorithm
    algorithm = get_algorithm(alg_params['type'], K, d, T, B, alg_config)

    # Reset random seed for reproducibility
    np.random.seed(seed)

    # Run simulation
    start_time = time.time()

    for t in range(T):
        theta, w, p = algorithm.select_config(t)
        r, a = loader.get_arrival(theta, t)
        accept = algorithm.decide_admission(t, theta, r, a, p)
        algorithm.update(t, theta, r, a, accept)

    elapsed_time = time.time() - start_time

    # Collect results
    stats = algorithm.get_statistics()
    total_reward = stats['total_reward']

    # Compute metrics
    regret = T * V_mix - total_reward
    competitive_ratio = total_reward / (T * V_mix) if V_mix > 0 else 0.0

    return {
        'scenario': scenario,
        'algorithm': algorithm_name,
        'rho': rho,
        'seed': seed,
        'total_reward': total_reward,
        'regret': regret,
        'competitive_ratio': competitive_ratio,
        'acceptance_rate': stats['acceptance_rate'],
        'total_accepts': stats['total_accepts'],
        'V_mix': V_mix,
        'T_V_mix': T * V_mix,
        'K': K,
        'd': d,
        'T': T,
        'elapsed_time': elapsed_time,
        'budget_utilization': list(stats['budget_utilization']),
    }


# =============================================================================
# WORKER FUNCTION
# =============================================================================

def run_worker(worker_id: int, results_dir: Path, config: dict, verbose: bool = True):
    """Run all experiments for a single worker."""
    seeds = config['worker_seeds'][worker_id]
    scenarios = config['scenarios']
    algorithms = config['algorithms']
    rho_values = config['rho_values']

    total_runs = len(seeds) * len(scenarios) * len(algorithms) * len(rho_values)

    if verbose:
        print(f"Worker {worker_id}: Starting {total_runs} runs")
        print(f"  Seeds: {seeds}")
        print(f"  Scenarios: {scenarios}")
        print(f"  Algorithms: {list(algorithms.keys())}")
        print(f"  Rho values: {rho_values}")

    results = []
    run_count = 0
    start_time = time.time()

    for seed in seeds:
        for scenario in scenarios:
            # Create loader once per (scenario, seed)
            loader = AlibabaShuffledLoader(
                T=config['T'], seed=seed, scenario=scenario,
                shuffle=config.get('shuffle', True)
            )

            for rho in rho_values:
                # Compute oracle once per (scenario, seed, rho)
                if verbose:
                    print(f"  Computing MC oracle: {scenario}, rho={rho}, seed={seed}...")
                oracle_values = compute_alibaba_oracle_mc(loader, rho)

                for alg_name, alg_params in algorithms.items():
                    run_count += 1

                    if verbose:
                        print(f"    [{run_count}/{total_runs}] {scenario}/{alg_name}/rho={rho}/seed={seed}")

                    try:
                        result = run_single_experiment(
                            scenario, alg_name, alg_params, rho, seed, config,
                            loader=loader, oracle_values=oracle_values
                        )
                        results.append(result)

                        if verbose:
                            print(f"      CR={result['competitive_ratio']:.3f}, "
                                  f"R={result['total_reward']:.1f} ({result['elapsed_time']:.1f}s)")
                    except Exception as e:
                        print(f"  ERROR: {scenario}/{alg_name}/rho={rho}/seed={seed}: {e}")
                        import traceback
                        traceback.print_exc()
                        results.append({
                            'scenario': scenario,
                            'algorithm': alg_name,
                            'rho': rho,
                            'seed': seed,
                            'error': str(e),
                        })

    total_time = time.time() - start_time

    if verbose:
        print(f"\nWorker {worker_id}: Completed {total_runs} runs in {total_time:.1f}s")

    # Save results
    output_file = results_dir / f"alibaba_worker_{worker_id}_results.json"
    with open(output_file, 'w') as f:
        json.dump({
            'worker_id': worker_id,
            'seeds': seeds,
            'total_runs': total_runs,
            'total_time': total_time,
            'results': results,
        }, f, indent=2)

    if verbose:
        print(f"Worker {worker_id}: Saved to {output_file}")

    return results


# =============================================================================
# COMBINE RESULTS
# =============================================================================

def combine_worker_results(results_dir: Path, verbose: bool = True):
    """Combine all worker results into a single file."""
    all_results = []

    for worker_id in range(10):
        worker_file = results_dir / f"alibaba_worker_{worker_id}_results.json"
        if worker_file.exists():
            with open(worker_file) as f:
                data = json.load(f)
                all_results.extend(data['results'])
                if verbose:
                    print(f"Loaded {len(data['results'])} results from worker {worker_id}")

    # Save combined results
    combined_file = results_dir / "alibaba_combined_results.json"
    with open(combined_file, 'w') as f:
        json.dump({
            'total_results': len(all_results),
            'results': all_results,
        }, f, indent=2)

    if verbose:
        print(f"\nCombined {len(all_results)} results into {combined_file}")

    # Create summary statistics
    create_summary(all_results, results_dir, verbose)

    return all_results


def create_summary(results: list, results_dir: Path, verbose: bool = True):
    """Create summary statistics from raw results."""
    import pandas as pd

    # Filter out errors
    valid_results = [r for r in results if 'error' not in r]

    if not valid_results:
        print("No valid results to summarize!")
        return

    # Convert to DataFrame
    df = pd.DataFrame(valid_results)

    # Save raw per-seed data
    df.to_csv(results_dir / "alibaba_per_seed.csv", index=False)

    # Group by scenario, algorithm, rho
    summary = df.groupby(['scenario', 'algorithm', 'rho']).agg({
        'competitive_ratio': ['mean', 'std', 'min', 'max'],
        'regret': ['mean', 'std'],
        'total_reward': ['mean', 'std'],
        'acceptance_rate': ['mean', 'std'],
        'elapsed_time': ['mean', 'sum'],
    }).round(4)

    # Save summary
    summary_file = results_dir / "alibaba_summary.csv"
    summary.to_csv(summary_file)

    if verbose:
        print(f"\nSummary saved to {summary_file}")
        print("\n" + "="*80)
        print("ALIBABA EXPERIMENT SUMMARY: Mean Competitive Ratio")
        print("="*80)

        pivot = df.pivot_table(
            values='competitive_ratio',
            index=['scenario', 'rho'],
            columns='algorithm',
            aggfunc='mean'
        ).round(4)
        print(pivot.to_string())

        print("\n" + "="*80)
        print("ALIBABA EXPERIMENT SUMMARY: Mean Regret")
        print("="*80)

        pivot_regret = df.pivot_table(
            values='regret',
            index=['scenario', 'rho'],
            columns='algorithm',
            aggfunc='mean'
        ).round(1)
        print(pivot_regret.to_string())


# =============================================================================
# MAIN
# =============================================================================

def main():
    parser = argparse.ArgumentParser(description='Run Alibaba experiments')
    parser.add_argument('--worker', type=str, default='0',
                       help='Worker ID (0-9) or "all" or "combine"')
    parser.add_argument('--results-dir', type=str, default='./results/alibaba',
                       help='Results directory')
    parser.add_argument('--smoke', action='store_true',
                       help='Run quick smoke test')
    parser.add_argument('--quiet', action='store_true', help='Suppress output')

    args = parser.parse_args()

    # Select config
    if args.smoke:
        config = SMOKE_CONFIG
        results_dir = Path(args.results_dir) / 'smoke_test'
    else:
        config = FULL_CONFIG
        results_dir = Path(args.results_dir)

    results_dir.mkdir(parents=True, exist_ok=True)
    verbose = not args.quiet

    if verbose:
        T = config['T']
        n_seeds = len(config['seeds'])
        n_scenarios = len(config['scenarios'])
        n_rho = len(config['rho_values'])
        n_alg = len(config['algorithms'])
        total = n_seeds * n_scenarios * n_rho * n_alg
        print("="*80)
        print("ALIBABA EXPERIMENT RUNNER")
        print("="*80)
        print(f"Config: T={T}, scenarios={n_scenarios}, seeds={n_seeds}, "
              f"rho_values={n_rho}, algorithms={n_alg}")
        print(f"Total runs: {total}")
        print(f"Results dir: {results_dir}")
        print("="*80)

    if args.worker == 'combine':
        combine_worker_results(results_dir, verbose)
    elif args.worker == 'all' or args.smoke:
        if verbose:
            print("\nRUNNING ALL WORKERS" + (" (SMOKE TEST)" if args.smoke else ""))
            print("="*80 + "\n")

        start = time.time()
        n_workers = 1 if args.smoke else 10

        for worker_id in range(n_workers):
            run_worker(worker_id, results_dir, config, verbose)

        if verbose:
            print(f"\nCompleted all workers in {time.time()-start:.1f}s")

        combine_worker_results(results_dir, verbose)
    else:
        worker_id = int(args.worker)
        if worker_id < 0 or worker_id > 9:
            print(f"Error: worker must be 0-9, got {worker_id}")
            sys.exit(1)
        run_worker(worker_id, results_dir, config, verbose)


if __name__ == '__main__':
    main()
