"""
Experiment Runner

Orchestrates running algorithms on data loaders and collecting results.
Supports parallel execution and progress tracking.
"""

import numpy as np
from typing import Dict, List, Tuple, Any, Optional, Type
from dataclasses import dataclass, field
import time
import json
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor, as_completed
import traceback

from .storage import RunTrajectory, ExperimentResults, create_results_directory
from .oracle import OracleComputer, compute_oracle_from_loader
from .algorithms import BaseAlgorithm, get_algorithm
from .data import BaseDataLoader


@dataclass
class RunConfig:
    """Configuration for a single experiment run."""
    algorithm: str
    seed: int
    rho: float = 1.0  # Budget factor
    alpha: float = 1.0  # UCB exploration parameter
    extra_config: Dict[str, Any] = field(default_factory=dict)


def run_single_experiment(
    loader: BaseDataLoader,
    algorithm: BaseAlgorithm,
    B: np.ndarray,
    seed: int,
    record_trajectory: bool = True
) -> Tuple[RunTrajectory, Dict[str, Any]]:
    """
    Run a single experiment with given loader and algorithm.

    Parameters
    ----------
    loader : BaseDataLoader
        Data loader providing arrivals
    algorithm : BaseAlgorithm
        Algorithm to run
    B : np.ndarray
        Total budget vector
    seed : int
        Random seed
    record_trajectory : bool
        Whether to record full trajectory (memory intensive)

    Returns
    -------
    trajectory : RunTrajectory
        Recorded trajectory (empty if record_trajectory=False)
    stats : Dict[str, Any]
        Final statistics
    """
    np.random.seed(seed)

    T = loader.T
    K = loader.K
    d = loader.d

    # Initialize trajectory
    trajectory = RunTrajectory(
        run_id=f"{algorithm.__class__.__name__}_{seed}",
        algorithm=algorithm.__class__.__name__,
        experiment_family=loader.get_metadata().get('family', 'unknown'),
        seed=seed
    )

    if record_trajectory:
        trajectory.initialize(T, K, d)

    # Reset algorithm
    algorithm.reset()

    # Run experiment
    for t in range(T):
        # Select configuration
        theta, w, p = algorithm.select_config(t)

        # Get arrival
        r, a = loader.get_arrival(theta, t)

        # Decide admission
        accept = algorithm.decide_admission(t, theta, r, a, p)
        x = 1 if accept else 0

        # Record trajectory
        if record_trajectory:
            beta = np.array([
                algorithm.compute_confidence_radius(k, t)
                for k in range(K)
            ])
            trajectory.record_step(t, theta, r, a, x, p, w, beta)

        # Update algorithm state
        algorithm.update(t, theta, r, a, accept)

    # Get final statistics
    stats = algorithm.get_statistics()
    stats['seed'] = seed

    # Finalize trajectory
    if record_trajectory:
        trajectory.finalize(B)
    else:
        # Even without full trajectory, we need essential metrics for results
        trajectory.total_reward = stats['total_reward']
        trajectory.acceptance_rate = stats['acceptance_rate']
        trajectory.budget_violation = np.any(algorithm.B_remaining < -1e-9)
        trajectory.T = T
        trajectory.K = K
        trajectory.d = d

    return trajectory, stats


