import numpy as np
import torch
import matplotlib.pyplot as plt
from multiprocessing import Pool
from datetime import datetime
import os
import argparse
from tqdm import tqdm
from dataclasses import dataclass
from typing import Callable, Type, Optional

from src.wpm import WPMSolver, wpm_swf
from src.kolm import KolmSolver, kolm_swf
from src.gini import GiniSolver, gini_swf

from src.armsets import BetaArmSet
from src.sampler import PiPSSampler
from src.swf_ucb import SWFUCB

from src.logger import log_experiment


@dataclass
class ExperimentResults:
    """Container for experiment results - no side effects, just data."""

    all_regrets: np.ndarray  # Raw regrets from all experiments: (n_experiments, n_rounds)
    weights: torch.Tensor  # Weights used in experiment
    avg_regret: np.ndarray  # Mean per-step regret across experiments
    std_error: np.ndarray  # Standard error of per-step regret
    avg_cum_regret: np.ndarray  # Mean cumulative regret across experiments
    std_cum_error: np.ndarray  # Standard error of cumulative regret


def run_single_experiment(
    seed: int,
    n_arms: int,
    n_rounds: int,
    num_alloc: int,
    delta: float,
    weights: torch.Tensor,
    solver_cls: Type,
    swf_func: Callable,
    pow_val: float,
    objective: str,
    verbose: bool = True,
) -> np.ndarray:
    """
    Run a single experiment instance.

    This is the innermost experiment loop - it creates one arm set, one UCB instance,
    and runs for n_rounds, returning the per-step regret array.

    Args:
        seed: Random seed for this experiment instance.
        n_arms: Number of arms for bandit.
        n_rounds: Number of rounds per experiment.
        num_alloc: Number of arms (k) to allocate per round.
        delta: UCB exploration parameter.
        weights: SWF weights (shared across experiments).
        solver_cls: Solver class for the SWF.
        swf_func: Social welfare function to use.
        pow_val: Power parameter for WPM/Kolm.
        objective: Objective function name (wpm, kolm, gini).
        verbose: Whether to print progress and results.

    Returns:
        Array of per-step regrets with shape (n_rounds,).
    """
    gen = torch.Generator().manual_seed(int(seed))
    torch.manual_seed(int(seed))

    # Initialize ArmSet
    arm_set = BetaArmSet(n_arms, gen)

    # Initialize Solver (Gini has different signature - no pow parameter)
    if objective == "gini":
        solver = solver_cls(weights, num_alloc)
    else:
        solver = solver_cls(weights, pow_val, num_alloc)

    # Calculate optimal allocation and SWF value
    opt_probs = solver.get_allocation_probabilities(arm_set.means)
    if objective == "gini":
        opt_swf = swf_func(arm_set.means * opt_probs, weights)
    else:
        opt_swf = swf_func(arm_set.means * opt_probs, weights, pow_val)

    # Initialize Sampler and UCB
    sampler = PiPSSampler(gen)
    ucb = SWFUCB(n_arms, num_alloc, solver, sampler, delta=delta)

    # Run the bandit loop
    regrets = []
    current_swf = None

    iterator = tqdm(range(n_rounds)) if verbose else range(n_rounds)
    regret_sum = 0.0
    for t in iterator:
        inds = ucb.select_arms()
        rewards = arm_set.sample(inds)
        ucb.update(inds, rewards)

        # Calculate regret for this step
        current_probs = ucb.probs
        if objective == "gini":
            current_swf = swf_func(arm_set.means * current_probs, weights)
        else:
            current_swf = swf_func(arm_set.means * current_probs, weights, pow_val)

        regret = opt_swf - current_swf
        regrets.append(regret.item())
        regret_sum += regret.item()

    if verbose:
        print(f"""
        Optimal:
            Probs: {opt_probs}
            SWF: {opt_swf.item()}
        Final UCB:
            Probs: {ucb.probs}
            SWF: {current_swf.item()}
        """)

    # return regret_sum
    return np.array(regrets)


