"""
Rollout utilities for DQN policy.
"""

from __future__ import annotations

import numpy as np

from .constants import FEAT_TAU
from .core import _lambda_from_action, _lambda_from_action_batch
from .env import sample_world, sample_world_batch
from .features import _features_from_state, _features_from_state_batch
from .kelly import _kelly_and_endpoint_from_past, _kelly_and_endpoint_from_past_batch


def rollout_episode_dqn_epsilon(
    agent,
    N,
    alpha,
    m,
    mu,
    world="beta_mixture",
    conc=6.0,
    eps_cap=1e-3,
    var_floor=0.0,
    shrink_kappa=0.0,
    lcap=None,
    explore_eps=0.1,
    rng=None,
):
    """
    Roll out one episode using a DQN that chooses among discrete actions.
    """
    if rng is None:
        rng = np.random.default_rng()

    X = sample_world(N, mu, world=world, rng=rng, conc=conc)
    T = float(np.log(1.0 / alpha))

    y = 0.0
    s1 = 0.0
    s2 = 0.0
    s3 = 0.0
    s4 = 0.0
    n = 0
    n_pos = 0
    n_low = 0
    n_high = 0

    transitions = []
    hit_time = None

    for t in range(N):
        mu_hat_prev, lam_kelly, lam_end, var_hat = _kelly_and_endpoint_from_past(
            s1, s2, n, m,
            eps_cap=eps_cap,
            var_floor=var_floor,
            shrink_kappa=shrink_kappa,
            lcap=lcap,
        )

        phi_t = _features_from_state(
            m, mu_hat_prev, var_hat, y, T, t, N, lam_kelly, lam_end,
            s2, s3, s4, n, n_pos, n_low, n_high
        )

        a_t = agent.select_action(phi_t, explore_eps)
        lam_t = _lambda_from_action(a_t, lam_kelly, lam_end)

        x_t = float(X[t])
        x2 = x_t * x_t
        y_next = y + np.log(1.0 + lam_t * (x_t - m))

        s1_next = s1 + x_t
        s2_next = s2 + x2
        s3_next = s3 + x2 * x_t
        s4_next = s4 + x2 * x2
        n_next = n + 1

        n_pos_next = n_pos + int(x_t > m)
        n_low_next = n_low + int(x_t < FEAT_TAU)
        n_high_next = n_high + int(x_t > (1.0 - FEAT_TAU))

        if t < N - 1:
            mu_hat_next, lam_kelly_next, lam_end_next, var_hat_next = _kelly_and_endpoint_from_past(
                s1_next, s2_next, n_next, m,
                eps_cap=eps_cap,
                var_floor=var_floor,
                shrink_kappa=shrink_kappa,
                lcap=lcap,
            )
            phi_next = _features_from_state(
                m, mu_hat_next, var_hat_next, y_next, T, t + 1, N, lam_kelly_next, lam_end_next,
                s2_next, s3_next, s4_next, n_next, n_pos_next, n_low_next, n_high_next
            )
        else:
            phi_next = np.zeros_like(phi_t)

        crossed = (y_next >= T)
        is_last_step = (t == N - 1)

        done = False
        reward = 0.0

        if crossed:
            hit_time = t + 1
            done = True
            reward = 1.0
        elif is_last_step:
            done = True

        transitions.append((phi_t.copy(), a_t, reward, phi_next.copy(), done))

        y = y_next
        s1, s2, s3, s4, n = s1_next, s2_next, s3_next, s4_next, n_next
        n_pos, n_low, n_high = n_pos_next, n_low_next, n_high_next

        if done:
            break

    return transitions, hit_time


