"""
Hedging baselines over fixed epsilon schedules.
"""

from __future__ import annotations

import numpy as np

from .kelly import _kelly_and_endpoint_from_past
from .core import safe_bounds


def build_time_only_epsilon_schedule_dict(N, ks=(0.0, 0.25, 0.5, 0.75)):
    """
    Build deterministic time-only ε schedules (experts).
    """
    N = int(N)
    t = np.arange(1, N + 1, dtype=np.float32)

    sched = {}
    for k in ks:
        k = float(k)
        t0 = int(np.floor(k * N))
        denom = float(max(1, N - t0))
        u = np.maximum(0.0, t - float(t0)) / denom
        u = np.clip(u, 0.0, 1.0).astype(np.float32)

        sched[f"lin_k={k:g}"] = u
        sched[f"quad_k={k:g}"] = (u * u).astype(np.float32)
    return sched


def epsilon_schedule_dict_to_matrix(schedule_dict):
    """Convert {name -> eps_seq} into (names, eps_mat) with eps_mat shape (K,N)."""
    names = list(schedule_dict.keys())
    eps_mat = np.stack([np.asarray(schedule_dict[n], dtype=np.float32) for n in names], axis=0)
    return names, eps_mat


def _logmeanexp(a, axis=None):
    """Stable log(mean(exp(a))) along an axis."""
    a = np.asarray(a, dtype=np.float64)
    amax = np.max(a, axis=axis, keepdims=True)
    out = amax + np.log(np.mean(np.exp(a - amax), axis=axis, keepdims=True))
    return np.squeeze(out, axis=axis) if axis is not None else float(out.squeeze())


def _logsumexp(a, axis=None):
    """Stable log(sum(exp(a))) along an axis."""
    a = np.asarray(a, dtype=np.float64)
    amax = np.max(a, axis=axis, keepdims=True)
    out = amax + np.log(np.sum(np.exp(a - amax), axis=axis, keepdims=True))
    return np.squeeze(out, axis=axis) if axis is not None else float(out.squeeze())


def simulate_uniform_hedge_time_only_eps_greedy_experts(
    X,
    m,
    alpha,
    eps_mat,
    rng=None,
    *,
    coupled=True,
    stop_on_hit=True,
    eps_cap=1e-3,
    var_floor=0.0,
    shrink_kappa=0.0,
    lcap=None,
):
    """
    Uniform hedge over K experts with deterministic ε schedules.
    """
    if rng is None:
        rng = np.random.default_rng()

    X = np.asarray(X, dtype=np.float64)
    N = int(X.shape[0])

    eps_mat = np.asarray(eps_mat, dtype=np.float64)
    K, N2 = eps_mat.shape
    if N2 != N:
        raise ValueError(f"eps_mat has N={N2} but X has N={N}")

    logT = float(np.log(1.0 / alpha))

    y_i = np.zeros(K, dtype=np.float64)
    ybar = np.zeros(N + 1, dtype=np.float64)
    hit = np.zeros(N + 1, dtype=np.bool_)

    s1 = 0.0
    s2 = 0.0
    n = 0

    for t in range(N):
        _, lam_kelly, lam_end, _ = _kelly_and_endpoint_from_past(
            s1, s2, n, m,
            eps_cap=eps_cap,
            var_floor=var_floor,
            shrink_kappa=shrink_kappa,
            lcap=lcap,
        )
        lam_kelly = float(lam_kelly)
        lam_end = float(lam_end)

        if coupled:
            u = float(rng.random())
            take_end = (u < eps_mat[:, t])
        else:
            u = rng.random(K)
            take_end = (u < eps_mat[:, t])

        lam_vec = np.where(take_end, lam_end, lam_kelly).astype(np.float64)

        x_t = float(X[t])
        y_i += np.log1p(lam_vec * (x_t - float(m)))

        ybar[t + 1] = _logmeanexp(y_i)
        hit[t + 1] = bool(hit[t] or (ybar[t + 1] >= logT))

        s1 += x_t
        s2 += x_t * x_t
        n += 1

        if stop_on_hit and hit[t + 1]:
            hit[t + 1:] = True
            ybar[t + 1:] = ybar[t + 1]
            break

    return ybar, hit


