"""
Core utilities and allocation algorithms for the Exploration Free project.
"""

from __future__ import annotations

from typing import Callable, Iterable, List, Sequence, Tuple

import numpy as np
from scipy import stats


RewardGenerator = Callable[[int | None], np.ndarray]


def sample_many(generator: Callable[..., np.ndarray] | Callable[[], float], n: int) -> List[float]:
    """Sample ``n`` observations from the provided generator.

    The helper supports both vectorized callables that accept ``size=`` and scalar
    callables that return a single draw per invocation.
    """
    try:
        out = generator(size=n)  # type: ignore[misc]
        return np.asarray(out, dtype=float).tolist()
    except TypeError:
        return [float(generator()) for _ in range(n)]  # type: ignore[misc]


def compute_regret(
    allocations: Sequence[Sequence[int]],
    stds: Sequence[float],
    horizon: int,
    p_norm: float,
    *,
    return_ci: bool = False,
    confidence: float = 0.90,
) -> float | Tuple[float, float, float]:
    """Compute the regret (and optional confidence interval) for a set of runs."""
    allocation_arr = np.asarray(allocations, dtype=float)
    runs, arms = allocation_arr.shape

    stds_arr = np.asarray(stds, dtype=float)
    expected_errors = (stds_arr**2)[None, :] / allocation_arr  # shape: (runs, arms)

    if p_norm == np.inf:
        empirical = np.max(expected_errors, axis=1)
        oracle = np.sum(stds_arr**2) / float(horizon)
    else:
        empirical = np.sum(expected_errors**p_norm, axis=1) ** (1.0 / p_norm)
        q_norm = 2.0 * p_norm / (p_norm + 1.0)
        oracle = np.power(np.sum(stds_arr**q_norm), 2.0 / q_norm) / float(horizon)

    regrets = empirical - oracle
    regret_mean = float(np.mean(regrets))

    if not return_ci or runs <= 1:
        if return_ci:
            return regret_mean, regret_mean, regret_mean
        return regret_mean

    sd = float(np.std(regrets, ddof=1))
    se = sd / np.sqrt(runs)
    alpha = 1.0 - confidence
    tcrit = float(stats.t.ppf(1.0 - alpha / 2.0, df=runs - 1))
    lb = regret_mean - tcrit * se
    ub = regret_mean + tcrit * se
    return regret_mean, float(lb), float(ub)


# ---------------------------------------------------------------------------
# Allocation algorithms


