"""
Training loop for the DQN epsilon policy.
"""

from __future__ import annotations

import os
import numpy as np
import torch

from .constants import EPSILON_ACTIONS
from .core import _resolve_world, _sample_conc_uniform
from .evaluation import eval_greedy_hit_rate, sample_episode_params
from .features import _features_from_state
from .kelly import _kelly_and_endpoint_from_past
from .models import DQNEpsilonAgent
from .rollouts import rollout_batch_dqn_epsilon


def train_dqn_epsilon_policy(
    episodes=2000,
    N=500,
    alpha=0.005,
    m=0.41,
    mu=0.40,
    world="beta_mixture",
    conc=6.0,
    conc_range=(1.0, 12.0),
    lr=1e-3,
    gamma=1.0,
    buffer_capacity=50000,
    batch_size=64,
    target_update_interval=1000,
    explore_eps_start=1.0,
    explore_eps_end=0.05,
    explore_eps_decay=0.995,
    min_buffer_size=1000,
    tau=None,
    seed=0,
    avg_window=1000,
    checkpoint_every=500,
    num_envs=1,
    train_freq=4,
    log_loss_every=200,
    actor_update_interval=100,
    policy_save_path="best_dqn_policy.pt",
    domain_randomize=False,
    N_range=(50, 500),
    m_range=(0.1, 0.9),
    difficulty_range=(0.7, 1.3),
    mu_clip=(0.02, 0.98),
    lcap=5.0,
    eval_episodes=2000,
    eval_seed=12345,
    eval_batch_size=256,
):
    """
    Train a DQN that chooses among discrete epsilon actions.
    """
    T = float(np.log(1.0 / alpha))

    mu0, lam_k0, lam_e0, var0 = _kelly_and_endpoint_from_past(
        0.0, 0.0, 0, m,
        eps_cap=1e-3,
        var_floor=0.0,
        shrink_kappa=0.0,
        lcap=lcap if domain_randomize else None,
    )
    d = len(_features_from_state(
        m, mu0, var0, 0.0, T, 0, N, lam_k0, lam_e0,
        0.0, 0.0, 0.0, 0, 0, 0, 0
    ))

    agent = DQNEpsilonAgent(
        state_dim=d,
        epsilon_actions=EPSILON_ACTIONS,
        lr=lr,
        gamma=gamma,
        buffer_capacity=buffer_capacity,
        target_update_interval=target_update_interval,
        seed=seed,
        tau=tau,
        actor_update_interval=actor_update_interval,
    )
    print(f"[DQN] state_dim = {d}")

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        agent.optimizer,
        mode="max",
        factor=0.8,
        patience=8,
        threshold=1e-4,
        min_lr=1e-6,
    )

    current_lr = agent.optimizer.param_groups[0]['lr']
    rng = np.random.default_rng(seed)

    returns = []
    losses = []
    explore_eps_history = []

    avg_window = max(1, int(avg_window))
    checkpoint_every = max(1, int(checkpoint_every))

    best_eval = -np.inf
    best_ep = None
    best_state_dict = None
    best_train_avg = None

    eval_hit_history = []
    eval_episode_history = []
    train_avg_history = []

    episodes_done = 0
    global_env_steps = 0
    train_freq = max(1, int(train_freq))
    log_loss_every = max(1, int(log_loss_every))

    next_checkpoint = checkpoint_every

    while episodes_done < episodes:
        B = min(int(num_envs), episodes - episodes_done)

        world_ep = _resolve_world(world, rng)
        conc_ep = float(conc)
        if domain_randomize:
            conc_ep = _sample_conc_uniform(rng, conc_range)

        ep_idxs = np.arange(episodes_done, episodes_done + B, dtype=np.int64)
        explore_eps_batch = explore_eps_start * (explore_eps_decay ** ep_idxs)
        explore_eps_batch = np.maximum(explore_eps_end, explore_eps_batch).astype(np.float32)
        explore_eps_history.extend(explore_eps_batch.tolist())

        if domain_randomize:
            N_ep, m_ep, mu_ep = sample_episode_params(
                rng,
                alpha=alpha,
                conc=conc_ep,
                world=world_ep,
                N_range=N_range,
                m_range=m_range,
                difficulty_range=difficulty_range,
                mu_clip=mu_clip,
            )
            lcap_ep = lcap
        else:
            N_ep, m_ep, mu_ep = N, m, mu
            lcap_ep = None

        (S, A, R, S_next, D), hit_times = rollout_batch_dqn_epsilon(
            agent,
            batch_episodes=B,
            N=N_ep,
            alpha=alpha,
            m=m_ep,
            mu=mu_ep,
            world=world_ep,
            conc=conc_ep,
            eps_cap=1e-3,
            var_floor=0.0,
            shrink_kappa=0.0,
            lcap=lcap_ep,
            explore_eps=explore_eps_batch,
            rng=rng,
        )

        returns.extend((hit_times >= 0).astype(np.float32).tolist())

        size_before = agent.buffer.size
        g0 = global_env_steps
        K = int(A.shape[0])
        agent.buffer.add_batch(S, A, R, S_next, D)
        global_env_steps = g0 + K

        size_after = agent.buffer.size
        updates_to_run = 0
        if size_after >= min_buffer_size and K > 0:
            if size_before >= min_buffer_size:
                first_eligible_step = g0 + 1
            else:
                k = int(min_buffer_size - size_before)
                first_eligible_step = g0 + k

            if first_eligible_step <= global_env_steps:
                updates_to_run = (global_env_steps // train_freq) - ((first_eligible_step - 1) // train_freq)
                updates_to_run = max(0, int(updates_to_run))

        for _ in range(updates_to_run):
            loss_t = agent.train_step(batch_size=batch_size)
            if (loss_t is not None) and (agent.train_steps % log_loss_every == 0):
                losses.append(float(loss_t.item()))

        episodes_done += B

        do_ckpt = (episodes_done >= next_checkpoint) or (episodes_done == episodes)
        if do_ckpt:
            w = min(avg_window, len(returns))
            avg_last = float(np.mean(returns[-w:])) if w > 0 else 0.0

            eval_hit = eval_greedy_hit_rate(
                agent=agent,
                eval_episodes=eval_episodes,
                eval_seed=eval_seed,
                eval_batch_size=eval_batch_size,
                N=N,
                alpha=alpha,
                m=m,
                mu=mu,
                world=world,
                conc=conc,
                conc_range=conc_range,
                eps_cap=1e-3,
                var_floor=0.0,
                shrink_kappa=0.0,
                lcap=lcap if domain_randomize else None,
                domain_randomize=domain_randomize,
                N_range=N_range,
                m_range=m_range,
                difficulty_range=difficulty_range,
                mu_clip=mu_clip,
            )

            eval_episode_history.append(int(episodes_done))
            eval_hit_history.append(float(eval_hit))
            train_avg_history.append(float(avg_last))

            prev_lr = current_lr
            scheduler.step(eval_hit)
            current_lr = agent.optimizer.param_groups[0]['lr']
            if current_lr != prev_lr:
                print(f"[DQN] Learning rate changed: {prev_lr:.2e} -> {current_lr:.2e}")

            if eval_hit > best_eval:
                best_eval = float(eval_hit)
                best_ep = int(episodes_done)
                best_train_avg = float(avg_last)

                best_state_dict = {k: v.detach().cpu().clone() for k, v in agent.policy_net.state_dict().items()}

                torch.save({
                    "policy_net_state_dict": best_state_dict,
                    "best_episode": best_ep,
                    "best_eval_hit_rate": best_eval,
                    "best_avg_return": best_train_avg,
                    "avg_window": int(avg_window),
                    "eval_episodes": int(eval_episodes),
                    "eval_seed": int(eval_seed),
                    "eval_batch_size": int(eval_batch_size),
                    "eval_checkpoint_episodes": np.array(eval_episode_history, dtype=int),
                    "eval_hit_rates": np.array(eval_hit_history, dtype=float),
                    "train_avg_at_checkpoints": np.array(train_avg_history, dtype=float),
                    "state_dim": agent.state_dim,
                    "num_actions": agent.num_actions,
                    "hidden_sizes": agent.hidden_sizes,
                    "epsilon_actions": agent.epsilon_actions,
                }, policy_save_path)

                print(f"[DQN] New BEST (by greedy eval) -> saved {policy_save_path} "
                      f"(ep {best_ep}, eval_hit={best_eval:.3f}, train_avg_last{w}={best_train_avg:.3f})")

            print(f"[DQN] Episode {episodes_done}/{episodes} | "
                  f"train_mean_last{w}={avg_last:.3f} | "
                  f"eval_hit_rate(greedy)={eval_hit:.3f} | "
                  f"explore_eps(last)={float(explore_eps_batch[-1]):.3f} | "
                  f"updates={updates_to_run}")

            next_checkpoint = ((episodes_done // checkpoint_every) + 1) * checkpoint_every

    if best_state_dict is not None:
        agent.policy_net.load_state_dict(best_state_dict)
        agent.target_net.load_state_dict(best_state_dict)
        agent.sync_actor()

        if os.path.exists(policy_save_path):
            print(f"[DQN] Best policy already saved to {policy_save_path} (episode {best_ep}, eval_hit={best_eval:.3f})")
        else:
            torch.save({
                'policy_net_state_dict': best_state_dict,
                'best_episode': best_ep,
                'best_eval_hit_rate': best_eval,
                'best_avg_return': best_train_avg,
                'state_dim': agent.state_dim,
                'num_actions': agent.num_actions,
                'hidden_sizes': agent.hidden_sizes,
                'epsilon_actions': agent.epsilon_actions,
            }, policy_save_path)
            print(f"[DQN] Saved best policy to {policy_save_path}")

    if best_ep is not None:
        print(f"[DQN] BEST CHECKPOINT (by greedy eval) = episode {best_ep}/{episodes} | "
              f"eval_hit_rate={best_eval:.3f} | "
              f"train_mean_last{min(avg_window, len(returns))} at best={best_train_avg:.3f}")

    history = {
        "returns": np.array(returns, float),
        "losses": np.array(losses, float),
        "explore_eps": np.array(explore_eps_history, float),
        "checkpoint_every": checkpoint_every,
        "avg_window": avg_window,
        "eval_checkpoint_episodes": np.array(eval_episode_history, dtype=int),
        "eval_hit_rates": np.array(eval_hit_history, dtype=float),
        "train_avg_at_checkpoints": np.array(train_avg_history, dtype=float),
        "best_checkpoint_episode": best_ep,
        "best_checkpoint_eval_hit_rate": best_eval if best_ep is not None else None,
        "best_checkpoint_train_avg_return": best_train_avg if best_ep is not None else None,
        "best_state_dict": best_state_dict,
    }
    return agent, history


__all__ = ["train_dqn_epsilon_policy"]
