"""
Experiment orchestration helpers for the Exploration Free project.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Callable, Dict, Iterable, List, Mapping, MutableMapping, Sequence

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from .core import (
    adaptive_algorithm_GSG,
    adaptive_algorithm_gaussian,
    adaptive_algorithm_gsg,
    adaptive_algorithm_ssg,
    compute_regret,
    make_gauss_and_rademacher_generators,
)

Algorithm = Callable[..., tuple]


@dataclass
class ExperimentResult:
    """Container for regret statistics collected across a sweep."""

    per_run: pd.DataFrame
    summary: pd.DataFrame
    regret_mean: Dict[float, List[float]]
    regret_lb: Dict[float, List[float]]
    regret_ub: Dict[float, List[float]]

    def save_to_excel(self, prefix: str) -> None:
        """Write the per-run and summary dataframes to ``{prefix}_*.xlsx``."""
        self.per_run.to_excel(f"{prefix}_runs.xlsx", index=False)
        self.summary.to_excel(f"{prefix}_summary.xlsx", index=False)


def default_delta_rule(p_norm: float, horizon: float) -> float:
    """Default confidence schedule used across several experiments."""
    if p_norm == np.inf:
        return 1.0 / (horizon**2)
    return 1.0 / (horizon**2.5)


def run_regret_sweep(
    algorithm: Algorithm,
    T_list: Sequence[int],
    p_list: Sequence[float],
    reward_generators: Sequence[Callable[[int | None], Iterable[float]]],
    stds: Sequence[float],
    sigma_min: float,
    sigma_sg: float,
    *,
    runs: int = 100,
    confidence: float = 0.90,
    delta_rule: Callable[[float, float], float] = default_delta_rule,
) -> ExperimentResult:
    """Run a regret sweep over ``T_list`` and ``p_list`` for the given algorithm."""
    regret_mean: Dict[float, List[float]] = {p: [] for p in p_list}
    regret_lb: Dict[float, List[float]] = {p: [] for p in p_list}
    regret_ub: Dict[float, List[float]] = {p: [] for p in p_list}
    per_run_records: List[Dict[str, object]] = []
    summary_records: List[Dict[str, object]] = []

    for T in T_list:
        T_float = float(T)
        for p in p_list:
            q = 2 * p / (p + 1) if p != np.inf else 2.0
            delta = delta_rule(p, T_float)

            all_final_means = []
            allocations = []
            elimination_counts = []
            run_regrets = []

            for rep in range(runs):
                final_means, n_k, tau = algorithm(
                    int(T),
                    q,
                    delta,
                    len(reward_generators),
                    sigma_min,
                    sigma_sg,
                    reward_generators,
                )
                all_final_means.append(final_means)
                allocations.append(n_k)
                elimination_counts.append(tau)

                regret_run = compute_regret([n_k], stds, T, p)
                run_regrets.append(regret_run)

                per_run_records.append(
                    {
                        "T": T_float,
                        "p": float(p),
                        "replication": rep + 1,
                        "allocations": list(map(int, n_k)),
                        "elimination_counts": list(map(int, tau)),
                        "regret": float(regret_run),
                    }
                )

            allocations_arr = np.asarray(allocations, dtype=int)
            regret_stats = compute_regret(
                allocations_arr,
                stds,
                T,
                p,
                return_ci=True,
                confidence=confidence,
            )
            mean_regret, ci_lb, ci_ub = regret_stats  # type: ignore[misc]

            regret_mean[p].append(float(mean_regret))
            regret_lb[p].append(float(ci_lb))
            regret_ub[p].append(float(ci_ub))

            summary_records.append(
                {
                    "T": T_float,
                    "p": float(p),
                    "runs": runs,
                    "confidence": confidence,
                    "regret_mean": float(mean_regret),
                    "regret_lb": float(ci_lb),
                    "regret_ub": float(ci_ub),
                    "regret_mean_of_runs": float(np.mean(run_regrets)),
                    "regret_std_of_runs": float(np.std(run_regrets, ddof=1)) if runs > 1 else 0.0,
                }
            )

    return ExperimentResult(
        per_run=pd.DataFrame(per_run_records),
        summary=pd.DataFrame(summary_records),
        regret_mean=regret_mean,
        regret_lb=regret_lb,
        regret_ub=regret_ub,
    )


# ---------------------------------------------------------------------------
# Theoretical bounds helpers


def compute_sigma_q(stds: Sequence[float], q: float) -> float:
    return float(np.sum(np.asarray(stds, dtype=float) ** q))


def gsg_bounds_known(
    stds: Sequence[float],
    sigma_min: float,
    sigma_sg: float,
    T_list: Sequence[int],
    p_list: Sequence[float],
    *,
    beta: float = 1.0,
) -> Dict[float, np.ndarray]:
    T_float = np.asarray(T_list, dtype=float)
    eps = 1e-300
    sigma = float(sigma_sg)
    sigma_min = float(sigma_min)

    bounds: Dict[float, np.ndarray] = {}
    for p in p_list:
        if p == np.inf:
            sigma2_sum = compute_sigma_q(stds, 2.0)
            Sigma_neg1 = compute_sigma_q(stds, -1.0)
            Sigma2 = sigma2_sum
            bracket = np.sqrt(Sigma2) * (Sigma_neg1 + Sigma2 / (sigma_min**3) - 2.0 / sigma_min)
            values = (T_float ** (-1.5)) * np.sqrt(np.log(np.maximum(beta * T_float, 1.0))) * (8.0 * sigma**2 * bracket)
        else:
            q = 2.0 * p / (p + 1.0)
            coef = 40.0 * (sigma**4) * ((compute_sigma_q(stds, q)) ** (2.0 / q)) * (p**2) / (p + 1.0)
            values = coef * (T_float ** (-2.0)) * np.log(np.maximum(beta * T_float, 1.0))
        bounds[p] = np.maximum(values, eps)
    return bounds


def gsg_bounds_unknown(
    stds: Sequence[float],
    sigma_min: float,
    sigma_sg: float,
    T_list: Sequence[int],
    p_list: Sequence[float],
    *,
    beta: float = 1.0,
) -> Dict[float, np.ndarray]:
    T_float = np.asarray(T_list, dtype=float)
    eps = 1e-300

    bounds: Dict[float, np.ndarray] = {}
    Sigma1 = compute_sigma_q(stds, 1.0)
    Sigma2 = compute_sigma_q(stds, 2.0)

    for p in p_list:
        if p == np.inf:
            sigma_min = float(max(sigma_min, 1e-12))
            bracket = np.sqrt(Sigma2) * (Sigma2 - 2.0 * sigma_min**2) / sigma_min + np.sqrt(Sigma2) * Sigma1
            values = np.sqrt(8.0 * (T_float ** -3.0) * np.log(np.maximum(beta * T_float, 1.0))) * bracket
        else:
            q = 2.0 * p / (p + 1.0)
            coef = (5.0 * len(stds) * (p**2) / (p + 1.0)) * (compute_sigma_q(stds, q) ** (2.0 / q))
            values = coef * (T_float ** -2.0) * np.log(np.maximum(beta * T_float, 1.0))
        bounds[p] = np.maximum(values, eps)
    return bounds


def gaussian_bounds_known(
    stds: Sequence[float],
    sigma_min: float,
    T_list: Sequence[int],
    p_list: Sequence[float],
    *,
    beta: float = 1.0,
) -> Dict[float, np.ndarray]:
    T_float = np.asarray(T_list, dtype=float)
    eps = 1e-300
    Sigma2 = compute_sigma_q(stds, 2.0)
    Sigma1 = compute_sigma_q(stds, 1.0)

    bounds: Dict[float, np.ndarray] = {}
    for p in p_list:
        if p == np.inf:
            bracket = np.sqrt(max(Sigma2 * (Sigma2 - 2.0 * sigma_min**2), 0.0)) / sigma_min + np.sqrt(Sigma2 * Sigma1)
            values = np.sqrt(8.0 * (T_float ** -3.0) * np.log(np.maximum(beta * T_float, 1.0))) * bracket
        else:
            q = 2.0 * p / (p + 1.0)
            coef = (5.0 * len(stds) * (p**2) / (p + 1.0)) * (compute_sigma_q(stds, q) ** (2.0 / q))
            values = coef * (T_float ** -2.0) * np.log(np.maximum(beta * T_float, 1.0))
        bounds[p] = np.maximum(values, eps)
    return bounds


def gaussian_bounds_unknown(
    stds: Sequence[float],
    sigma_min: float,
    T_list: Sequence[int],
    p_list: Sequence[float],
    *,
    beta: float = 1.0,
) -> Dict[float, np.ndarray]:
    return gaussian_bounds_known(stds, sigma_min, T_list, p_list, beta=beta)


def plot_regret_vs_bounds(
    T_list: Sequence[int],
    regret_curves: Mapping[float, Sequence[float]],
    bounds_curves: Mapping[float, Sequence[float]],
    *,
    colors: Mapping[float, str] | None = None,
    log_plot: bool = True,
    xlabel: str = r"$\log T$",
    ylabel: str = r"$\log$ Regret",
    title: str | None = None,
    legend_fontsize: int = 12,
) -> None:
    """Plot regret curves alongside theoretical bounds."""
    T_float = np.asarray(T_list, dtype=float)
    x = np.log(T_float) if log_plot else T_float
    eps = 1e-300

    if colors is None:
        colors = {p: color for p, color in zip(regret_curves.keys(), ["red", "green", "blue", "purple"])}

    plt.figure(figsize=(8, 6))
    for p, regrets in regret_curves.items():
        regrets_arr = np.maximum(np.asarray(regrets, dtype=float), eps)
        plt.plot(
            x,
            np.log(regrets_arr) if log_plot else regrets_arr,
            label=f"Regret p={p}",
            color=colors.get(p, None),
            lw=3,
        )

    for p, bounds in bounds_curves.items():
        bounds_arr = np.maximum(np.asarray(bounds, dtype=float), eps)
        plt.plot(
            x,
            np.log(bounds_arr) if log_plot else bounds_arr,
            "--",
            label=f"Theoretical bound p={p}",
            color=colors.get(p, None),
            lw=3,
        )

    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    if title:
        plt.title(title)
    plt.grid(True, alpha=0.4)
    plt.legend(fontsize=legend_fontsize)
    plt.tight_layout()
    plt.show()


# ---------------------------------------------------------------------------
# Convenience wrappers for specific scenarios


def run_gsg_known_bounds(
    T_list: Sequence[int],
    p_list: Sequence[float],
    reward_generators: Sequence[Callable[[int | None], Iterable[float]]],
    stds: Sequence[float],
    sigma_min: float,
    sigma_sg: float,
    *,
    runs: int = 100,
    confidence: float = 0.90,
    delta_rule: Callable[[float, float], float] = default_delta_rule,
) -> ExperimentResult:
    """Run the GSG allocation experiment with known lower bounds."""
    return run_regret_sweep(
        adaptive_algorithm_gsg,
        T_list,
        p_list,
        reward_generators,
        stds,
        sigma_min,
        sigma_sg,
        runs=runs,
        confidence=confidence,
        delta_rule=delta_rule,
    )


def run_gsg_unknown_bounds(
    T_list: Sequence[int],
    p_list: Sequence[float],
    reward_generators: Sequence[Callable[[int | None], Iterable[float]]],
    stds: Sequence[float],
    sigma_sg: float,
    *,
    runs: int = 100,
    confidence: float = 0.90,
    delta_rule: Callable[[float, float], float] = default_delta_rule,
) -> ExperimentResult:
    """Run the GSG allocation experiment without known lower bounds."""
    return run_regret_sweep(
        adaptive_algorithm_gsg,
        T_list,
        p_list,
        reward_generators,
        stds,
        sigma_min=0.0,
        sigma_sg=sigma_sg,
        runs=runs,
        confidence=confidence,
        delta_rule=delta_rule,
    )


def run_ssg_known_bounds(
    T_list: Sequence[int],
    p_list: Sequence[float],
    reward_generators: Sequence[Callable[[int | None], Iterable[float]]],
    stds: Sequence[float],
    sigma_min: float,
    sigma_sg: float,
    *,
    runs: int = 100,
    confidence: float = 0.90,
    delta_rule: Callable[[float, float], float] = default_delta_rule,
) -> ExperimentResult:
    """Run the SSG allocation experiment with known lower bounds."""
    return run_regret_sweep(
        adaptive_algorithm_ssg,
        T_list,
        p_list,
        reward_generators,
        stds,
        sigma_min,
        sigma_sg,
        runs=runs,
        confidence=confidence,
        delta_rule=delta_rule,
    )


def run_ssg_unknown_bounds(
    T_list: Sequence[int],
    p_list: Sequence[float],
    reward_generators: Sequence[Callable[[int | None], Iterable[float]]],
    stds: Sequence[float],
    sigma_sg: float,
    *,
    runs: int = 100,
    confidence: float = 0.90,
    delta_rule: Callable[[float, float], float] = default_delta_rule,
) -> ExperimentResult:
    """Run the SSG allocation experiment without known lower bounds."""
    return run_regret_sweep(
        adaptive_algorithm_ssg,
        T_list,
        p_list,
        reward_generators,
        stds,
        sigma_min=0.0,
        sigma_sg=sigma_sg,
        runs=runs,
        confidence=confidence,
        delta_rule=delta_rule,
    )


def run_gaussian_known_bounds(
    T_list: Sequence[int],
    p_list: Sequence[float],
    reward_generators: Sequence[Callable[[int | None], Iterable[float]]],
    stds: Sequence[float],
    sigma_min: float,
    sigma_sg: float,
    *,
    runs: int = 100,
    confidence: float = 0.90,
    delta_rule: Callable[[float, float], float] = default_delta_rule,
) -> ExperimentResult:
    """Run the Gaussian-specialized allocation experiment with known lower bounds."""
    return run_regret_sweep(
        adaptive_algorithm_gaussian,
        T_list,
        p_list,
        reward_generators,
        stds,
        sigma_min,
        sigma_sg,
        runs=runs,
        confidence=confidence,
        delta_rule=delta_rule,
    )


def run_gaussian_unknown_bounds(
    T_list: Sequence[int],
    p_list: Sequence[float],
    reward_generators: Sequence[Callable[[int | None], Iterable[float]]],
    stds: Sequence[float],
    sigma_sg: float,
    *,
    runs: int = 100,
    confidence: float = 0.90,
    delta_rule: Callable[[float, float], float] = default_delta_rule,
) -> ExperimentResult:
    """Run the Gaussian-specialized allocation experiment without known lower bounds."""
    return run_regret_sweep(
        adaptive_algorithm_gaussian,
        T_list,
        p_list,
        reward_generators,
        stds,
        sigma_min=0.0,
        sigma_sg=sigma_sg,
        runs=runs,
        confidence=confidence,
        delta_rule=delta_rule,
    )


def run_two_arm_rademacher(
    T_list: Sequence[int],
    var_g_list: Sequence[float],
    *,
    runs: int = 100,
    confidence: float = 0.90,
    delta_rule: Callable[[float, float], float] = default_delta_rule,
    sigma_sg: float = 1.0,
) -> Dict[float, ExperimentResult]:
    """Run the two-arm Gaussian-versus-Rademacher benchmark for each variance."""
    results: Dict[float, ExperimentResult] = {}

    for var_g in var_g_list:
        stds = [float(np.sqrt(var_g)), 1.0]
        reward_generators = make_gauss_and_rademacher_generators(var_g)
        result = run_regret_sweep(
            adaptive_algorithm_ssg,
            T_list,
            [np.inf],
            reward_generators,
            stds,
            sigma_min=0.0,
            sigma_sg=sigma_sg,
            runs=runs,
            confidence=confidence,
            delta_rule=delta_rule,
        )
        results[float(var_g)] = result

    return results