def rollout_batch_dqn_epsilon(
    agent,
    batch_episodes,
    N,
    alpha,
    m,
    mu,
    world="beta_mixture",
    conc=6.0,
    eps_cap=1e-3,
    var_floor=0.0,
    shrink_kappa=0.0,
    lcap=None,
    explore_eps=None,
    rng=None,
):
    """
    Vectorized rollout of B episodes in parallel.
    """
    if rng is None:
        rng = np.random.default_rng()

    B = int(batch_episodes)
    N = int(N)
    d = int(agent.state_dim)

    if explore_eps is None:
        explore_eps = np.full(B, 0.1, dtype=np.float32)
    else:
        explore_eps = np.asarray(explore_eps, dtype=np.float32)
        if explore_eps.shape == ():
            explore_eps = np.full(B, float(explore_eps), dtype=np.float32)
        if explore_eps.shape[0] != B:
            raise ValueError("explore_eps must be scalar or shape (batch_episodes,)")

    X = sample_world_batch(B, N, mu, world=world, rng=rng, conc=conc)
    T = float(np.log(1.0 / alpha))

    y = np.zeros(B, dtype=np.float32)
    s1 = np.zeros(B, dtype=np.float32)
    s2 = np.zeros(B, dtype=np.float32)
    s3 = np.zeros(B, dtype=np.float32)
    s4 = np.zeros(B, dtype=np.float32)
    n = np.zeros(B, dtype=np.int32)
    n_pos = np.zeros(B, dtype=np.int32)
    n_low = np.zeros(B, dtype=np.int32)
    n_high = np.zeros(B, dtype=np.int32)
    done = np.zeros(B, dtype=np.bool_)
    hit_time = -np.ones(B, dtype=np.int32)

    max_trans = B * N
    S = np.zeros((max_trans, d), dtype=np.float32)
    A = np.zeros(max_trans, dtype=np.int64)
    R = np.zeros(max_trans, dtype=np.float32)
    S_next = np.zeros((max_trans, d), dtype=np.float32)
    D = np.zeros(max_trans, dtype=np.bool_)
    p = 0

    for t in range(N):
        active = ~done
        if not active.any():
            break
        idx = np.nonzero(active)[0]

        mu_hat, lam_k, lam_end, var_hat = _kelly_and_endpoint_from_past_batch(
            s1[idx], s2[idx], n[idx], m,
            eps_cap=eps_cap, var_floor=var_floor, shrink_kappa=shrink_kappa, lcap=lcap
        )

        phi = _features_from_state_batch(
            m, mu_hat, var_hat, y[idx], T, t, N, lam_k, lam_end,
            s2[idx], s3[idx], s4[idx], n[idx], n_pos[idx], n_low[idx], n_high[idx],
        )
        a = agent.select_action_batch(phi, explore_eps[idx])
        lam_t = _lambda_from_action_batch(a, lam_k, lam_end)

        x_t = X[idx, t]
        x2 = x_t * x_t

        y_next = (y[idx] + np.log1p(lam_t * (x_t - np.float32(m)))).astype(np.float32)

        s1_next = s1[idx] + x_t
        s2_next = s2[idx] + x2
        s3_next = s3[idx] + x2 * x_t
        s4_next = s4[idx] + x2 * x2
        n_next = n[idx] + 1

        n_pos_next = n_pos[idx] + (x_t > np.float32(m)).astype(np.int32)
        n_low_next = n_low[idx] + (x_t < np.float32(FEAT_TAU)).astype(np.int32)
        n_high_next = n_high[idx] + (x_t > np.float32(1.0 - FEAT_TAU)).astype(np.int32)

        crossed = (y_next >= np.float32(T))
        is_last = (t == N - 1)
        done_a = (crossed | is_last)
        reward = crossed.astype(np.float32)

        phi_next_full = np.zeros_like(phi, dtype=np.float32)
        if t < N - 1:
            nd = ~done_a
            if nd.any():
                mu_hat_n, lam_k_n, lam_end_n, var_hat_n = _kelly_and_endpoint_from_past_batch(
                    s1_next[nd], s2_next[nd], n_next[nd], m,
                    eps_cap=eps_cap, var_floor=var_floor, shrink_kappa=shrink_kappa, lcap=lcap
                )
                phi_next_full[nd] = _features_from_state_batch(
                    m, mu_hat_n, var_hat_n, y_next[nd], T, t + 1, N, lam_k_n, lam_end_n,
                    s2_next[nd], s3_next[nd], s4_next[nd], n_next[nd],
                    n_pos_next[nd], n_low_next[nd], n_high_next[nd],
                )

        K = int(idx.shape[0])
        S[p:p+K] = phi
        A[p:p+K] = a
        R[p:p+K] = reward
        S_next[p:p+K] = phi_next_full
        D[p:p+K] = done_a
        p += K

        y[idx] = y_next
        s1[idx] = s1_next
        s2[idx] = s2_next
        s3[idx] = s3_next
        s4[idx] = s4_next
        n[idx] = n_next
        n_pos[idx] = n_pos_next
        n_low[idx] = n_low_next
        n_high[idx] = n_high_next

        ht = hit_time[idx]
        newly_hit = crossed & (ht < 0)
        if newly_hit.any():
            hit_time[idx[newly_hit]] = t + 1

        done[idx] = done_a

    return (S[:p], A[:p], R[:p], S_next[:p], D[:p]), hit_time


