"""
Evaluation and data-collection utilities for the DQN policy and baselines.
"""

from __future__ import annotations

import numpy as np

from .core import _resolve_world, _sample_conc_uniform
from .env import sample_world
from .features import _features_from_state
from .hedge import (
    build_time_only_epsilon_schedule_dict,
    epsilon_schedule_dict_to_matrix,
    simulate_expweights_hedge_time_only_eps_greedy_experts,
    simulate_uniform_hedge_time_only_eps_greedy_experts,
)
from .kelly import _kelly_and_endpoint_from_past
from .policies import (
    empirical_kelly_policy_predictable,
    eps_greedy_policy,
    linear_epsilon_schedule,
    simulate_with_policy,
)
from .rollouts import (
    lam_sequence_from_dqn_epsilon_policy,
    rollout_batch_dqn_epsilon,
    trace_dqn_epsilon_episode,
)
from .star import (
    star_bets_test_process,
    star_hoeffding_test_process,
)


def sample_episode_params(
    rng,
    alpha,
    conc,
    world="beta_mixture",
    N_range=(50, 500),
    m_range=(0.1, 0.9),
    difficulty_range=(0.7, 1.3),
    mu_clip=(0.02, 0.98),
):
    """
    Sample episode parameters (N, m, μ) with difficulty coupling.
    """
    N_min, N_max = N_range
    logN = rng.uniform(np.log(N_min), np.log(N_max))
    N = int(np.round(np.exp(logN)))

    m = float(rng.uniform(*m_range))

    T = float(np.log(1.0 / alpha))

    mu0 = m
    var1 = mu0 * (1 - mu0) / (conc + 1.0)
    var2 = mu0 * (1 - mu0) / (2.0 * conc + 1.0)
    if world == "beta":
        var_proxy = var1
    elif world == "beta_mixture":
        var_proxy = 0.5 * (var1 + var2)
    else:
        raise ValueError("world must be 'beta' or 'beta_mixture' (sample_episode_params expects a concrete world)")

    c = float(rng.uniform(*difficulty_range))

    delta_mag = np.sqrt(max(1e-12, 2.0 * var_proxy * c * T / N))
    sign = 1.0 if rng.random() < 0.5 else -1.0
    mu = float(np.clip(m + sign * delta_mag, mu_clip[0], mu_clip[1]))

    return N, m, mu


def eval_greedy_hit_rate(
    agent,
    eval_episodes,
    N,
    alpha,
    m,
    mu,
    world="beta_mixture",
    conc=6.0,
    conc_range=(1.0, 12.0),
    eps_cap=1e-3,
    var_floor=0.0,
    shrink_kappa=0.0,
    lcap=None,
    eval_seed=12345,
    eval_batch_size=256,
    domain_randomize=False,
    N_range=(50, 500),
    m_range=(0.1, 0.9),
    difficulty_range=(0.7, 1.3),
    mu_clip=(0.02, 0.98),
):
    """
    Greedy hit-rate metric using explore_eps=0.0.
    """
    agent.sync_actor()
    rng = np.random.default_rng(int(eval_seed))

    eval_episodes = int(eval_episodes)
    eval_batch_size = max(1, int(eval_batch_size))

    hits = 0
    done = 0

    while done < eval_episodes:
        B = min(eval_batch_size, eval_episodes - done)

        world_ep = _resolve_world(world, rng)

        if domain_randomize:
            conc_ep = _sample_conc_uniform(rng, conc_range)
            N_ep, m_ep, mu_ep = sample_episode_params(
                rng,
                alpha=alpha,
                conc=conc_ep,
                world=world_ep,
                N_range=N_range,
                m_range=m_range,
                difficulty_range=difficulty_range,
                mu_clip=mu_clip,
            )
            lcap_ep = lcap
        else:
            N_ep, m_ep, mu_ep = int(N), float(m), float(mu)
            lcap_ep = None
            conc_ep = float(conc)

        (_, _, _, _, _), hit_times = rollout_batch_dqn_epsilon(
            agent,
            batch_episodes=B,
            N=N_ep,
            alpha=alpha,
            m=m_ep,
            mu=mu_ep,
            world=world_ep,
            conc=conc_ep,
            eps_cap=eps_cap,
            var_floor=var_floor,
            shrink_kappa=shrink_kappa,
            lcap=lcap_ep,
            explore_eps=0.0,
            rng=rng,
        )

        hits += int((hit_times >= 0).sum())
        done += B

    return hits / float(eval_episodes)