class ExperimentRunner:
    """
    Main experiment runner class.

    Handles:
    - Running multiple algorithms on a data loader
    - Multiple random seeds for statistical significance
    - Computing oracle values
    - Saving results
    """

    def __init__(
        self,
        loader: BaseDataLoader,
        rho: float = 1.0,
        results_dir: str = "./results"
    ):
        """
        Initialize experiment runner.

        Parameters
        ----------
        loader : BaseDataLoader
            Data loader to use
        rho : float
            Budget scaling factor
        results_dir : str
            Directory to save results
        """
        self.loader = loader
        self.rho = rho
        self.results_dir = Path(results_dir)
        self.results_dir.mkdir(parents=True, exist_ok=True)

        # Get experiment parameters
        self.K = loader.K
        self.d = loader.d
        self.T = loader.T
        self.B = loader.get_budget(rho)
        self.b = self.B / self.T

        # Compute oracle values
        self.oracle_result = None

    def compute_oracle(self, n_samples: int = None) -> Dict[str, Any]:
        """Compute oracle values from data."""
        if n_samples is None:
            n_samples = min(self.T, 10000)

        samples_dict = self.loader.get_samples_dict()

        oracle = OracleComputer(self.K, self.d)
        self.oracle_result = oracle.compute_full_oracle(samples_dict, self.b)

        return {
            'V_mix': self.oracle_result.V_mix,
            'V_star': self.oracle_result.V_star,
            'gap': self.oracle_result.gap,
            'w_star': self.oracle_result.w_star.tolist(),
            'p_star': self.oracle_result.p_star.tolist(),
            'best_fixed_config': self.oracle_result.best_fixed_config,
        }

    def run_algorithm(
        self,
        algorithm_name: str,
        seeds: List[int],
        config: Dict[str, Any] = None,
        record_trajectories: bool = False,
        verbose: bool = True
    ) -> List[Tuple[RunTrajectory, Dict[str, Any]]]:
        """
        Run an algorithm with multiple seeds.

        Parameters
        ----------
        algorithm_name : str
            Name of algorithm (see algorithms/__init__.py)
        seeds : List[int]
            Random seeds to use
        config : Dict[str, Any]
            Algorithm configuration
        record_trajectories : bool
            Whether to record full trajectories
        verbose : bool
            Print progress

        Returns
        -------
        results : List[Tuple[RunTrajectory, Dict[str, Any]]]
            List of (trajectory, stats) tuples
        """
        config = config or {}

        # Add oracle info if available
        if self.oracle_result is not None:
            config['w_star'] = self.oracle_result.w_star
            config['p_star'] = self.oracle_result.p_star

        results = []

        for i, seed in enumerate(seeds):
            if verbose:
                print(f"  Running {algorithm_name} seed {seed} ({i+1}/{len(seeds)})...", end=" ")

            start_time = time.time()

            # Create algorithm
            algorithm = get_algorithm(
                algorithm_name, self.K, self.d, self.T, self.B, config
            )

            # Run experiment
            trajectory, stats = run_single_experiment(
                self.loader, algorithm, self.B, seed, record_trajectories
            )

            elapsed = time.time() - start_time

            if verbose:
                reward = stats['total_reward']
                accept_rate = stats['acceptance_rate']
                print(f"reward={reward:.2f}, accept={accept_rate:.2%}, time={elapsed:.1f}s")

            results.append((trajectory, stats))

        return results

    def run_comparison(
        self,
        algorithm_names: List[str],
        seeds: List[int],
        configs: Dict[str, Dict[str, Any]] = None,
        record_trajectories: bool = False,
        verbose: bool = True
    ) -> ExperimentResults:
        """
        Run comparison across multiple algorithms.

        Parameters
        ----------
        algorithm_names : List[str]
            List of algorithm names to compare
        seeds : List[int]
            Random seeds to use
        configs : Dict[str, Dict[str, Any]]
            Per-algorithm configurations
        record_trajectories : bool
            Whether to record trajectories
        verbose : bool
            Print progress

        Returns
        -------
        results : ExperimentResults
            Aggregated comparison results
        """
        configs = configs or {}

        # Compute oracle if not done
        if self.oracle_result is None:
            if verbose:
                print("Computing oracle values...")
            self.compute_oracle()

        if verbose:
            print(f"\nOracle values:")
            print(f"  V^mix = {self.oracle_result.V_mix:.4f}")
            print(f"  V* = {self.oracle_result.V_star:.4f}")
            print(f"  Gap = {self.oracle_result.gap:.4f}")
            print(f"  Best fixed config = {self.oracle_result.best_fixed_config}")
            print()

        # Initialize results container
        exp_results = ExperimentResults(
            experiment_id=f"{self.loader.get_metadata().get('family', 'exp')}_{self.rho}",
            family=self.loader.get_metadata().get('family', 'unknown'),
            n_runs=len(seeds)
        )
        exp_results.T = self.T
        exp_results.K = self.K
        exp_results.d = self.d
        exp_results.budget_factor = self.rho
        exp_results.V_mix = self.oracle_result.V_mix
        exp_results.V_star = self.oracle_result.V_star
        exp_results.w_star = self.oracle_result.w_star
        exp_results.p_star = self.oracle_result.p_star

        # Run each algorithm
        all_results = {}
        for alg_name in algorithm_names:
            if verbose:
                print(f"Running {alg_name}...")

            alg_config = configs.get(alg_name, {})
            results = self.run_algorithm(
                alg_name, seeds, alg_config, record_trajectories, verbose
            )
            all_results[alg_name] = results

            # Add results
            for trajectory, stats in results:
                exp_results.add_run(alg_name, trajectory)

        # Finalize
        exp_results.finalize()

        return exp_results

    def save_results(
        self,
        results: ExperimentResults,
        filename: str = None
    ) -> str:
        """Save results to JSON file."""
        if filename is None:
            filename = f"{results.experiment_id}_results.json"

        filepath = self.results_dir / filename
        results.save_json(str(filepath))

        return str(filepath)


def run_quick_test():
    """Run a quick test of the experiment runner."""
    from .data import S1ComplementarityLoader

    print("=" * 60)
    print("Quick Test: Experiment Runner")
    print("=" * 60)

    # Create loader
    loader = S1ComplementarityLoader(K=4, T=1000)

    # Create runner
    runner = ExperimentRunner(loader, rho=1.0, results_dir="./test_results")

    # Compute oracle
    oracle = runner.compute_oracle()
    print(f"\nOracle: V^mix={oracle['V_mix']:.4f}, V*={oracle['V_star']:.4f}")

    # Run comparison
    algorithms = ['SP-UCB-OLP', 'Greedy', 'Random']
    seeds = [42, 43, 44]

    results = runner.run_comparison(algorithms, seeds, verbose=True)

    # Print summary
    print("\n" + "=" * 60)
    print("Results Summary")
    print("=" * 60)

    for alg in results.algorithms:
        summary = results.get_summary(alg)
        ratio = summary['mean_competitive_ratio']
        regret = summary['mean_regret']
        print(f"{alg}:")
        print(f"  Mean reward: {summary['mean_reward']:.2f}")
        print(f"  Competitive ratio: {ratio:.4f}")
        print(f"  Mean regret: {regret:.2f}")

    # Save results
    filepath = runner.save_results(results)
    print(f"\nResults saved to: {filepath}")

    print("\n" + "=" * 60)


if __name__ == "__main__":
    run_quick_test()
