import numba as nb
import numpy as np
from tqdm import tqdm

# ------------------------------------------------------------------------------#
# Numba-JIT helpers & runners
# ------------------------------------------------------------------------------#

GAUSSIAN = 0
LAPLACE = 1

dict_TYPE = {GAUSSIAN: "gaussian", LAPLACE: "laplace"}

save_step = 10


@nb.njit
def _draw_distribution(mean, scale, dist_type):
    """
    Draw a single sample from a specified distribution.

    Parameters
    ----------
    mean : float
        Mean or location parameter.
    scale : float
        Standard deviation (for normal) or b (for Laplace).
    dist_type : int
        0 for normal, 1 for Laplace.
    """
    if dist_type == 0:
        # Normal (Box–Muller transform)
        u1 = np.random.random()
        u2 = np.random.random()
        z = np.sqrt(-2.0 * np.log(u1)) * np.cos(2 * np.pi * u2)
        return mean + scale * z

    elif dist_type == 1:
        # Laplace (Inverse-CDF)
        u = np.random.random() - 0.5
        return mean - scale * np.sign(u) * np.log(1 - 2 * abs(u))

    else:
        # Invalid type — return NaN to keep it Numba-compatible
        return np.nan


@nb.njit
def _sample_posterior(_mu, _var, inflation):
    # 1) Normal(mu, sqrt(sigma2 * inflation))
    std = np.sqrt(_var * inflation)
    return np.random.normal(_mu, std)


@nb.njit
def _sample_posterior_GN(_mu, _kappa, _alpha, _beta, inflation):
    # 1) G ~ Gamma(alpha, scale=1/beta)
    # 2) sigma2 = 1/G
    g = np.random.gamma(_alpha, 1.0 / _beta)
    sigma2 = 1.0 / g
    # 3) Normal(mu, sqrt(sigma2/kappa * inflation))
    std = np.sqrt(sigma2 / _kappa * inflation)
    return np.random.normal(_mu, std)


@nb.njit
def run_TS_jit(mu_0, var_0, loc, scale, steps, pulls_out, rewards_out, eta, dist_type):
    arms = loc.shape[0]
    # local working copies
    mu_n = np.full(arms, mu_0)
    var_n = np.full(arms, var_0)
    counts = np.zeros(arms, dtype=np.int32)
    sum_x = np.zeros(arms, dtype=np.float64)
    known_var = np.zeros(arms, dtype=np.float64)
    for a in range(arms):
        known_var[a] = (1 + dist_type) * scale[a] * scale[a]

    for t in range(steps):
        # pick the best posterior sample
        best_a = 0
        best_sample = -1e308
        for a in range(arms):
            sample = _sample_posterior(mu_n[a], var_n[a], eta**2)
            if sample > best_sample:
                best_sample = sample
                best_a = a

        pulls_out[t] = best_a

        # draw reward & update posterior
        r = _draw_distribution(loc[best_a], scale[best_a], dist_type)
        rewards_out[t] = r

        # Update statistics
        counts[best_a] += 1
        sum_x[best_a] += r
        # Update Gaussian posterior for mean with known variance (conjugate)
        precision_prior = 1.0 / var_0
        precision_obs = counts[best_a] / known_var[best_a]
        var_n[best_a] = 1.0 / (precision_prior + precision_obs)
        mu_n[best_a] = var_n[best_a] * (
            mu_0 * precision_prior + sum_x[best_a] / known_var[best_a]
        )

    return  # in-place writes to pulls_out, rewards_out


@nb.njit
def run_TSVI_jit(
    mu_0, var_0, loc, scale, steps, pulls_out, rewards_out, hyper, dist_type
):
    arms = loc.shape[0]
    # local working copies
    mu_n = np.full(arms, mu_0)
    var_n = np.full(arms, var_0)
    counts = np.zeros(arms, dtype=np.int32)
    sum_x = np.zeros(arms, dtype=np.float64)
    known_var = np.zeros(arms, dtype=np.float64)
    for a in range(arms):
        known_var[a] = (1 + dist_type) * scale[a] * scale[a]

    for t in range(steps):
        best_a = 0
        best_sample = -1e308
        for a in range(arms):
            n_a = counts[a] if counts[a] > 0 else 1
            inflation = (hyper * hyper * (t + 1)) / (arms * n_a)

            sample = _sample_posterior(mu_n[a], var_n[a], inflation)
            if sample > best_sample:
                best_sample = sample
                best_a = a

        pulls_out[t] = best_a

        r = _draw_distribution(loc[best_a], scale[best_a], dist_type)
        rewards_out[t] = r

        # Update statistics
        counts[best_a] += 1
        sum_x[best_a] += r
        # Update Gaussian posterior for mean with known variance (conjugate)
        precision_prior = 1.0 / var_0
        precision_obs = counts[best_a] / known_var[best_a]
        var_n[best_a] = 1.0 / (precision_prior + precision_obs)
        mu_n[best_a] = var_n[best_a] * (
            mu_0 * precision_prior + sum_x[best_a] / known_var[best_a]
        )

    return  # in-place writes