def lam_sequence_from_dqn_epsilon_policy(
    X,
    m,
    alpha,
    agent,
    eps_cap=1e-3,
    var_floor=0.0,
    shrink_kappa=0.0,
    lcap=None,
    explore_eps=0.0,
):
    """
    Build λ_t using a trained DQNEpsilonAgent.
    """
    X = np.asarray(X, float)
    N = len(X)
    lam = np.zeros(N, float)
    T = float(np.log(1.0 / alpha))

    s1 = 0.0
    s2 = 0.0
    s3 = 0.0
    s4 = 0.0
    n = 0
    n_pos = 0
    n_low = 0
    n_high = 0
    y = 0.0

    for t in range(N):
        mu_hat_prev, lam_kelly, lam_end, var_hat = _kelly_and_endpoint_from_past(
            s1, s2, n, m,
            eps_cap=eps_cap,
            var_floor=var_floor,
            shrink_kappa=shrink_kappa,
            lcap=lcap,
        )

        phi = _features_from_state(
            m, mu_hat_prev, var_hat, y, T, t, N, lam_kelly, lam_end,
            s2, s3, s4, n, n_pos, n_low, n_high
        )

        a_t = agent.select_action(phi, explore_eps)
        lam[t] = _lambda_from_action(a_t, lam_kelly, lam_end)

        x = float(X[t])
        x2 = x * x
        y += np.log(1.0 + lam[t] * (x - m))
        s1 += x
        s2 += x2
        s3 += x2 * x
        s4 += x2 * x2
        n_pos += int(x > m)
        n_low += int(x < FEAT_TAU)
        n_high += int(x > (1.0 - FEAT_TAU))
        n += 1

    return lam


def trace_dqn_epsilon_episode(
    agent,
    N,
    alpha,
    m,
    mu,
    world="beta_mixture",
    conc=6.0,
    eps_cap=1e-3,
    var_floor=0.0,
    shrink_kappa=0.0,
    lcap=None,
    explore_eps=0.0,
    rng=None,
    stop_on_hit=True,
):
    """
    Runs ONE test episode and logs paths for analysis.
    """
    if rng is None:
        rng = np.random.default_rng()

    X = sample_world(N, mu, world=world, rng=rng, conc=conc)
    T = float(np.log(1.0 / alpha))

    y = 0.0
    s1 = 0.0
    s2 = 0.0
    s3 = 0.0
    s4 = 0.0
    n = 0
    n_pos = 0
    n_low = 0
    n_high = 0

    Y_path = [y]
    eps_path = []
    lam_path = []
    a_path = []

    for t in range(N):
        mu_hat_prev, lam_kelly, lam_end, var_hat = _kelly_and_endpoint_from_past(
            s1, s2, n, m,
            eps_cap=eps_cap,
            var_floor=var_floor,
            shrink_kappa=shrink_kappa,
            lcap=lcap,
        )

        phi_t = _features_from_state(
            m, mu_hat_prev, var_hat, y, T, t, N, lam_kelly, lam_end,
            s2, s3, s4, n, n_pos, n_low, n_high
        )

        a_t = agent.select_action(phi_t, explore_eps)

        eps_t = float(agent.epsilon_actions[a_t])

        lam_t = _lambda_from_action(a_t, lam_kelly, lam_end)

        x_t = float(X[t])
        x2 = x_t * x_t
        y_next = y + np.log1p(lam_t * (x_t - m))

        a_path.append(int(a_t))
        eps_path.append(eps_t)
        lam_path.append(float(lam_t))
        Y_path.append(float(y_next))

        s1 += x_t
        s2 += x2
        s3 += x2 * x_t
        s4 += x2 * x2
        n_pos += int(x_t > m)
        n_low += int(x_t < FEAT_TAU)
        n_high += int(x_t > (1.0 - FEAT_TAU))
        n += 1
        y = y_next

        if stop_on_hit and (y >= T):
            break

    return (
        np.asarray(X[:len(eps_path)], dtype=float),
        np.asarray(Y_path, dtype=float),
        np.asarray(eps_path, dtype=float),
        np.asarray(lam_path, dtype=float),
        np.asarray(a_path, dtype=int),
    )


__all__ = [
    "rollout_episode_dqn_epsilon",
    "rollout_batch_dqn_epsilon",
    "lam_sequence_from_dqn_epsilon_policy",
    "trace_dqn_epsilon_episode",
]
