"""
Feature engineering helpers.
"""

from __future__ import annotations

import numpy as np

from .constants import FEAT_TAU, KURT_SHRINK_K, SKEW_SHRINK_K


def _features_from_state(
    m, mu_hat, var_hat, y, T, t, N, lam_kelly, lam_end,
    s2, s3, s4, n, n_pos, n_low, n_high,
    tau=FEAT_TAU,
):
    """
    Scalar extended feature vector φ_t from predictable (past) statistics.
    """
    m = float(m)
    mu_hat = float(mu_hat)
    var_hat = float(var_hat)
    y = float(y)
    T = float(T)
    t = int(t)
    N = int(N)
    lam_kelly = float(lam_kelly)
    lam_end = float(lam_end)

    delta = float(mu_hat - m)
    abs_delta = float(abs(delta))

    dist = float((T - y) / (abs(T) + 1e-12))
    dist = float(np.clip(dist, -2.0, 2.0))

    time_remaining_rel = float((N - 1 - t) / max(1, N - 1))

    var_hat = float(np.clip(var_hat, 0.0, 0.25))
    se = float(np.sqrt(var_hat / max(1, t)))
    snr = float(delta / (se + 1e-12))
    snr = float(np.clip(snr, -5.0, 5.0))

    steps_left = max(1, N - t)
    req = float((T - y) / steps_left)
    req = float(np.clip(req / (abs(T) + 1e-12), -2.0, 2.0))

    lam_gap = float(lam_end - lam_kelly)

    ez2 = float(var_hat + delta * delta)
    g_k = float(lam_kelly * delta - 0.5 * (lam_kelly**2) * ez2)
    g_e = float(lam_end * delta - 0.5 * (lam_end**2) * ez2)
    g_diff = float(np.clip(g_e - g_k, -1.0, 1.0))

    n_i = int(n)
    if n_i > 0:
        n_f = float(n_i)
        e1 = mu_hat
        e2 = float(s2) / n_f
        e3 = float(s3) / n_f
        e4 = float(s4) / n_f

        m2 = m * m
        m3 = m2 * m
        m4 = m2 * m2

        mu2_z = e2 - 2.0 * m * e1 + m2
        mu3_z = e3 - 3.0 * m * e2 + 3.0 * m2 * e1 - m3
        mu4_z = e4 - 4.0 * m * e3 + 6.0 * m2 * e2 - 4.0 * m3 * e1 + m4

        mu2_z = float(np.clip(mu2_z, 1e-8, 1.0))

        skew = mu3_z / (mu2_z**1.5 + 1e-12)
        exkurt = mu4_z / (mu2_z * mu2_z + 1e-12) - 3.0

        w3 = n_f / (n_f + float(SKEW_SHRINK_K))
        w4 = n_f / (n_f + float(KURT_SHRINK_K))
        skew = float(np.clip(skew * w3, -3.0, 3.0))
        exkurt = float(np.clip(exkurt * w4, -3.0, 10.0))

        p_pos = float(n_pos) / n_f
        p_low = float(n_low) / n_f
        p_high = float(n_high) / n_f
    else:
        mu2_z = float(np.clip(var_hat, 1e-8, 1.0))
        mu3_z = 0.0
        mu4_z = 0.0
        skew = 0.0
        exkurt = 0.0
        p_pos = 0.5
        p_low = 0.0
        p_high = 0.0

    mu_eff = float(np.clip(mu_hat, 1e-4, 1.0 - 1e-4))
    var_eff = float(np.clip(var_hat, 1e-6, 0.25))
    kappa = mu_eff * (1.0 - mu_eff) / var_eff - 1.0
    kappa = float(np.clip(kappa, 1e-3, 1e6))
    log_kappa = float(np.clip(np.log(kappa), -10.0, 10.0))

    g4_k = (
        lam_kelly * delta
        - 0.5 * (lam_kelly**2) * mu2_z
        + (1.0 / 3.0) * (lam_kelly**3) * mu3_z
        - 0.25 * (lam_kelly**4) * mu4_z
    )
    g4_e = (
        lam_end * delta
        - 0.5 * (lam_end**2) * mu2_z
        + (1.0 / 3.0) * (lam_end**3) * mu3_z
        - 0.25 * (lam_end**4) * mu4_z
    )
    g4_diff = (g4_e - g4_k) / (1.0 + abs(g4_e) + abs(g4_k))
    g4_diff = float(np.clip(g4_diff, -1.0, 1.0))

    logN = float(np.log(float(N)))
    m_norm = float(m)

    return np.array([
        1.0,
        delta,
        abs_delta,
        dist,
        time_remaining_rel,
        var_hat,
        snr,
        req,
        lam_kelly,
        lam_end,
        lam_gap,
        g_diff,
        m_norm,
        logN,
        mu2_z,
        skew,
        exkurt,
        p_pos,
        p_low,
        p_high,
        log_kappa,
        g4_diff,
    ], dtype=np.float32)