def get_solver_and_swf(objective: str) -> tuple[Type, Callable]:
    """
    Get the solver class and SWF function for a given objective.

    Args:
        objective: One of 'wpm', 'kolm', 'gini'.

    Returns:
        Tuple of (solver_class, swf_function).

    Raises:
        ValueError: If objective is not recognized.
    """
    if objective == "wpm":
        return WPMSolver, wpm_swf
    elif objective == "kolm":
        return KolmSolver, kolm_swf
    elif objective == "gini":
        return GiniSolver, gini_swf
    else:
        raise ValueError(f"Unknown objective: {objective}")


def generate_weights(gen: torch.Generator, n_arms: int, objective: str) -> torch.Tensor:
    """
    Generate SWF weights for the experiment.

    Args:
        rng: NumPy random generator.
        n_arms: Number of arms.
        objective: Objective function name (affects weight sorting for Gini).

    Returns:
        Tensor of weights on probability simplex.
    """
    weights = torch.randn(n_arms, dtype=torch.float64, generator=gen) + 0.1
    weights = torch.softmax(weights, dim=0)

    # Gini requires non-increasing weights
    if objective == "gini":
        weights, _ = torch.sort(weights, descending=True)

    return weights


def run_experiments_core(
    n_experiments: int,
    n_arms: int,
    n_rounds: int,
    num_alloc: int,
    delta: float,
    seed: int,
    objective: str,
    pow_val: float,
    parallel: bool = True,
    verbose: bool = True,
) -> ExperimentResults:
    """
    Core experiment logic - pure function with no I/O side effects.

    This function contains all the computational logic for running experiments:
    setting up RNG, generating weights, running experiment instances, and
    aggregating statistics. It does NOT do any logging, plotting, or printing
    (beyond what run_single_experiment does when verbose=True).

    This design makes the core logic easy to test for reproducibility without
    worrying about file I/O, timestamps, or plot windows.

    Args:
        n_experiments: Number of independent experiments to run.
        n_arms: Number of arms for bandit.
        n_rounds: Number of rounds per experiment.
        num_alloc: Number of arms to allocate per round.
        delta: UCB exploration parameter.
        seed: Master random seed (determines weights and per-experiment seeds).
        objective: Objective function (wpm, kolm, gini).
        pow_val: Power parameter for WPM/Kolm.
        parallel: Whether to run experiments in parallel (set False for testing).
        verbose: Whether to show progress bars and print results.

    Returns:
        ExperimentResults dataclass containing all results and statistics.
    """
    # Set master seed - this controls weight generation and child seeds
    gen = torch.Generator().manual_seed(seed)

    # Generate weights (deterministic given seed)
    weights = generate_weights(gen, n_arms, objective)

    # Get solver class and SWF function
    solver_cls, swf_func = get_solver_and_swf(objective)

    # Generate seeds for each experiment (deterministic given master seed)
    experiment_seeds = torch.randint(0, 1_000_000, (n_experiments,), generator=gen).numpy()

    # Build argument list for experiments
    args_list = [
        (
            exp_seed,
            n_arms,
            n_rounds,
            num_alloc,
            delta,
            weights,
            solver_cls,
            swf_func,
            pow_val,
            objective,
            verbose,
        )
        for exp_seed in experiment_seeds
    ]

    # Run experiments (parallel or sequential)
    if parallel and n_experiments > 1:
        with Pool(processes=n_experiments) as pool:
            results = pool.starmap(run_single_experiment, args_list)
    else:
        # Sequential execution - useful for debugging and reproducibility tests
        results = [run_single_experiment(*args) for args in args_list]

    all_results = np.array(results)
    results_mean = np.mean(all_results)
    results_std = np.std(all_results)

    all_regrets = np.array(results)

    # Compute statistics
    all_cumulative_regrets = np.cumsum(all_regrets, axis=1)
    avg_cum_regret = np.mean(all_cumulative_regrets, axis=0)
    std_cum_error = np.std(all_cumulative_regrets, axis=0) / np.sqrt(n_experiments)

    avg_regret = np.mean(all_regrets, axis=0)
    std_error = np.std(all_regrets, axis=0) / np.sqrt(n_experiments)

    return ExperimentResults(
        all_regrets=all_regrets,
        weights=weights,
        avg_regret=avg_regret,
        std_error=std_error,
        avg_cum_regret=avg_cum_regret,
        std_cum_error=std_cum_error,
    )