def evaluate_policies_with_dqn_epsilon(
    N=500,
    alpha=0.005,
    trials=1000,
    m=0.41,
    mu=0.40,
    world="beta_mixture",
    conc=6.0,
    agent=None,
    seed=123,
    *,
    include_star=True,
    include_uniform_hedge=False,
    include_expweights_hedge=True,
    expweights_eta=2.0,
    expweights_gamma=0.01,
    expweights_score_mode="shadow",
    hedge_ks=(0.0, 0.25, 0.5, 0.75),
    hedge_coupled=True,
    eps_cap=1e-3,
    var_floor=0.0,
    shrink_kappa=0.0,
    lcap=None,
):
    assert agent is not None, "Need a trained DQNEpsilonAgent."

    rng = np.random.default_rng(seed)

    hits_kelly = np.zeros((trials, N + 1), dtype=bool)
    hits_eps_baseline = np.zeros((trials, N + 1), dtype=bool)
    hits_dqn = np.zeros((trials, N + 1), dtype=bool)

    hits_star_bets = None
    hits_star_hoeffding = None

    if include_star:
        hits_star_bets = np.zeros((trials, N + 1), dtype=bool)
        hits_star_hoeffding = np.zeros((trials, N + 1), dtype=bool)
        rng_star = np.random.default_rng(int(seed) + 999999)
    else:
        rng_star = None

    eps_seq = linear_epsilon_schedule(N, t0_frac=0.5)

    eps_mat = None
    hits_uniform_hedge = None
    hits_exphedge = None

    if include_uniform_hedge or include_expweights_hedge:
        sched = build_time_only_epsilon_schedule_dict(N, ks=hedge_ks)
        _, eps_mat = epsilon_schedule_dict_to_matrix(sched)

    if include_uniform_hedge:
        hits_uniform_hedge = np.zeros((trials, N + 1), dtype=bool)
        rng_hedge = np.random.default_rng(int(seed) + 424242)

    if include_expweights_hedge:
        hits_exphedge = np.zeros((trials, N + 1), dtype=bool)
        rng_exphedge = np.random.default_rng(int(seed) + 777777)

    for i in range(trials):
        X = sample_world(N, mu, world=world, rng=rng, conc=conc)

        lam_I = empirical_kelly_policy_predictable(
            X, m,
            eps_cap=eps_cap,
            var_floor=var_floor,
            shrink_kappa=shrink_kappa,
            lcap=lcap,
        )
        _, hitI = simulate_with_policy(X, lam_I, m, alpha)
        hits_kelly[i] = hitI

        lam_eps = eps_greedy_policy(
            X, m, eps_seq=eps_seq, rng=rng,
            eps_cap=eps_cap,
            var_floor=var_floor,
            shrink_kappa=shrink_kappa,
            lcap=lcap,
        )
        _, hit_eps = simulate_with_policy(X, lam_eps, m, alpha)
        hits_eps_baseline[i] = hit_eps

        lam_dqn = lam_sequence_from_dqn_epsilon_policy(
            X, m, alpha, agent,
            eps_cap=eps_cap,
            var_floor=var_floor,
            shrink_kappa=shrink_kappa,
            lcap=lcap,
            explore_eps=0.0,
        )
        _, hit_dqn = simulate_with_policy(X, lam_dqn, m, alpha)
        hits_dqn[i] = hit_dqn

        if include_uniform_hedge:
            _, hit_h = simulate_uniform_hedge_time_only_eps_greedy_experts(
                X, m, alpha, eps_mat,
                rng=rng_hedge,
                coupled=hedge_coupled,
                stop_on_hit=True,
                eps_cap=eps_cap,
                var_floor=var_floor,
                shrink_kappa=shrink_kappa,
                lcap=lcap,
            )
            hits_uniform_hedge[i] = hit_h

        if include_expweights_hedge:
            _, hit_eh = simulate_expweights_hedge_time_only_eps_greedy_experts(
                X, m, alpha, eps_mat,
                eta=expweights_eta,
                gamma=expweights_gamma,
                score_mode=expweights_score_mode,
                rng=rng_exphedge,
                coupled=hedge_coupled,
                stop_on_hit=True,
                eps_cap=eps_cap,
                var_floor=var_floor,
                shrink_kappa=shrink_kappa,
                lcap=lcap,
            )
            hits_exphedge[i] = hit_eh

        if include_star:
            _, hit_sb, _ = star_bets_test_process(
                X, m, delta=alpha,
                alpha_var=alpha,
                use_impl_details=True,
                c=1.0,
                clip_v="m1m",
                two_sided=True,
                last_round_randomize=False,
                rng=rng_star,
                stop_on_hit=False,
            )
            hits_star_bets[i] = hit_sb

            _, hit_sh, _ = star_hoeffding_test_process(
                X, m, delta=alpha,
                two_sided=True,
                stop_on_hit=False,
            )
            hits_star_hoeffding[i] = hit_sh

    t = np.arange(N + 1)
    curves = {
        "Empirical Kelly": hits_kelly.mean(axis=0),
        "Linear-ε baseline": hits_eps_baseline.mean(axis=0),
        "DQN policy": hits_dqn.mean(axis=0),
    }

    if include_star:
        curves["STaR-Bets"] = hits_star_bets.mean(axis=0)
        curves["STaR-Hoeffding"] = hits_star_hoeffding.mean(axis=0)

    if include_uniform_hedge:
        curves[f"Uniform hedge (K={eps_mat.shape[0]} schedules)"] = hits_uniform_hedge.mean(axis=0)

    if include_expweights_hedge:
        curves["Hedge"] = hits_exphedge.mean(axis=0)

    return t, curves


