"""
Baseline betting policies.
"""

from __future__ import annotations

import numpy as np

from .core import safe_bounds
from .kelly import _kelly_and_endpoint_from_past


def simulate_with_policy(X, lam, m, alpha):
    """
    Given path X and bets lam, return log-wealth path and hit indicators.
    """
    X = np.asarray(X)
    lam = np.asarray(lam)
    N = len(X)
    eta = np.log(1.0 / alpha)
    Y = np.zeros(N + 1)
    hit = np.zeros(N + 1, dtype=bool)
    for t in range(N):
        Y[t + 1] = Y[t] + np.log(1.0 + lam[t] * (X[t] - m))
        hit[t + 1] = hit[t] or (Y[t + 1] >= eta)
    return Y, hit


def empirical_kelly_policy_predictable(
    X,
    m: float,
    eps_cap: float = 1e-3,
    var_floor: float = 0.00,
    shrink_kappa: float = 0.0,
    lcap=None,
):
    """
    Predictable empirical-Kelly/Taylor-Kelly policy computed from past statistics only.
    """
    X = np.asarray(X, dtype=float)
    N = X.shape[0]
    lam = np.zeros(N, dtype=float)

    s1 = 0.0
    s2 = 0.0
    n = 0

    for t in range(N):
        _, lam_taylor, _, _ = _kelly_and_endpoint_from_past(
            s1, s2, n, m,
            eps_cap=eps_cap,
            var_floor=var_floor,
            shrink_kappa=shrink_kappa,
            lcap=lcap,
        )
        lam[t] = float(lam_taylor)

        xt = float(X[t])
        s1 += xt
        s2 += xt * xt
        n += 1

    return lam


def linear_epsilon_schedule(N, t0_frac=0.2):
    """
    Return array eps[0..N-1] that ramps from 0 to 1, starting at t0 = floor(t0_frac*N).
    """
    eps = np.zeros(N, dtype=float)
    t0 = int(t0_frac * N)
    if t0 >= N:
        return eps
    ramp = np.linspace(0.0, 1.0, num=N - t0, endpoint=True)
    eps[t0:] = ramp
    return eps


def eps_greedy_policy(
    X,
    m,
    eps_seq=None,
    eps_cap=1e-3,
    var_ridge=1e-6,
    rng=None,
    *,
    var_floor=0.0,
    shrink_kappa=0.0,
    lcap=None,
    cap_ramp=(0.5, 2.0),
):
    """
    Predictable ε-greedy horizon-aware policy.
    """
    if rng is None:
        rng = np.random.default_rng()

    X = np.asarray(X, dtype=float)
    N = len(X)
    if eps_seq is None:
        eps_seq = linear_epsilon_schedule(N, t0_frac=0.7)

    lam = np.zeros(N, dtype=float)

    lam_max_pos, lam_max_neg = safe_bounds(m, eps_cap)
    if lcap is not None:
        lam_max_pos = min(lam_max_pos, float(lcap))
        lam_max_neg = max(lam_max_neg, -float(lcap))

    s1 = 0.0
    s2 = 0.0
    n = 0

    for t in range(N):
        _, lam_kelly, lam_end, _ = _kelly_and_endpoint_from_past(
            s1, s2, n, m,
            eps_cap=eps_cap,
            var_floor=var_floor,
            shrink_kappa=shrink_kappa,
            lcap=lcap,
        )

        if rng.random() < float(eps_seq[t]):
            lam[t] = lam_end
        else:
            lam[t] = lam_kelly

        s1 += X[t]
        s2 += X[t] * X[t]
        n += 1

    return lam


__all__ = [
    "simulate_with_policy",
    "empirical_kelly_policy_predictable",
    "linear_epsilon_schedule",
    "eps_greedy_policy",
]