def adaptive_algorithm_gsg(
    horizon: int,
    q_norm: float,
    delta: float,
    num_arms: int,
    sigma_min: float,
    sigma_sg: float,
    reward_generators: Sequence[Callable[[int | None], Iterable[float]]],
) -> Tuple[List[float], List[int], np.ndarray]:
    """Adaptive elimination algorithm for the general sub-Gaussian setting."""
    delta = float(delta)
    log_inv_delta = -np.log(delta)

    def epsilon_t_minus(t):
        t = np.asarray(t, dtype=float)
        t_safe = np.maximum(t, 2.0)
        term1 = 4 * sigma_sg**2 * (1 + np.sqrt(t_safe - 1)) / np.sqrt(t_safe)
        term1 *= np.sqrt(2 * log_inv_delta / (t_safe - 1))
        term2 = 13 * sigma_sg**2 * log_inv_delta / (3 * t_safe)
        return term1 + term2

    def epsilon_t_plus(t):
        t = np.asarray(t, dtype=float)
        t_safe = np.maximum(t, 2.0)
        term1 = 4 * sigma_sg**2 * (1 + np.sqrt(t_safe - 1)) / np.sqrt(t_safe)
        term1 *= np.sqrt(2 * log_inv_delta / (t_safe - 1))
        term2 = 6 * sigma_sg**2 * log_inv_delta / t_safe
        return term1 + term2

    T_int = int(horizon)
    rewards = [list() for _ in range(num_arms)]
    t_counts = np.zeros(num_arms, dtype=int)
    n_max = max(T_int // num_arms, 0)

    def safe_sigma2_hats():
        arr = np.array([np.var(r, ddof=1) if len(r) >= 2 else np.nan for r in rewards], float)
        return np.nan_to_num(arr, nan=0.0)

    if sigma_min == 0:
        m = 1
        n0 = int(np.ceil(64.0 * (m**2) * sigma_sg**4 * np.log(T_int)))
        n = min(n0, n_max)

        for k in range(num_arms):
            rewards[k].extend(sample_many(reward_generators[k], n))
            t_counts[k] = n
        total_pulls = int(t_counts.sum())

        while total_pulls < T_int:
            rem = T_int - total_pulls
            if 0 < rem < num_arms:
                sigma2 = safe_sigma2_hats()
                eps_p = float(epsilon_t_plus(n))
                lcb_vals = sigma2 - eps_p
                order = np.argsort(lcb_vals)
                for k in order[:rem]:
                    rewards[k].append(reward_generators[k]())
                    t_counts[k] += 1
                total_pulls += rem
                break

            sigma2 = safe_sigma2_hats()
            eps_p = float(epsilon_t_plus(n))
            lcb_min = float(np.min(sigma2 - eps_p))
            if lcb_min > 0.0:
                break

            for k in range(num_arms):
                rewards[k].append(reward_generators[k]())
                t_counts[k] += 1
            n += 1
            total_pulls += num_arms
    else:
        n_raw = int(sigma_min**q_norm * T_int / (sigma_min**q_norm + (num_arms - 1) * sigma_sg**q_norm))
        n = min(n_raw, n_max)
        rewards = [sample_many(g, n) for g in reward_generators]
        t_counts = np.full(num_arms, n, dtype=int)

    t_counts = np.array([len(r) for r in rewards], dtype=int)
    tau = np.full(num_arms, -1, dtype=int)

    lam = t_counts.astype(float) / float(T_int)
    active = np.ones(num_arms, dtype=bool)

    max_loops = 10000
    for _ in range(max_loops):
        if n_max == n or not active.any():
            break

        remaining = T_int - int(t_counts.sum())
        if remaining <= 0:
            break

        for k in np.flatnonzero(active):
            target_nk = int(np.floor(lam[k] * T_int))
            to_add = max(0, target_nk - t_counts[k])
            if to_add <= 0:
                continue
            add_k = min(to_add, remaining)
            if add_k <= 0:
                break
            rewards[k].extend(sample_many(reward_generators[k], add_k))
            t_counts[k] += add_k
            remaining -= add_k
            if remaining <= 0:
                break

        eps_minus = epsilon_t_minus(t_counts)
        eps_plus = epsilon_t_plus(t_counts)
        sigma2_hats = np.array(
            [np.var(rew, ddof=1) if len(rew) >= 2 else 0.0 for rew in rewards],
            dtype=float,
        )
        lcb_vals = np.maximum(sigma2_hats - eps_plus, 0.0).astype(float)
        ucb_vals = (sigma2_hats + eps_minus).astype(float)

        lcb_q = np.power(lcb_vals, q_norm / 2.0)
        ucb_q = np.power(ucb_vals, q_norm / 2.0)
        denom = lcb_q + (ucb_q.sum() - ucb_q)
        denom = np.where(denom <= 0.0, 1e-12, denom)
        lam_new = lcb_q / denom

        lam_thresh = np.floor(np.maximum(lam_new, 0.0) * T_int + 1e-12).astype(int)
        elim = t_counts >= lam_thresh
        tau[elim] = t_counts[elim]

        active = ~elim
        lam = lam_new

    T_int = int(horizon)
    current_total = int(t_counts.sum())
    remaining = T_int - current_total

    if remaining > 0:
        var_hats = np.array([np.var(rew, ddof=1) if len(rew) >= 2 else 0.0 for rew in rewards], dtype=float)
        weights = np.power(var_hats, q_norm / 2.0)
        if not np.isfinite(weights).all() or np.all(weights == 0):
            weights = np.ones_like(weights, dtype=float)
        weights = weights / weights.sum()

        target = np.floor(weights * T_int).astype(int)
        deficit = np.maximum(target - tau, 0)
        give = np.zeros(num_arms, dtype=int)
        total_need = int(deficit.sum())

        if total_need <= remaining:
            give = deficit.astype(int)
        else:
            order = np.argsort(-deficit)
            rem = remaining
            for k in order:
                if rem <= 0:
                    break
                take = int(min(deficit[k], rem))
                give[k] = take
                rem -= take

        for k, extra in enumerate(give):
            if extra <= 0:
                continue
            rewards[k].extend(sample_many(reward_generators[k], int(extra)))
            t_counts[k] += int(extra)

    residual = int(T_int - int(t_counts.sum()))
    if residual > 0:
        weights = np.array(
            [np.var(rew, ddof=1) if len(rew) >= 2 else 1.0 for rew in rewards],
            dtype=float,
        )
        weights = np.where(np.isfinite(weights) & (weights > 0), weights, 1.0)
        weights = weights / weights.sum()
        for _ in range(residual):
            k = int(np.argmax(weights))
            rewards[k].append(reward_generators[k]())
            t_counts[k] += 1

    final_means = [float(np.mean(r)) if len(r) > 0 else 0.0 for r in rewards]
    n_k = [len(r) for r in rewards]
    return final_means, n_k, t_counts


def adaptive_algorithm_ssg(
    horizon: int,
    q_norm: float,
    delta: float,
    num_arms: int,
    sigma_min: float,
    sigma_sg: float,
    reward_generators: Sequence[Callable[[int | None], Iterable[float]]],
) -> Tuple[List[float], List[int], np.ndarray]:
    """Adaptive elimination algorithm specialized for strongly sub-Gaussian arms."""
    delta = float(delta)
    delta = max(min(delta, 1.0 - 1e-12), np.finfo(float).tiny)
    log_inv_delta = -np.log(delta)

    def s_t_minus(t):
        t = np.asarray(t, dtype=float)
        t_safe = np.maximum(t, 2.0)

        term1 = 4.0 * (1.0 + np.sqrt((t_safe - 1.0) / 8.0)) / np.sqrt(t_safe)
        term1 *= np.sqrt(2.0 * log_inv_delta / (t_safe - 1.0))
        term2 = 13 * log_inv_delta / (3.0 * t_safe)
        return term1 + term2

    def s_t_plus(t):
        t = np.asarray(t, dtype=float)
        t_safe = np.maximum(t, 2.0)

        term1 = 4.0 * (1.0 + np.sqrt((t_safe - 1.0) / 8.0)) / np.sqrt(t_safe)
        term1 *= np.sqrt(2.0 * log_inv_delta / (t_safe - 1.0))
        term2 = 6.0 * log_inv_delta / t_safe
        return term1 + term2

    T_int = int(horizon)
    rewards = [list() for _ in range(num_arms)]
    t_counts = np.zeros(num_arms, dtype=int)

    if sigma_min == 0:
        n0 = int(np.ceil(18.0 * np.log(T_int)))
        n_max = max(T_int // num_arms, 0)
        n = min(n0, n_max)

        for k in range(num_arms):
            rewards[k].extend(sample_many(reward_generators[k], n))
            t_counts[k] = n
        total_pulls = int(t_counts.sum())

        while True:
            if total_pulls + num_arms > T_int:
                break
            s_val = float(s_t_minus(float(n)))
            if s_val < 1.0:
                break

            for k in range(num_arms):
                rewards[k].append(reward_generators[k]())
                t_counts[k] += 1
            n += 1
            total_pulls += num_arms
    else:
        tau_1 = int(sigma_min**q_norm * T_int / (sigma_min**q_norm + (num_arms - 1) * sigma_sg**q_norm))
        tau_1 = int(min(tau_1, T_int / num_arms))
        rewards = [sample_many(g, tau_1) for g in reward_generators]

    t_counts = np.array([len(r) for r in rewards], dtype=int)
    tau = np.full(num_arms, -1, dtype=int)
    active = np.ones(num_arms, dtype=bool)

    lam = t_counts.astype(float) / float(T_int)

    max_loops = 10000
    for _ in range(max_loops):
        if not active.any():
            break

        remaining = T_int - int(t_counts.sum())
        if remaining <= 0:
            break

        for k in np.flatnonzero(active):
            target_nk = int(np.floor(lam[k] * T_int))
            to_add = max(0, target_nk - t_counts[k])
            if to_add <= 0:
                continue
            add_k = min(to_add, remaining)
            rewards[k].extend(sample_many(reward_generators[k], add_k))
            t_counts[k] += add_k
            remaining -= add_k
            if remaining <= 0:
                break

        s_minus = s_t_minus(t_counts)
        s_plus = s_t_plus(t_counts)
        sigma2_hats = np.array(
            [np.var(rew, ddof=1) if len(rew) >= 2 else 0.0 for rew in rewards],
            dtype=float,
        )
        lcb_vals = (sigma2_hats / (1 + s_plus)).astype(float)
        ucb_vals = (sigma2_hats / (1 - s_minus)).astype(float)

        lcb_q = np.power(lcb_vals, q_norm / 2.0)
        ucb_q = np.power(ucb_vals, q_norm / 2.0)
        denom = lcb_q + (ucb_q.sum() - ucb_q)
        denom = np.where(denom <= 0.0, 1e-12, denom)
        lam_new = lcb_q / denom

        lam_thresh = np.floor(np.maximum(lam_new, 0.0) * T_int + 1e-12).astype(int)
        elim = t_counts >= lam_thresh
        tau[elim] = t_counts[elim]

        active = ~elim
        lam = lam_new

    T_int = int(horizon)
    current_total = int(t_counts.sum())
    remaining = T_int - current_total

    if remaining > 0:
        var_hats = np.array([np.var(rew, ddof=1) if len(rew) >= 2 else 0.0 for rew in rewards], dtype=float)
        weights = np.power(var_hats, q_norm / 2.0)
        if not np.isfinite(weights).all() or np.all(weights == 0):
            weights = np.ones_like(weights, dtype=float)
        weights = weights / weights.sum()

        target = np.floor(weights * T_int).astype(int)
        deficit = np.maximum(target - tau, 0)
        give = np.zeros(num_arms, dtype=int)
        total_need = int(deficit.sum())

        if total_need <= remaining:
            give = deficit.astype(int)
        else:
            order = np.argsort(-deficit)
            rem = remaining
            for k in order:
                if rem <= 0:
                    break
                take = int(min(deficit[k], rem))
                give[k] = take
                rem -= take

        for k, extra in enumerate(give):
            if extra <= 0:
                continue
            rewards[k].extend(sample_many(reward_generators[k], int(extra)))
            t_counts[k] += int(extra)

    residual = int(T_int - int(t_counts.sum()))
    if residual > 0:
        weights = np.array(
            [np.var(rew, ddof=1) if len(rew) >= 2 else 1.0 for rew in rewards],
            dtype=float,
        )
        weights = np.where(np.isfinite(weights) & (weights > 0), weights, 1.0)
        weights = weights / weights.sum()
        for _ in range(residual):
            k = int(np.argmax(weights))
            rewards[k].append(reward_generators[k]())
            t_counts[k] += 1

    final_means = [float(np.mean(r)) if len(r) > 0 else 0.0 for r in rewards]
    n_k = [len(r) for r in rewards]
    return final_means, n_k, tau


def adaptive_algorithm_gaussian(
    horizon: int,
    q_norm: float,
    delta: float,
    num_arms: int,
    sigma_min: float,
    sigma_sg: float,
    reward_generators: Sequence[Callable[[int | None], Iterable[float]]],
) -> Tuple[List[float], List[int], np.ndarray]:
    """Adaptive elimination variant with Gaussian-specific confidence bounds."""
    delta = float(delta)
    delta = max(min(delta, 1.0 - 1e-12), np.finfo(float).tiny)
    log_inv_delta = -np.log(delta)

    def s_t_minus(t):
        t = np.asarray(t, dtype=float)
        t_safe = np.maximum(t, 2.0)
        return 2.0 * np.sqrt(log_inv_delta / (t_safe - 1.0))

    def s_t_plus(t):
        t = np.asarray(t, dtype=float)
        t_safe = np.maximum(t, 2.0)
        term1 = 2.0 * np.sqrt(log_inv_delta / (t_safe - 1.0))
        term2 = 2.0 * log_inv_delta / (t_safe - 1.0)
        return term1 + term2

    T_int = int(horizon)
    rewards = [list() for _ in range(num_arms)]
    t_counts = np.zeros(num_arms, dtype=int)

    if sigma_min == 0:
        n0 = int(np.ceil(8.0 * np.log(T_int)))
        n_max = max(T_int // num_arms, 0)
        n = min(n0, n_max)

        for k in range(num_arms):
            rewards[k].extend(sample_many(reward_generators[k], n))
            t_counts[k] = n
        total_pulls = int(t_counts.sum())

        while True:
            if total_pulls + num_arms > T_int:
                break
            s_val = float(s_t_minus(float(n)))
            if s_val < 1.0:
                break

            for k in range(num_arms):
                rewards[k].append(reward_generators[k]())
                t_counts[k] += 1
            n += 1
            total_pulls += num_arms
    else:
        tau_1 = int(sigma_min**q_norm * T_int / (sigma_min**q_norm + (num_arms - 1) * sigma_sg**q_norm))
        tau_1 = int(min(tau_1, T_int / num_arms))
        rewards = [sample_many(g, tau_1) for g in reward_generators]

    t_counts = np.array([len(r) for r in rewards], dtype=int)
    tau = np.full(num_arms, -1, dtype=int)
    active = np.ones(num_arms, dtype=bool)

    lam = t_counts.astype(float) / float(T_int)

    max_loops = 10000
    for _ in range(max_loops):
        if not active.any():
            break

        remaining = T_int - int(t_counts.sum())
        if remaining <= 0:
            break

        for k in np.flatnonzero(active):
            target_nk = int(np.floor(lam[k] * T_int))
            to_add = max(0, target_nk - t_counts[k])
            if to_add <= 0:
                continue
            add_k = min(to_add, remaining)
            rewards[k].extend(sample_many(reward_generators[k], add_k))
            t_counts[k] += add_k
            remaining -= add_k
            if remaining <= 0:
                break

        s_minus = s_t_minus(t_counts)
        s_plus = s_t_plus(t_counts)
        sigma2_hats = np.array(
            [np.var(rew, ddof=1) if len(rew) >= 2 else 0.0 for rew in rewards],
            dtype=float,
        )
        lcb_vals = (sigma2_hats / (1 + s_plus)).astype(float)
        ucb_vals = (sigma2_hats / (1 - s_minus)).astype(float)

        lcb_q = np.power(lcb_vals, q_norm / 2.0)
        ucb_q = np.power(ucb_vals, q_norm / 2.0)
        denom = lcb_q + (ucb_q.sum() - ucb_q)
        denom = np.where(denom <= 0.0, 1e-12, denom)
        lam_new = lcb_q / denom

        lam_thresh = np.floor(np.maximum(lam_new, 0.0) * T_int + 1e-12).astype(int)
        elim = t_counts >= lam_thresh
        tau[elim] = t_counts[elim]

        active = ~elim
        lam = lam_new

    T_int = int(horizon)
    current_total = int(t_counts.sum())
    remaining = T_int - current_total

    if remaining > 0:
        var_hats = np.array([np.var(rew, ddof=1) if len(rew) >= 2 else 0.0 for rew in rewards], dtype=float)
        weights = np.power(var_hats, q_norm / 2.0)
        if not np.isfinite(weights).all() or np.all(weights == 0):
            weights = np.ones_like(weights, dtype=float)
        weights = weights / weights.sum()

        target = np.floor(weights * T_int).astype(int)
        deficit = np.maximum(target - tau, 0)
        give = np.zeros(num_arms, dtype=int)
        total_need = int(deficit.sum())

        if total_need <= remaining:
            give = deficit.astype(int)
        else:
            order = np.argsort(-deficit)
            rem = remaining
            for k in order:
                if rem <= 0:
                    break
                take = int(min(deficit[k], rem))
                give[k] = take
                rem -= take

        for k, extra in enumerate(give):
            if extra <= 0:
                continue
            rewards[k].extend(sample_many(reward_generators[k], int(extra)))
            t_counts[k] += int(extra)

    residual = int(T_int - int(t_counts.sum()))
    if residual > 0:
        weights = np.array(
            [np.var(rew, ddof=1) if len(rew) >= 2 else 1.0 for rew in rewards],
            dtype=float,
        )
        weights = np.where(np.isfinite(weights) & (weights > 0), weights, 1.0)
        weights = weights / weights.sum()
        for _ in range(residual):
            k = int(np.argmax(weights))
            rewards[k].append(reward_generators[k]())
            t_counts[k] += 1

    final_means = [float(np.mean(r)) if len(r) > 0 else 0.0 for r in rewards]
    n_k = [len(r) for r in rewards]
    return final_means, n_k, tau


def make_gauss_and_rademacher_generators(var_gaussian: float, seed_gauss: int = 123, seed_radem: int = 456):
    """Return generators for a Gaussian arm and a Rademacher arm."""
    std_g = float(np.sqrt(var_gaussian))
    rng_g = np.random.default_rng(seed_gauss)
    rng_r = np.random.default_rng(seed_radem)

    def gen_gauss(size=None, mu=0.0, std=std_g, rng=rng_g):
        return rng.normal(mu, std, size=size)

    def gen_rademacher(size=None, rng=rng_r):
        return rng.choice([-1.0, 1.0], size=size)

    return [gen_gauss, gen_rademacher]


# Backwards-compatible aliases
compute_regret_1 = compute_regret
adaptive_algorithm_GSG = adaptive_algorithm_gsg
adaptive_algorithm_SSG = adaptive_algorithm_ssg
adaptive_algorithm_Gaussian = adaptive_algorithm_gaussian