def sample_trajectories_kelly_dqn_linear(
    agent,
    num_paths=10,
    N=500,
    alpha=0.005,
    m=0.41,
    mu=0.40,
    world="beta_mixture",
    conc=6.0,
    t0_frac_eps=0.5,
    seed=123,
):
    """
    Sample trajectories for Kelly, linear epsilon, and DQN policies.
    """
    rng = np.random.default_rng(seed)

    lam_paths = {
        "Empirical Kelly": np.zeros((num_paths, N), dtype=float),
        "Linear-ε baseline": np.zeros((num_paths, N), dtype=float),
        "DQN ε-policy": np.zeros((num_paths, N), dtype=float),
    }
    y_paths = {
        "Empirical Kelly": np.zeros((num_paths, N + 1), dtype=float),
        "Linear-ε baseline": np.zeros((num_paths, N + 1), dtype=float),
        "DQN ε-policy": np.zeros((num_paths, N + 1), dtype=float),
    }

    eps_seq = linear_epsilon_schedule(N, t0_frac=t0_frac_eps)

    for i in range(num_paths):
        X = sample_world(N, mu, world=world, rng=rng, conc=conc)

        lam_I = empirical_kelly_policy_predictable(X, m)
        Y_I, _ = simulate_with_policy(X, lam_I, m, alpha)

        lam_eps = eps_greedy_policy(X, m, eps_seq=eps_seq, rng=rng)
        Y_eps, _ = simulate_with_policy(X, lam_eps, m, alpha)

        lam_dqn = lam_sequence_from_dqn_epsilon_policy(
            X, m, alpha, agent, explore_eps=0.0
        )
        Y_dqn, _ = simulate_with_policy(X, lam_dqn, m, alpha)

        lam_paths["Empirical Kelly"][i] = lam_I
        lam_paths["Linear-ε baseline"][i] = lam_eps
        lam_paths["DQN ε-policy"][i] = lam_dqn

        y_paths["Empirical Kelly"][i] = Y_I
        y_paths["Linear-ε baseline"][i] = Y_eps
        y_paths["DQN ε-policy"][i] = Y_dqn

    return lam_paths, y_paths