def _features_from_state_batch(
    m, mu_hat, var_hat, y, T, t, N, lam_kelly, lam_end,
    s2, s3, s4, n, n_pos, n_low, n_high,
    tau=FEAT_TAU,
):
    """Batch extended feature vector. Returns (B, d) float32 with d=22."""
    m = np.float32(m)

    mu_hat = np.asarray(mu_hat, dtype=np.float32)
    var_hat = np.asarray(var_hat, dtype=np.float32)
    y = np.asarray(y, dtype=np.float32)
    lam_kelly = np.asarray(lam_kelly, dtype=np.float32)
    lam_end = np.asarray(lam_end, dtype=np.float32)

    s2 = np.asarray(s2, dtype=np.float32)
    s3 = np.asarray(s3, dtype=np.float32)
    s4 = np.asarray(s4, dtype=np.float32)

    n = np.asarray(n, dtype=np.int32)
    n_pos = np.asarray(n_pos, dtype=np.int32)
    n_low = np.asarray(n_low, dtype=np.int32)
    n_high = np.asarray(n_high, dtype=np.int32)

    delta = (mu_hat - m).astype(np.float32)
    abs_delta = np.abs(delta).astype(np.float32)

    dist = ((np.float32(T) - y) / (np.abs(np.float32(T)) + 1e-12)).astype(np.float32)
    dist = np.clip(dist, -2.0, 2.0).astype(np.float32)

    time_remaining_rel = np.float32((N - 1 - t) / max(1, N - 1))

    var_hat = np.clip(var_hat, 0.0, 0.25).astype(np.float32)
    se = np.sqrt(var_hat / np.float32(max(1, t))).astype(np.float32)
    snr = (delta / (se + 1e-12)).astype(np.float32)
    snr = np.clip(snr, -5.0, 5.0).astype(np.float32)

    steps_left = max(1, N - t)
    req = ((np.float32(T) - y) / np.float32(steps_left)).astype(np.float32)
    req = np.clip(req / (np.abs(np.float32(T)) + 1e-12), -2.0, 2.0).astype(np.float32)

    lam_gap = (lam_end - lam_kelly).astype(np.float32)

    ez2 = (var_hat + delta * delta).astype(np.float32)
    g_k = (lam_kelly * delta - 0.5 * (lam_kelly * lam_kelly) * ez2).astype(np.float32)
    g_e = (lam_end * delta - 0.5 * (lam_end * lam_end) * ez2).astype(np.float32)
    g_diff = np.clip(g_e - g_k, -1.0, 1.0).astype(np.float32)

    n_f = np.maximum(n.astype(np.float32), 1.0)
    zero = (n <= 0)

    e1 = mu_hat
    e2 = s2 / n_f
    e3 = s3 / n_f
    e4 = s4 / n_f

    m2 = m * m
    m3 = m2 * m
    m4 = m2 * m2

    mu2_z = e2 - 2.0 * m * e1 + m2
    mu3_z = e3 - 3.0 * m * e2 + 3.0 * m2 * e1 - m3
    mu4_z = e4 - 4.0 * m * e3 + 6.0 * m2 * e2 - 4.0 * m3 * e1 + m4

    mu2_z = np.where(zero, var_hat, mu2_z)
    mu3_z = np.where(zero, 0.0, mu3_z)
    mu4_z = np.where(zero, 0.0, mu4_z)

    mu2_z = np.clip(mu2_z, 1e-8, 1.0).astype(np.float32)

    skew = mu3_z / (np.power(mu2_z, 1.5) + 1e-12)
    exkurt = mu4_z / (mu2_z * mu2_z + 1e-12) - 3.0

    w3 = n_f / (n_f + np.float32(SKEW_SHRINK_K))
    w4 = n_f / (n_f + np.float32(KURT_SHRINK_K))
    skew = np.clip(skew * w3, -3.0, 3.0)
    exkurt = np.clip(exkurt * w4, -3.0, 10.0)

    skew = np.where(zero, 0.0, skew).astype(np.float32)
    exkurt = np.where(zero, 0.0, exkurt).astype(np.float32)

    p_pos = np.where(n > 0, n_pos.astype(np.float32) / n_f, np.float32(0.5)).astype(np.float32)
    p_low = np.where(n > 0, n_low.astype(np.float32) / n_f, np.float32(0.0)).astype(np.float32)
    p_high = np.where(n > 0, n_high.astype(np.float32) / n_f, np.float32(0.0)).astype(np.float32)

    mu_eff = np.clip(mu_hat, 1e-4, 1.0 - 1e-4)
    var_eff = np.clip(var_hat, 1e-6, 0.25)
    kappa = mu_eff * (1.0 - mu_eff) / var_eff - 1.0
    kappa = np.clip(kappa, 1e-3, 1e6)
    log_kappa = np.clip(np.log(kappa), -10.0, 10.0).astype(np.float32)

    g4_k = (
        lam_kelly * delta
        - 0.5 * (lam_kelly**2) * mu2_z
        + (1.0 / 3.0) * (lam_kelly**3) * mu3_z
        - 0.25 * (lam_kelly**4) * mu4_z
    )
    g4_e = (
        lam_end * delta
        - 0.5 * (lam_end**2) * mu2_z
        + (1.0 / 3.0) * (lam_end**3) * mu3_z
        - 0.25 * (lam_end**4) * mu4_z
    )
    g4_diff = (g4_e - g4_k) / (1.0 + np.abs(g4_e) + np.abs(g4_k))
    g4_diff = np.clip(g4_diff, -1.0, 1.0).astype(np.float32)

    bias = np.ones_like(delta, dtype=np.float32)
    tr = np.full_like(delta, time_remaining_rel, dtype=np.float32)
    m_norm = np.full_like(delta, float(m), dtype=np.float32)
    logN = np.full_like(delta, np.float32(np.log(float(N))), dtype=np.float32)

    return np.stack(
        [
            bias, delta, abs_delta, dist, tr, var_hat, snr, req,
            lam_kelly, lam_end, lam_gap, g_diff,
            m_norm, logN,
            mu2_z, skew, exkurt, p_pos, p_low, p_high, log_kappa, g4_diff,
        ],
        axis=1
    ).astype(np.float32)


__all__ = [
    "_features_from_state",
    "_features_from_state_batch",
]
