"""
Predictable Kelly / endpoint computations.
"""

from __future__ import annotations

import math
from typing import Optional, Tuple

import numpy as np

from .core import safe_bounds


def _kelly_and_endpoint_from_past(
    s1: float,
    s2: float,
    n: int,
    m: float,
    *,
    eps_cap: float = 1e-3,
    var_floor: float = 0.0,
    shrink_kappa: float = 0.0,
    lcap: Optional[float] = None,
) -> Tuple[float, float, float, float]:
    """
    Stable predictable Taylor-Kelly using S/V with ridge-style regularization.
    Interprets shrink_kappa as pseudo-count strength.
    Returns: mean_hat, lam_taylor, lam_end, em2m_hat
    """
    lam_max_pos, lam_max_neg = safe_bounds(m, eps_cap=eps_cap)
    if lcap is not None:
        lam_max_pos = min(lam_max_pos, float(lcap))
        lam_max_neg = max(lam_max_neg, -float(lcap))

    if n <= 0:
        mean_hat = float(m)
        S = 0.0
        V = 0.0
        delta_hat = 0.0
    else:
        mean_hat = float(s1 / n)
        delta_hat = mean_hat - float(m)
        S = float(s1 - n * float(m))
        V = float(s2 - 2.0 * float(m) * float(s1) + n * float(m) * float(m))
        V = max(0.0, V)

    k = float(shrink_kappa)
    V_reg = V + k * float(var_floor)

    em2m_hat = float(V_reg / (n + k)) if (n + k) > 0 else float(var_floor)

    if math.isclose(V_reg, 0.0, abs_tol=1e-12):
        lam_taylor = 0.0
    else:
        lam_taylor = float(S / V_reg)
    lam_taylor = float(np.clip(lam_taylor, lam_max_neg, lam_max_pos))

    lam_end = float(
        lam_max_pos
        if (delta_hat > 1e-8)
        else (lam_max_neg if (delta_hat < -1e-8) else 0.0)
    )

    return mean_hat, lam_taylor, lam_end, em2m_hat


def _kelly_and_endpoint_from_past_batch(
    s1,
    s2,
    n,
    m,
    eps_cap=1e-3,
    var_floor=0.00,
    shrink_kappa=0.0,
    lcap=None,
    *,
    delta_tol=1e-8,
    out_dtype=np.float32,
):
    """
    Vectorized Taylor-Kelly around a fixed m.
    """
    s1 = np.asarray(s1, dtype=np.float64)
    s2 = np.asarray(s2, dtype=np.float64)
    n = np.asarray(n, dtype=np.int64)

    lam_max_pos, lam_max_neg = safe_bounds(m, eps_cap)
    if lcap is not None:
        lam_max_pos = min(lam_max_pos, float(lcap))
        lam_max_neg = max(lam_max_neg, -float(lcap))

    m_f = float(m)
    n_f = n.astype(np.float64)

    mean_hat = np.where(n > 0, s1 / np.maximum(n_f, 1.0), m_f)
    delta_hat = mean_hat - m_f

    S = np.where(n > 0, s1 - n_f * m_f, 0.0)
    V = np.where(n > 0, s2 - 2.0 * m_f * s1 + n_f * (m_f * m_f), 0.0)
    V = np.maximum(V, 0.0)

    k = float(shrink_kappa)
    V_reg = V + k * float(var_floor)

    denom = n_f + k
    em2m_hat = np.divide(
        V_reg,
        denom,
        out=np.full_like(V_reg, float(var_floor), dtype=np.float64),
        where=denom > 1e-12,
    )

    lam_taylor = np.divide(
        S,
        V_reg,
        out=np.zeros_like(S, dtype=np.float64),
        where=V_reg > 1e-12,
    )
    lam_taylor = np.clip(lam_taylor, lam_max_neg, lam_max_pos)

    lam_end = np.zeros_like(lam_taylor, dtype=np.float64)
    lam_end = np.where(delta_hat > delta_tol, lam_max_pos, lam_end)
    lam_end = np.where(delta_hat < -delta_tol, lam_max_neg, lam_end)

    return (
        mean_hat.astype(out_dtype),
        lam_taylor.astype(out_dtype),
        lam_end.astype(out_dtype),
        em2m_hat.astype(out_dtype),
    )


__all__ = [
    "_kelly_and_endpoint_from_past",
    "_kelly_and_endpoint_from_past_batch",
]