def plot_regret_curves(
    avg_regret: np.ndarray,
    std_error: np.ndarray,
    avg_cum_regret: np.ndarray,
    std_cum_error: np.ndarray,
    objective: str,
    n_arms: int,
    num_alloc: int,
    timestamp: str,
    show: bool = True,
) -> str:
    """
    Plot cumulative and per-step regret curves.

    Args:
        avg_regret: Average per-step regret across experiments.
        std_error: Standard error of per-step regret.
        avg_cum_regret: Average cumulative regret across experiments.
        std_cum_error: Standard error of cumulative regret.
        objective: Objective function name (wpm, kolm, gini).
        n_arms: Number of arms.
        num_alloc: Number of allocations per round.
        timestamp: Timestamp string for output directory.
        show: Whether to display the plot interactively.

    Returns:
        Path to the saved plot.
    """
    n_rounds = len(avg_regret)

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

    # Cumulative Regret Plot
    ax1.plot(avg_cum_regret, label=f"SWF-UCB ({objective})")
    ax1.fill_between(
        np.arange(n_rounds),
        avg_cum_regret - 1.96 * std_cum_error,
        avg_cum_regret + 1.96 * std_cum_error,
        alpha=0.2,
    )
    ax1.set_xlabel("Rounds")
    ax1.set_ylabel("Cumulative Regret")
    ax1.set_title(f"Cumulative Regret ({objective})")
    ax1.legend()
    ax1.grid()

    # Per-Step Regret Plot
    ax2.plot(avg_regret, label=f"SWF-UCB ({objective})", color="orange")
    ax2.fill_between(
        np.arange(n_rounds),
        avg_regret - 1.96 * std_error,
        avg_regret + 1.96 * std_error,
        alpha=0.2,
        color="orange",
    )
    ax2.set_xlabel("Rounds")
    ax2.set_ylabel("Per-Step Regret")
    ax2.set_title(f"Per-Step Regret ({objective})")
    ax2.legend()
    ax2.grid()

    plt.tight_layout()

    # Create timestamped output directory
    plot_dir = f"./plots/{timestamp}"
    os.makedirs(plot_dir, exist_ok=True)
    plot_path = f"{plot_dir}/swf_ucb_{objective}_arms_{n_arms}_alloc_{num_alloc}.png"
    plt.savefig(plot_path, dpi=150)
    print(f"Plot saved to {plot_path}")

    if show:
        plt.show()
    plt.close()

    return plot_path


def multi_run_experiment(
    n_experiments: int,
    n_arms: int,
    n_rounds: int,
    num_alloc: int,
    delta: float,
    seed: int,
    objective: str,
    pow_val: float,
    log: bool = True,
    show_plot: bool = True,
) -> dict:
    """
    Run multiple experiments with full I/O: logging, plotting, and console output.

    This is the main entry point for running experiments from the command line.
    It wraps run_experiments_core and adds all the side effects (printing,
    logging to disk, plotting).

    Args:
        n_experiments: Number of independent experiments to run.
        n_arms: Number of arms for bandit.
        n_rounds: Number of rounds per experiment.
        num_alloc: Number of arms to allocate per round.
        delta: UCB exploration parameter.
        seed: Master random seed.
        objective: Objective function (wpm, kolm, gini).
        pow_val: Power parameter for WPM/Kolm.
        log: Whether to log experiment to runs/ directory.
        show_plot: Whether to display plots interactively.

    Returns:
        Dictionary containing aggregated results and metadata.
    """
    timestamp = datetime.now().strftime("%Y-%m-%d_%H.%M.%S")

    # Print experiment configuration
    print(f"Running {objective.upper()} Experiment")
    print(
        f"Params: Arms={n_arms}, Rounds={n_rounds}, Alloc={num_alloc}, "
        f"Pow={pow_val}, delta={delta}, Seed={seed}"
    )

    # Run the core experiment logic (pure computation, no side effects)
    results = run_experiments_core(
        n_experiments=n_experiments,
        n_arms=n_arms,
        n_rounds=n_rounds,
        num_alloc=num_alloc,
        delta=delta,
        seed=seed,
        objective=objective,
        pow_val=pow_val,
        parallel=True,
        verbose=True,
    )

    print(f"Final Average Cumulative Regret: {results.avg_cum_regret[-1]:.4f}")

    # ==================== LOGGING (side effect) ====================
    if log:
        config = {
            "objective": objective,
            "n_arms": n_arms,
            "n_rounds": n_rounds,
            "num_alloc": num_alloc,
            "delta": delta,
            "pow_val": pow_val if pow_val != -torch.inf else "-inf",
            "seed": seed,
            "n_experiments": n_experiments,
            "weights": results.weights.tolist(),
        }
        results_dict = {
            "regrets": results.avg_regret.tolist(),
            "cumulative_regrets": results.avg_cum_regret.tolist(),
            "std_error": results.std_error.tolist(),
            "std_cum_error": results.std_cum_error.tolist(),
            "avg_final_cumulative_regret": float(results.avg_cum_regret[-1]),
        }
        log_experiment(config, results_dict, run_id=f"{timestamp}_{objective}")

    # ==================== PLOTTING (side effect) ====================
    plot_path = plot_regret_curves(
        avg_regret=results.avg_regret,
        std_error=results.std_error,
        avg_cum_regret=results.avg_cum_regret,
        std_cum_error=results.std_cum_error,
        objective=objective,
        n_arms=n_arms,
        num_alloc=num_alloc,
        timestamp=timestamp,
        show=show_plot,
    )

    return {
        "avg_regret": results.avg_regret,
        "std_error": results.std_error,
        "avg_cum_regret": results.avg_cum_regret,
        "std_cum_error": results.std_cum_error,
        "plot_path": plot_path,
        "timestamp": timestamp,
    }