def aggregate_modal_eps_grid(
    agent,
    trials,
    N,
    alpha,
    m,
    mu,
    world="beta_mixture",
    conc=6.0,
    explore_eps=0.0,
    stop_on_hit=True,
    num_y_bins=40,
    y_range=None,
    quantile_clip=(0.01, 0.99),
    seed=123,
    t_bin_width=1,
):
    """
    Empirically estimate a projected policy map on (t, y=log-wealth).
    """
    rng = np.random.default_rng(seed)

    t_list = []
    y_list = []
    a_list = []

    for _ in range(int(trials)):
        X, Y_path, eps_path, lam_path, a_path = trace_dqn_epsilon_episode(
            agent=agent,
            N=N,
            alpha=alpha,
            m=m,
            mu=mu,
            world=world,
            conc=conc,
            explore_eps=explore_eps,
            rng=rng,
            stop_on_hit=stop_on_hit,
        )

        L = len(a_path)
        if L == 0:
            continue

        t_list.append(np.arange(L, dtype=np.int32))
        y_list.append(np.asarray(Y_path[:-1], dtype=np.float32))
        a_list.append(np.asarray(a_path, dtype=np.int64))

    if len(t_list) == 0:
        raise RuntimeError("No transitions collected. Try stop_on_hit=False or increase trials.")

    t_all = np.concatenate(t_list, axis=0)
    y_all = np.concatenate(y_list, axis=0)
    a_all = np.concatenate(a_list, axis=0)

    T = float(np.log(1.0 / alpha))
    if y_range is None:
        qlo, qhi = quantile_clip
        y_lo = float(np.quantile(y_all, qlo))
        y_hi = float(np.quantile(y_all, qhi))
        y_hi = max(y_hi, T)
        if y_hi <= y_lo:
            y_hi = y_lo + 1.0
        pad = 0.05 * (y_hi - y_lo)
        y_lo -= pad
        y_hi += pad
    else:
        y_lo, y_hi = map(float, y_range)
        if y_hi <= y_lo:
            raise ValueError("y_range must satisfy ymax > ymin")

    y_edges = np.linspace(y_lo, y_hi, int(num_y_bins) + 1, dtype=np.float32)

    ybin = np.searchsorted(y_edges, y_all, side="right") - 1

    t_bin_width = int(t_bin_width)
    if t_bin_width < 1:
        raise ValueError("t_bin_width must be >= 1")

    num_t_bins = int(np.ceil(N / t_bin_width))
    tbin = (t_all // t_bin_width).astype(np.int32)

    num_actions = int(agent.num_actions)
    counts = np.zeros((num_t_bins, num_y_bins, num_actions), dtype=np.int32)

    valid = (
        (tbin >= 0) & (tbin < num_t_bins) &
        (ybin >= 0) & (ybin < num_y_bins) &
        (a_all >= 0) & (a_all < num_actions)
    )
    tb_v = tbin[valid]
    yb_v = ybin[valid]
    a_v = a_all[valid]

    np.add.at(counts, (tb_v, yb_v, a_v), 1)

    totals = counts.sum(axis=2)
    max_counts = counts.max(axis=2)
    modal_action = counts.argmax(axis=2)

    modal_eps_tb_yb = agent.epsilon_actions[modal_action]
    modal_eps_tb_yb = np.where(totals > 0, modal_eps_tb_yb, np.nan)

    confidence_tb_yb = np.where(totals > 0, max_counts / np.maximum(totals, 1), np.nan)

    modal_eps = modal_eps_tb_yb.T
    confidence = confidence_tb_yb.T
    visit_counts = totals.T

    t_edges = (np.arange(num_t_bins + 1) * t_bin_width).astype(np.float32)
    t_edges[-1] = float(N)

    return modal_eps, confidence, visit_counts, y_edges, t_edges


def mask_low_visit_cells(modal_eps, confidence, visit_counts, min_visits=5):
    """
    Mask cells with fewer than min_visits.
    """
    modal_eps = np.array(modal_eps, dtype=float, copy=True)
    confidence = np.array(confidence, dtype=float, copy=True)
    visit_counts = np.asarray(visit_counts)

    min_visits = int(min_visits)
    mask = (visit_counts < min_visits)

    modal_eps[mask] = np.nan
    confidence[mask] = np.nan

    return modal_eps, confidence, mask


def dp_value_to_modal_grid(dp_info, y_edges, t_edges):
    """
    Interpolate DP value V_t(y) onto the modal grid.
    """
    V = np.asarray(dp_info["V"], dtype=float)
    dp_y = np.asarray(dp_info["y_centers"], dtype=float)

    y_cent = 0.5 * (np.asarray(y_edges[:-1]) + np.asarray(y_edges[1:]))

    t_rep = np.asarray(t_edges[:-1], dtype=int)
    t_rep = np.clip(t_rep, 0, V.shape[1] - 1)

    out = np.empty((y_cent.shape[0], t_rep.shape[0]), dtype=float)
    for j, t in enumerate(t_rep):
        out[:, j] = np.interp(y_cent, dp_y, V[:, t], left=0.0, right=1.0)

    return out


__all__ = [
    "sample_episode_params",
    "eval_greedy_hit_rate",
    "evaluate_policies_with_dqn_epsilon",
    "sample_trajectories_kelly_dqn_linear",
    "aggregate_modal_eps_grid",
    "mask_low_visit_cells",
    "dp_value_to_modal_grid",
]