def simulate_expweights_hedge_time_only_eps_greedy_experts(
    X,
    m,
    alpha,
    eps_mat,
    *,
    eta=2.0,
    gamma=0.01,
    score_mode="shadow",
    pi=None,
    rng=None,
    coupled=True,
    stop_on_hit=True,
    eps_cap=1e-3,
    var_floor=0.0,
    shrink_kappa=0.0,
    lcap=None,
):
    """
    Predictable rebalanced exp-weights hedge (self-financing).
    """
    if rng is None:
        rng = np.random.default_rng()

    X = np.asarray(X, dtype=np.float64)
    N = int(X.shape[0])

    eps_mat = np.asarray(eps_mat, dtype=np.float64)
    K, N2 = eps_mat.shape
    if N2 != N:
        raise ValueError(f"eps_mat has N={N2} but X has N={N}")

    eta = float(eta)
    gamma = float(gamma)
    if not (0.0 <= gamma <= 1.0):
        raise ValueError("gamma must be in [0,1]")
    if eta < 0.0:
        raise ValueError("eta must be >= 0")
    if score_mode not in ("shadow", "capital"):
        raise ValueError('score_mode must be "shadow" or "capital"')

    logT = float(np.log(1.0 / alpha))

    if pi is None:
        pi = np.full(K, 1.0 / K, dtype=np.float64)
    else:
        pi = np.asarray(pi, dtype=np.float64)
        if pi.shape != (K,):
            raise ValueError(f"pi must have shape ({K},)")
        if np.any(pi < 0):
            raise ValueError("pi must be nonnegative")
        s = float(pi.sum())
        if s <= 0:
            raise ValueError("pi must sum to a positive value")
        pi = pi / s

    logC = np.log(pi)
    y_shadow = np.zeros(K, dtype=np.float64)

    y_mix = np.zeros(N + 1, dtype=np.float64)
    hit = np.zeros(N + 1, dtype=np.bool_)
    y_mix[0] = _logsumexp(logC)

    s1 = 0.0
    s2 = 0.0
    n = 0

    if gamma == 0.0:
        log_gamma = -np.inf
        log_1mgamma = 0.0
    elif gamma == 1.0:
        log_gamma = 0.0
        log_1mgamma = -np.inf
    else:
        log_gamma = float(np.log(gamma))
        log_1mgamma = float(np.log1p(-gamma))

    for t in range(N):
        _, lam_kelly, lam_end, _ = _kelly_and_endpoint_from_past(
            s1, s2, n, m,
            eps_cap=eps_cap,
            var_floor=var_floor,
            shrink_kappa=shrink_kappa,
            lcap=lcap,
        )
        lam_kelly = float(lam_kelly)
        lam_end = float(lam_end)

        logW_prev = _logsumexp(logC)

        if eta == 0.0:
            logp = np.full(K, -np.log(K), dtype=np.float64)
        else:
            score = y_shadow if score_mode == "shadow" else logC
            logits = eta * score
            logZ = _logsumexp(logits)
            logp = logits - logZ

        if gamma == 0.0:
            logC_tilde = logC
        elif gamma == 1.0:
            logC_tilde = logW_prev + logp
        else:
            logC_tilde = np.logaddexp(
                logC + log_1mgamma,
                (logW_prev + log_gamma) + logp,
            )

        if coupled:
            u = float(rng.random())
            take_end = (u < eps_mat[:, t])
        else:
            u = rng.random(K)
            take_end = (u < eps_mat[:, t])

        lam_vec = np.where(take_end, lam_end, lam_kelly).astype(np.float64)

        x_t = float(X[t])
        incr = lam_vec * (x_t - float(m))
        if np.any(incr < -1.0):
            raise ValueError("1 + lam*(X-m) < 0 encountered; adjust bounds/caps.")
        log_factor = np.log1p(incr)

        logC = logC_tilde + log_factor
        y_shadow = y_shadow + log_factor

        y_mix[t + 1] = _logsumexp(logC)
        hit[t + 1] = bool(hit[t] or (y_mix[t + 1] >= logT))

        if stop_on_hit and hit[t + 1]:
            hit[t + 1:] = True
            y_mix[t + 1:] = y_mix[t + 1]
            break

        s1 += x_t
        s2 += x_t * x_t
        n += 1

    return y_mix, hit


__all__ = [
    "build_time_only_epsilon_schedule_dict",
    "epsilon_schedule_dict_to_matrix",
    "simulate_uniform_hedge_time_only_eps_greedy_experts",
    "simulate_expweights_hedge_time_only_eps_greedy_experts",
]