def parse_pow_val(value: str) -> float:
    """Parse power value, handling special cases like 'inf' and '-inf'."""
    value = value.lower().strip()
    if value in ("-inf", "-infinity", "neginf"):
        return -torch.inf
    elif value in ("inf", "infinity"):
        return torch.inf
    else:
        return float(value)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Run SWF-UCB Experiment",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    # Core experiment parameters
    parser.add_argument("--n-arms", type=int, default=20, help="Number of arms for bandit")
    parser.add_argument(
        "--n-rounds", type=int, default=5000, help="Number of rounds per experiment"
    )
    parser.add_argument(
        "--num-alloc", type=int, default=10, help="Number of arms (k) to allocate per round"
    )
    parser.add_argument(
        "--n-experiments", type=int, default=5, help="Number of independent experiments to run"
    )
    parser.add_argument("--delta", type=float, default=0.05, help="UCB exploration parameter")
    parser.add_argument("--seed", type=int, default=42, help="Master random seed")
    parser.add_argument(
        "--pow",
        type=str,
        default="-inf",
        help="Power parameter for WPM/Kolm (use '-inf' for egalitarian)",
    )

    # Objective selection
    parser.add_argument(
        "--objective",
        type=str,
        default="wpm",
        choices=["wpm", "kolm", "gini"],
        help="Objective function",
    )

    # Output options
    parser.add_argument("--no-log", action="store_true", help="Disable experiment logging")
    parser.add_argument("--no-show", action="store_true", help="Don't display plots (just save)")

    args = parser.parse_args()

    # Parse power value (handles -inf, inf, and numeric values)
    pow_val = parse_pow_val(args.pow)

    # Validate num_alloc <= n_arms
    if args.num_alloc > args.n_arms:
        parser.error(f"num-alloc ({args.num_alloc}) cannot exceed n-arms ({args.n_arms})")

    # Validate/adjust pow based on objective
    if args.objective == "kolm" and pow_val > 0:
        print(f"Warning: pow ({pow_val}) invalid for Kolm (requires pow <= 0)")
        pow_val = -1.0
        print(f"Changed pow to {pow_val}")

    # Run the experiment
    multi_run_experiment(
        n_experiments=args.n_experiments,
        n_arms=args.n_arms,
        n_rounds=args.n_rounds,
        num_alloc=args.num_alloc,
        delta=args.delta,
        seed=args.seed,
        objective=args.objective,
        pow_val=pow_val,
        log=not args.no_log,
        show_plot=not args.no_show,
    )