@nb.njit
def run_UCB_jit(mu_0, var_0, loc, scale, steps, pulls_out, rewards_out, eta, dist_type):
    arms = loc.shape[0]
    mu_n = np.full(arms, mu_0)
    var_n = np.full(arms, var_0)
    counts = np.zeros(arms, dtype=np.int32)
    sum_x = np.zeros(arms, dtype=np.float64)
    known_var = np.zeros(arms, dtype=np.float64)
    for a in range(arms):
        known_var[a] = (1 + dist_type) * scale[a] * scale[a]

    for t in range(steps):
        best_a = 0
        best_ucb = -1e308
        for a in range(arms):
            n_a = counts[a] if counts[a] > 0 else 1
            bonus = eta * np.sqrt(2.0 * known_var[a] * np.log(t + 2) / n_a)
            ucb = mu_n[a] + bonus
            if ucb > best_ucb:
                best_ucb = ucb
                best_a = a

        pulls_out[t] = best_a
        r = _draw_distribution(loc[best_a], scale[best_a], dist_type)
        rewards_out[t] = r

        counts[best_a] += 1
        sum_x[best_a] += r
        precision_prior = 1.0 / var_0
        precision_obs = counts[best_a] / known_var[best_a]
        var_n[best_a] = 1.0 / (precision_prior + precision_obs)
        mu_n[best_a] = var_n[best_a] * (
            mu_0 * precision_prior + sum_x[best_a] / known_var[best_a]
        )

    return  # in-place writes


@nb.njit
def run_UCBI_any_jit(
    mu_0,
    var_0,
    loc,
    scale,
    steps,
    pulls_out,
    rewards_out,
    hyper,
    eta,
    alpha,
    beta,
    dist_type,
):
    arms = loc.shape[0]
    mu_n = np.full(arms, mu_0)
    var_n = np.full(arms, var_0)
    counts = np.zeros(arms, dtype=np.int32)
    sum_x = np.zeros(arms, dtype=np.float64)
    known_var = np.zeros(arms, dtype=np.float64)

    for a in range(arms):
        known_var[a] = (1 + dist_type) * scale[a] * scale[a]

    K = arms
    T = steps  # for bonus term with fixed T

    for t in range(steps):
        best_a = 0
        best_ucb = -1e308
        for a in range(arms):
            n_a = counts[a] if counts[a] > 0 else 1

            # --- Compute the two candidate bonuses ---
            bonus_WC = (
                hyper * np.sqrt(known_var[a] * np.log2(K)) * ((t + 1) ** alpha) / n_a
            )
            bonus_ID = eta * np.sqrt(
                known_var[a] * np.log2(t + 2) * ((t + 1) ** beta) / n_a
            )
            bonus = bonus_WC if bonus_WC < bonus_ID else bonus_ID  # min{}

            ucb = mu_n[a] + bonus
            if ucb > best_ucb:
                best_ucb = ucb
                best_a = a

        pulls_out[t] = best_a
        r = _draw_distribution(loc[best_a], scale[best_a], dist_type)
        rewards_out[t] = r

        counts[best_a] += 1
        sum_x[best_a] += r
        precision_prior = 1.0 / var_0
        precision_obs = counts[best_a] / known_var[best_a]
        var_n[best_a] = 1.0 / (precision_prior + precision_obs)
        mu_n[best_a] = var_n[best_a] * (
            mu_0 * precision_prior + sum_x[best_a] / known_var[best_a]
        )

    return  # in-place writes


