"""
STaR-Bets baselines (Hoeffding and Bets variants).
"""

from __future__ import annotations

import numpy as np

from .core import safe_bounds


def star_hoeffding_test_process(
    X,
    m,
    delta,
    *,
    two_sided=False,
    stop_on_hit=False,
):
    """
    Algorithm 2 (Testing with ★-Hoeffding) from the STaR-Bets paper.
    """
    X = np.asarray(X, dtype=float)
    n = int(len(X))
    m = float(m)
    delta = float(delta)

    if not (0.0 < delta < 1.0):
        raise ValueError("delta must be in (0,1)")

    T = float(np.log(1.0 / delta))
    lgW = 0.0

    s1 = 0.0
    k = 0

    lgW_path = np.zeros(n + 1, dtype=float)
    hit = np.zeros(n + 1, dtype=bool)
    ell_seq = np.zeros(n, dtype=float)

    for t0 in range(n):
        remaining = float(n - t0)

        target_remaining = max(0.0, T - lgW)

        ell = np.sqrt(8.0 * target_remaining / max(1.0, remaining))
        ell_seq[t0] = float(ell)

        if two_sided and k > 0:
            mean_prev = s1 / k
            sgn = 1.0 if mean_prev >= m else -1.0
        else:
            sgn = 1.0

        x = float(X[t0])
        lgW = lgW + (sgn * ell) * (x - m) - (ell * ell) / 8.0

        lgW_path[t0 + 1] = lgW
        hit[t0 + 1] = hit[t0] or (lgW >= T)

        s1 += x
        k += 1

        if stop_on_hit and hit[t0 + 1]:
            hit[t0 + 1:] = True
            lgW_path[t0 + 1:] = lgW
            break

    return lgW_path, hit, ell_seq


def star_bets_test_process(
    X,
    m,
    delta,
    *,
    alpha_var=0.05,
    use_impl_details=True,
    c=1.0,
    clip_v="m1m",
    two_sided=False,
    eps_cap=1e-3,
    last_round_randomize=False,
    rng=None,
    stop_on_hit=False,
):
    """
    Algorithm 4 (Testing with ★-Bets / STaR-Bets).
    """
    X = np.asarray(X, dtype=float)
    n = int(len(X))
    m = float(m)
    delta = float(delta)
    alpha_var = float(alpha_var)

    if not (0.0 < delta < 1.0):
        raise ValueError("delta must be in (0,1)")
    if not (0.0 < alpha_var < 1.0):
        raise ValueError("alpha_var must be in (0,1)")

    if rng is None:
        rng = np.random.default_rng()

    T = float(np.log(1.0 / delta))

    V = 0.0
    lgW = 0.0

    s1 = 0.0
    k = 0

    if clip_v == "m1m":
        v_cap = float(max(0.0, m * (1.0 - m)))
    elif clip_v == "vanilla":
        v_cap = 1.0
    else:
        raise ValueError("clip_v must be either 'm1m' or 'vanilla'")

    lam_max_pos, lam_max_neg = safe_bounds(m, eps_cap=eps_cap)

    lgW_path = np.zeros(n + 1, dtype=float)
    hit = np.zeros(n + 1, dtype=bool)
    lam_seq = np.zeros(n, dtype=float)

    for t0 in range(n):
        remaining = float(n - t0)
        denom = float(max(1, t0))

        if use_impl_details:
            corr = float(c) * m * float(n) / (denom * denom)
        else:
            corr = 10.0 * float(np.log(8.0 * float(n) / alpha_var)) / (denom * denom)

        v_est = (V / denom) + corr
        v = min(v_est, v_cap)
        v = max(v, 1e-12)

        target_remaining = max(0.0, T - lgW)
        ell = np.sqrt(2.0 * target_remaining / (remaining * v))
        ell = float(min(ell, 1.0))

        if two_sided and k > 0:
            mean_prev = s1 / k
            sgn = 1.0 if mean_prev >= m else -1.0
        else:
            sgn = 1.0

        lam = sgn * ell
        if two_sided:
            lam = float(np.clip(lam, lam_max_neg, lam_max_pos))

        lam_seq[t0] = lam

        x = float(X[t0])
        lgW = lgW + float(np.log1p(lam * (x - m)))
        V = V + (x - m) * (x - m)

        lgW_path[t0 + 1] = lgW
        hit[t0 + 1] = hit[t0] or (lgW >= T)

        s1 += x
        k += 1

        if stop_on_hit and hit[t0 + 1]:
            hit[t0 + 1:] = True
            lgW_path[t0 + 1:] = lgW
            break

    if last_round_randomize and (not hit[-1]):
        p = float(np.exp(np.log(delta) + lgW))
        p = float(np.clip(p, 0.0, 1.0))
        if rng.random() < p:
            hit[-1] = True
            lgW_path[-1] = T

    return lgW_path, hit, lam_seq


def star_hoeffding_two_sided_mixture(X, m, delta, stop_on_hit=False):
    """
    Two-sided mixture of ★-Hoeffding.
    """
    X = np.asarray(X, dtype=float)
    n = int(len(X))
    m = float(m)
    delta = float(delta)

    T = float(np.log(1.0 / delta))

    lgW_pos = 0.0
    lgW_neg = 0.0

    lgW_mix_path = np.zeros(n + 1, dtype=float)
    hit = np.zeros(n + 1, dtype=bool)

    for t0 in range(n):
        remaining = float(n - t0)

        ell_pos = np.sqrt(8.0 * max(0.0, T - lgW_pos) / max(1.0, remaining))
        ell_neg = np.sqrt(8.0 * max(0.0, T - lgW_neg) / max(1.0, remaining))

        x = float(X[t0])

        lgW_pos += ell_pos * (x - m) - (ell_pos * ell_pos) / 8.0
        lgW_neg += ell_neg * (m - x) - (ell_neg * ell_neg) / 8.0

        lgW_mix = np.logaddexp(lgW_pos, lgW_neg) - np.log(2.0)

        lgW_mix_path[t0 + 1] = lgW_mix
        hit[t0 + 1] = hit[t0] or (lgW_mix >= T)

        if stop_on_hit and hit[t0 + 1]:
            hit[t0 + 1:] = True
            lgW_mix_path[t0 + 1:] = lgW_mix
            break

    return lgW_mix_path, hit


__all__ = [
    "star_hoeffding_test_process",
    "star_bets_test_process",
    "star_hoeffding_two_sided_mixture",
]