@nb.njit
def run_UCBI_fix_jit(
    mu_0,
    var_0,
    loc,
    scale,
    steps,
    pulls_out,
    rewards_out,
    hyper,
    eta,
    alpha,
    beta,
    dist_type,
):
    arms = loc.shape[0]
    mu_n = np.full(arms, mu_0)
    var_n = np.full(arms, var_0)
    counts = np.zeros(arms, dtype=np.int32)
    sum_x = np.zeros(arms, dtype=np.float64)
    known_var = np.zeros(arms, dtype=np.float64)

    for a in range(arms):
        known_var[a] = (1 + dist_type) * scale[a] * scale[a]

    K = arms
    T = steps  # for bonus term with fixed T

    for t in range(steps):
        best_a = 0
        best_ucb = -1e308
        for a in range(arms):
            n_a = counts[a] if counts[a] > 0 else 1

            # --- Bonus term using t=T across the horizon ---
            bonus_WC = (
                hyper * np.sqrt(known_var[a] * np.log2(K)) * ((T + 1) ** alpha) / n_a
            )
            bonus_ID = eta * np.sqrt(
                known_var[a] * np.log2(T + 1) * ((T + 1) ** beta) / n_a
            )
            bonus = bonus_WC if bonus_WC < bonus_ID else bonus_ID

            ucb = mu_n[a] + bonus
            if ucb > best_ucb:
                best_ucb = ucb
                best_a = a

        pulls_out[t] = best_a
        r = _draw_distribution(loc[best_a], scale[best_a], dist_type)
        rewards_out[t] = r

        counts[best_a] += 1
        sum_x[best_a] += r
        precision_prior = 1.0 / var_0
        precision_obs = counts[best_a] / known_var[best_a]
        var_n[best_a] = 1.0 / (precision_prior + precision_obs)
        mu_n[best_a] = var_n[best_a] * (
            mu_0 * precision_prior + sum_x[best_a] / known_var[best_a]
        )

    return  # in-place writes


@nb.njit
def run_TS_GN_jit(
    mu_0,
    kappa_0,
    alpha_0,
    beta_0,
    laplace_loc,
    laplace_b,
    steps,
    pulls_out,
    rewards_out,
    eta,
    dist_type,
):
    arms = laplace_loc.shape[0]
    # local working copies
    mu_n = np.full(arms, mu_0)
    kappa_n = np.full(arms, kappa_0)
    alpha_n = np.full(arms, alpha_0)
    beta_n = np.full(arms, beta_0)

    for t in range(steps):
        # pick the best posterior sample
        best_a = 0
        best_sample = -1e308
        for a in range(arms):
            sample = _sample_posterior_GN(
                mu_n[a], kappa_n[a], alpha_n[a], beta_n[a], eta**2
            )
            if sample > best_sample:
                best_sample = sample
                best_a = a

        pulls_out[t] = best_a

        # draw reward & update posterior
        r = _draw_distribution(laplace_loc[best_a], laplace_b[best_a], dist_type)
        rewards_out[t] = r

        k_old = kappa_n[best_a]
        k_new = k_old + 1.0

        mu_new = (k_old * mu_n[best_a] + r) / k_new
        alpha_new = alpha_n[best_a] + 0.5
        beta_new = beta_n[best_a] + 0.5 * (r - mu_n[best_a]) ** 2 * k_old / k_new

        mu_n[best_a] = mu_new
        kappa_n[best_a] = k_new
        alpha_n[best_a] = alpha_new
        beta_n[best_a] = beta_new

    return  # in-place writes to pulls_out, rewards_out


@nb.njit
def run_TSVI_GN_jit(
    mu_0,
    kappa_0,
    alpha_0,
    beta_0,
    laplace_loc,
    laplace_b,
    steps,
    pulls_out,
    rewards_out,
    hyper,
    dist_type,
):
    arms = laplace_loc.shape[0]
    mu_n = np.full(arms, mu_0)
    kappa_n = np.full(arms, kappa_0)
    alpha_n = np.full(arms, alpha_0)
    beta_n = np.full(arms, beta_0)
    counts = np.zeros(arms, np.int32)

    for t in range(steps):
        best_a = 0
        best_sample = -1e308
        for a in range(arms):
            n_a = counts[a] if counts[a] > 0 else 1
            inflation = (hyper * hyper * (t + 1)) / (arms * n_a)

            sample = _sample_posterior_GN(
                mu_n[a], kappa_n[a], alpha_n[a], beta_n[a], inflation
            )
            if sample > best_sample:
                best_sample = sample
                best_a = a

        pulls_out[t] = best_a

        r = _draw_distribution(laplace_loc[best_a], laplace_b[best_a], dist_type)
        rewards_out[t] = r
        counts[best_a] += 1

        k_old = kappa_n[best_a]
        k_new = k_old + 1.0

        mu_new = (k_old * mu_n[best_a] + r) / k_new
        alpha_new = alpha_n[best_a] + 0.5
        beta_new = beta_n[best_a] + 0.5 * (r - mu_n[best_a]) ** 2 * k_old / k_new

        mu_n[best_a] = mu_new
        kappa_n[best_a] = k_new
        alpha_n[best_a] = alpha_new
        beta_n[best_a] = beta_new

    return  # in-place writes
