"""
Command-line interface for training/evaluating the DQN epsilon policy.
"""

from __future__ import annotations

import argparse
import os
import numpy as np
import torch

from dp import dp_optimal_policy_discrete

from .evaluation import (
    aggregate_modal_eps_grid,
    evaluate_policies_with_dqn_epsilon,
    mask_low_visit_cells,
    dp_value_to_modal_grid,
)
from .plotting import (
    plot_confidence_grid,
    plot_eps_on_logwealth,
    plot_modal_eps_grid,
    plot_rejection_curves_multi,
    setting_title,
    trailing_mean_every_k,
)
from .rollouts import trace_dqn_epsilon_episode
from .training import train_dqn_epsilon_policy
from .models import DQNEpsilonAgent
from .kelly import _kelly_and_endpoint_from_past
from .features import _features_from_state


def _load_agent_from_checkpoint(policy_path, m, alpha, N):
    if not os.path.exists(policy_path):
        raise FileNotFoundError(f"Policy file not found: {policy_path}")

    checkpoint = torch.load(policy_path, map_location='cpu')

    T = float(np.log(1.0 / alpha))
    mu0, lam_k0, lam_e0, var0 = _kelly_and_endpoint_from_past(
        0.0, 0.0, 0, m,
        eps_cap=1e-3,
        var_floor=0.0,
        shrink_kappa=0.0,
        lcap=None,
    )
    d = len(_features_from_state(
        m, mu0, var0, 0.0, T, 0, N, lam_k0, lam_e0,
        0.0, 0.0, 0.0, 0, 0, 0, 0
    ))
    ckpt_d = checkpoint.get("state_dim", None)
    if ckpt_d is not None and int(ckpt_d) != int(d):
        raise ValueError(
            f"Checkpoint state_dim={ckpt_d} does not match current feature dim={d}. "
            "This checkpoint was trained with a different feature set; retrain."
        )

    dqn_agent = DQNEpsilonAgent(
        state_dim=d,
        epsilon_actions=checkpoint['epsilon_actions'],
        hidden_sizes=checkpoint['hidden_sizes'],
        device=None,
        seed=0,
    )

    dqn_agent.policy_net.load_state_dict(checkpoint['policy_net_state_dict'])
    dqn_agent.target_net.load_state_dict(checkpoint['policy_net_state_dict'])
    dqn_agent.sync_actor()

    return dqn_agent, checkpoint


def main():
    parser = argparse.ArgumentParser(description='Train or evaluate DQN epsilon policy')
    parser.add_argument('mode', type=str, choices=['train', 'evaluate'],
                       help='Mode: "train" to train and evaluate, "evaluate" to only evaluate')
    parser.add_argument('--policy_path', type=str, default='best_dqn_policy.pt',
                       help='Path to saved policy file (for evaluate mode)')
    parser.add_argument('--checkpoint_path', type=str, default='best_dqn_policy.pt',
                       help='Path to save best policy checkpoint (for train mode)')
    parser.add_argument('--training_plot', type=str, default='training_returns.png',
                       help='Filename for training returns plot')
    parser.add_argument('--rejection_plot', type=str, default='rejection_curves.png',
                       help='Filename for rejection curves plot')
    parser.add_argument('--eps_plot', type=str, default='dqn_eps_on_logwealth.png',
                       help='Filename for epsilon on logwealth plot')
    parser.add_argument('--modal_eps_plot', type=str, default='modal_eps_grid.png',
                       help='Filename for modal epsilon grid plot')
    parser.add_argument('--modal_confidence_plot', type=str, default='modal_confidence_grid.png',
                       help='Filename for modal confidence grid plot')
    parser.add_argument('--min_visits', type=int, default=5,
                       help='Minimum visit count to include cell in masked plots')
    parser.add_argument('--modal_grid_trials', type=int, default=5000,
                       help='Number of episodes to use for modal epsilon grid evaluation')
    parser.add_argument('--t_bin_width', type=int, default=1,
                       help='Time bin width for modal epsilon grid (1 = no binning)')
    parser.add_argument('--N', type=int, default=200,
                       help='Horizon length N')
    parser.add_argument('--alpha', type=float, default=0.05,
                       help='Significance level alpha')
    parser.add_argument('--m', type=float, default=0.45,
                       help='Null hypothesis parameter m')
    parser.add_argument('--mu', type=float, default=0.40,
                       help='True mean parameter mu')
    parser.add_argument('--world', type=str, default='beta_mixture',
                       choices=['beta', 'beta_mixture', 'random'],
                       help="World distribution")
    parser.add_argument('--conc', type=float, default=6.0,
                       help="Beta concentration parameter")
    parser.add_argument('--conc_range', type=float, nargs=2, default=[1.0, 12.0],
                       metavar=('MIN', 'MAX'),
                       help='Uniform range for conc when --domain_randomize is enabled')
    parser.add_argument('--episodes', type=int, default=9000,
                       help='Number of training episodes')
    parser.add_argument('--buffer_capacity', type=int, default=10000,
                       help='Replay buffer capacity')
    parser.add_argument('--batch_size', type=int, default=512,
                       help='Training batch size')
    parser.add_argument('--target_update_interval', type=int, default=5000,
                       help='Target network update interval')
    parser.add_argument('--tau', type=float, default=None,
                       help='Soft update coefficient')
    parser.add_argument('--min_buffer_size', type=int, default=1000,
                       help='Minimum buffer size before training starts')
    parser.add_argument('--num_envs', type=int, default=16,
                       help='Number of parallel environments for vectorized rollouts')
    parser.add_argument('--train_freq', type=int, default=1,
                       help='Train every N environment steps')
    parser.add_argument('--actor_update_interval', type=int, default=100,
                       help='CPU actor sync interval (train steps)')
    parser.add_argument('--explore_eps_start', type=float, default=1.0,
                       help='Initial exploration epsilon')
    parser.add_argument('--explore_eps_end', type=float, default=0.05,
                       help='Final exploration epsilon')
    parser.add_argument('--explore_eps_decay', type=float, default=0.998,
                       help='Exploration epsilon decay rate per episode')
    parser.add_argument('--lr', type=float, default=1e-3,
                       help='Learning rate for optimizer')
    parser.add_argument('--checkpoint_every', type=int, default=500,
                        help='Evaluate/save checkpoint every K episodes')
    parser.add_argument('--eval_episodes', type=int, default=2000,
                        help='Number of greedy eval episodes for checkpoint selection')
    parser.add_argument('--eval_seed', type=int, default=12345,
                        help='Fixed seed for greedy eval')
    parser.add_argument('--eval_batch_size', type=int, default=256,
                        help='Batch size for vectorized greedy eval rollouts')
    parser.add_argument('--domain_randomize', action='store_true',
                       help='Enable domain randomization')
    parser.add_argument('--N_range', type=float, nargs=2, default=[50, 500],
                       metavar=('MIN', 'MAX'),
                       help='Range for N when domain_randomize=True')
    parser.add_argument('--m_range', type=float, nargs=2, default=[0.1, 0.9],
                       metavar=('MIN', 'MAX'),
                       help='Range for m when domain_randomize=True')
    parser.add_argument('--difficulty_range', type=float, nargs=2, default=[0.7, 1.3],
                       metavar=('MIN', 'MAX'),
                       help='Difficulty multiplier range when domain_randomize=True')
    parser.add_argument('--mu_clip', type=float, nargs=2, default=[0.02, 0.98],
                       metavar=('MIN', 'MAX'),
                       help='Clipping range for μ when domain_randomize=True')
    parser.add_argument('--lcap', type=float, default=5.0,
                       help='Lambda cap when domain_randomize=True (None to disable)')
    parser.add_argument('--include_star', action='store_true',
                        help='Include STaR baselines')
    parser.add_argument('--include_uniform_hedge', action='store_true',
                        help='Include uniform hedge baseline')
    parser.add_argument('--include_expweights_hedge', action='store_true',
                        help='Include exp-weights hedge baseline')
    parser.add_argument('--expweights_eta', type=float, default=2.0,
                        help='Exp-weights hedge learning rate η')
    parser.add_argument('--expweights_gamma', type=float, default=0.01,
                        help='Rebalancing rate γ for exp-weights hedge')
    parser.add_argument('--expweights_score_mode', type=str, default='shadow',
                        choices=['shadow', 'capital'],
                        help='Score mode for exp-weights hedge')
    args = parser.parse_args()

    N = args.N
    alpha = args.alpha
    m = args.m
    mu = args.mu

    if args.mode == 'train':
        print("[DQN] Starting training...")
        dqn_agent, dqn_hist = train_dqn_epsilon_policy(
            episodes=args.episodes,
            N=N,
            alpha=alpha,
            m=m,
            mu=mu,
            world=args.world,
            conc=args.conc,
            conc_range=tuple(args.conc_range),
            lr=args.lr,
            gamma=1.0,
            buffer_capacity=args.buffer_capacity,
            batch_size=args.batch_size,
            target_update_interval=args.target_update_interval,
            checkpoint_every=args.checkpoint_every,
            explore_eps_start=args.explore_eps_start,
            explore_eps_end=args.explore_eps_end,
            explore_eps_decay=args.explore_eps_decay,
            min_buffer_size=args.min_buffer_size,
            tau=args.tau,
            seed=0,
            num_envs=args.num_envs,
            train_freq=args.train_freq,
            log_loss_every=200,
            actor_update_interval=args.actor_update_interval,
            policy_save_path=args.checkpoint_path,
            eval_episodes=args.eval_episodes,
            eval_seed=args.eval_seed,
            eval_batch_size=args.eval_batch_size,
            domain_randomize=args.domain_randomize,
            N_range=tuple(args.N_range),
            m_range=tuple(args.m_range),
            difficulty_range=tuple(args.difficulty_range),
            mu_clip=tuple(args.mu_clip),
            lcap=args.lcap if args.domain_randomize else None,
        )

        k = int(dqn_hist.get("checkpoint_every", args.checkpoint_every))
        x_tm, y_tm = trailing_mean_every_k(dqn_hist["returns"], k)

        import matplotlib.pyplot as plt
        plt.figure()
        plt.plot(x_tm, y_tm)
        plt.xlabel("episode")
        plt.ylabel(f"mean return over last {k} episodes")
        plt.title(setting_title(f"DQN training return (trailing mean every {k})", m, mu, alpha))
        plt.tight_layout()
        plt.savefig(args.training_plot, dpi=150, bbox_inches='tight')
        plt.close()
        print(f"[DQN] Saved training curve to {args.training_plot}")

        print("[DQN] Evaluating trained policy...")
        t, curves = evaluate_policies_with_dqn_epsilon(
            N=N,
            alpha=alpha,
            trials=5000,
            m=m,
            mu=mu,
            world=args.world,
            conc=args.conc,
            agent=dqn_agent,
            seed=123,
            include_star=args.include_star,
            include_uniform_hedge=args.include_uniform_hedge,
            include_expweights_hedge=args.include_expweights_hedge,
            expweights_eta=args.expweights_eta,
            expweights_gamma=args.expweights_gamma,
            expweights_score_mode=args.expweights_score_mode,
        )
        plot_rejection_curves_multi(
            t,
            curves,
            title=setting_title("Kelly vs fixed ε vs DQN", m, mu, alpha),
            filename=args.rejection_plot,
        )
        print(f"[DQN] Saved rejection curves to {args.rejection_plot}")

        rng = np.random.default_rng(123)

        X, Y_path, eps_path, lam_path, a_path = trace_dqn_epsilon_episode(
            agent=dqn_agent,
            N=N,
            alpha=alpha,
            m=m,
            mu=mu,
            world=args.world,
            conc=args.conc,
            explore_eps=0.0,
            rng=rng,
            stop_on_hit=True,
        )

        plot_eps_on_logwealth(
            Y_path,
            eps_path,
            alpha=alpha,
            title=setting_title("DQN actions along log-wealth trajectory", m, mu, alpha),
            filename=args.eps_plot,
        )
        print(f"Saved: {args.eps_plot}")

        print("[DQN] Generating modal epsilon grid...")
        modal_eps, conf, visits, y_edges, t_edges = aggregate_modal_eps_grid(
            agent=dqn_agent,
            trials=args.modal_grid_trials,
            N=N,
            alpha=alpha,
            m=m,
            mu=mu,
            world=args.world,
            conc=args.conc,
            explore_eps=0.0,
            stop_on_hit=True,
            num_y_bins=40,
            t_bin_width=args.t_bin_width,
            seed=123,
        )

        dp_world = args.world if args.world != "random" else "beta_mixture"
        dp_info = dp_optimal_policy_discrete(
            N=N,
            alpha=alpha,
            m=m,
            mu=mu,
            conc=args.conc,
            world=dp_world,
            eps_cap=1e-3,
            num_y_bins=360,
            mc_samples=6000,
            seed=0,
            y_margin=0.75,
        )
        dp_val_grid = dp_value_to_modal_grid(dp_info, y_edges=y_edges, t_edges=t_edges)

        plot_modal_eps_grid(
            modal_eps=modal_eps,
            y_edges=y_edges,
            epsilon_actions=dqn_agent.epsilon_actions,
            alpha=alpha,
            title=setting_title("Modal action (unmasked)", m, mu, alpha),
            filename=args.modal_eps_plot.replace('.png', '_unmasked.png'),
            t_edges=t_edges,
            dp_value=dp_val_grid,
            hopeless_threshold=0.005,
        )
        print(f"Saved: {args.modal_eps_plot.replace('.png', '_unmasked.png')}")

        plot_confidence_grid(
            confidence=conf,
            y_edges=y_edges,
            alpha=alpha,
            title=setting_title("Modal-action confidence (unmasked)", m, mu, alpha),
            filename=args.modal_confidence_plot.replace('.png', '_unmasked.png'),
            t_edges=t_edges,
        )
        print(f"Saved: {args.modal_confidence_plot.replace('.png', '_unmasked.png')}")

        min_visits = args.min_visits
        modal_eps_m, conf_m, mask = mask_low_visit_cells(
            modal_eps, conf, visits, min_visits=min_visits
        )
        print(f"Masked {mask.sum()} / {mask.size} cells with visits < {min_visits}")

        plot_modal_eps_grid(
            modal_eps=modal_eps_m,
            y_edges=y_edges,
            epsilon_actions=dqn_agent.epsilon_actions,
            alpha=alpha,
            title=setting_title(f"Modal action (masked: visits < {min_visits})", m, mu, alpha),
            filename=args.modal_eps_plot,
            t_edges=t_edges,
            dp_value=dp_val_grid,
            hopeless_threshold=0.005,
        )
        print(f"Saved: {args.modal_eps_plot}")

        plot_confidence_grid(
            confidence=conf_m,
            y_edges=y_edges,
            alpha=alpha,
            title=setting_title(f"Modal-action confidence (masked: visits < {min_visits})", m, mu, alpha),
            filename=args.modal_confidence_plot,
            t_edges=t_edges,
        )
        print(f"Saved: {args.modal_confidence_plot}")

    elif args.mode == 'evaluate':
        dqn_agent, checkpoint = _load_agent_from_checkpoint(args.policy_path, m, alpha, N)

        best_ep = checkpoint.get("best_episode", None)
        best_avg = checkpoint.get("best_avg_return", None)
        best_eval = checkpoint.get("best_eval_hit_rate", None)

        msg = f"[DQN] Loaded policy from episode {best_ep}"
        if best_eval is not None:
            msg += f" | eval_hit_rate={best_eval:.3f}"
        if best_avg is not None:
            msg += f" | train_avg_return(ref)={best_avg:.3f}"
        print(msg)

        eval_ckpt_eps = checkpoint.get("eval_checkpoint_episodes", None)
        eval_hit_rates = checkpoint.get("eval_hit_rates", None)
        train_avg_at_ckpt = checkpoint.get("train_avg_at_checkpoints", None)
        if eval_ckpt_eps is not None and eval_hit_rates is not None:
            eval_ckpt_eps = np.asarray(eval_ckpt_eps)
            eval_hit_rates = np.asarray(eval_hit_rates)
            print(f"[DQN] Loaded eval_hit_rates with {len(eval_hit_rates)} checkpoints "
                  f"(latest={eval_hit_rates[-1]:.3f} at ep {eval_ckpt_eps[-1] if len(eval_ckpt_eps)>0 else 'n/a'})")
        if train_avg_at_ckpt is not None:
            train_avg_at_ckpt = np.asarray(train_avg_at_ckpt)
            latest_train_avg = train_avg_at_ckpt[-1] if len(train_avg_at_ckpt) > 0 else float('nan')
            print(f"[DQN] Loaded train_avg_at_checkpoints with {len(train_avg_at_ckpt)} entries "
                  f"(latest={latest_train_avg:.3f})")
        if (eval_ckpt_eps is not None) and (eval_hit_rates is not None) and (train_avg_at_ckpt is not None):
            import matplotlib.pyplot as plt
            plt.figure()
            plt.plot(eval_ckpt_eps, train_avg_at_ckpt, label="train mean (last window)")
            plt.plot(eval_ckpt_eps, eval_hit_rates, label="eval hit-rate (greedy)")
            plt.xlabel("episode (checkpoint)")
            plt.ylabel("metric")
            plt.title(setting_title("Checkpoint metrics: train mean & eval hit-rate", m, mu, alpha))
            plt.legend()
            plt.tight_layout()
            eval_plot_path = args.training_plot.replace(".png", "_eval.png")
            plt.savefig(eval_plot_path, dpi=150, bbox_inches="tight")
            plt.close()
            print(f"[DQN] Saved checkpoint metrics plot to {eval_plot_path}")

        print("[DQN] Evaluating loaded policy...")
        t, curves = evaluate_policies_with_dqn_epsilon(
            N=N,
            alpha=alpha,
            trials=5000,
            m=m,
            mu=mu,
            world=args.world,
            conc=args.conc,
            agent=dqn_agent,
            seed=123,
            include_star=args.include_star,
            include_uniform_hedge=args.include_uniform_hedge,
            include_expweights_hedge=args.include_expweights_hedge,
            expweights_eta=args.expweights_eta,
            expweights_gamma=args.expweights_gamma,
            expweights_score_mode=args.expweights_score_mode,
        )
        plot_rejection_curves_multi(
            t,
            curves,
            title=setting_title("Kelly vs fixed ε vs DQN", m, mu, alpha),
            filename=args.rejection_plot,
        )
        print(f"[DQN] Saved rejection curves to {args.rejection_plot}")

        rng = np.random.default_rng(123)

        X, Y_path, eps_path, lam_path, a_path = trace_dqn_epsilon_episode(
            agent=dqn_agent,
            N=N,
            alpha=alpha,
            m=m,
            mu=mu,
            world=args.world,
            conc=args.conc,
            explore_eps=0.0,
            rng=rng,
            stop_on_hit=True,
        )

        plot_eps_on_logwealth(
            Y_path,
            eps_path,
            alpha=alpha,
            title=setting_title("DQN actions along log-wealth trajectory", m, mu, alpha),
            filename=args.eps_plot,
        )
        print(f"Saved: {args.eps_plot}")

        print("[DQN] Generating modal epsilon grid...")
        modal_eps, conf, visits, y_edges, t_edges = aggregate_modal_eps_grid(
            agent=dqn_agent,
            trials=args.modal_grid_trials,
            N=N,
            alpha=alpha,
            m=m,
            mu=mu,
            world=args.world,
            conc=args.conc,
            explore_eps=0.0,
            stop_on_hit=True,
            num_y_bins=40,
            t_bin_width=args.t_bin_width,
            seed=123,
        )

        dp_world = args.world if args.world != "random" else "beta_mixture"
        dp_info = dp_optimal_policy_discrete(
            N=N,
            alpha=alpha,
            m=m,
            mu=mu,
            conc=args.conc,
            world=dp_world,
            eps_cap=1e-3,
            num_y_bins=360,
            mc_samples=6000,
            seed=0,
            y_margin=0.75,
        )
        dp_val_grid = dp_value_to_modal_grid(dp_info, y_edges=y_edges, t_edges=t_edges)

        plot_modal_eps_grid(
            modal_eps=modal_eps,
            y_edges=y_edges,
            epsilon_actions=dqn_agent.epsilon_actions,
            alpha=alpha,
            title=setting_title("Modal action (unmasked)", m, mu, alpha),
            filename=args.modal_eps_plot.replace('.png', '_unmasked.png'),
            t_edges=t_edges,
            dp_value=dp_val_grid,
            hopeless_threshold=0.005,
        )
        print(f"Saved: {args.modal_eps_plot.replace('.png', '_unmasked.png')}")

        plot_confidence_grid(
            confidence=conf,
            y_edges=y_edges,
            alpha=alpha,
            title=setting_title("Modal-action confidence (unmasked)", m, mu, alpha),
            filename=args.modal_confidence_plot.replace('.png', '_unmasked.png'),
            t_edges=t_edges,
        )
        print(f"Saved: {args.modal_confidence_plot.replace('.png', '_unmasked.png')}")

        min_visits = args.min_visits
        modal_eps_m, conf_m, mask = mask_low_visit_cells(
            modal_eps, conf, visits, min_visits=min_visits
        )
        print(f"Masked {mask.sum()} / {mask.size} cells with visits < {min_visits}")

        plot_modal_eps_grid(
            modal_eps=modal_eps_m,
            y_edges=y_edges,
            epsilon_actions=dqn_agent.epsilon_actions,
            alpha=alpha,
            title=setting_title(f"Modal action (masked: visits < {min_visits})", m, mu, alpha),
            filename=args.modal_eps_plot,
            t_edges=t_edges,
            dp_value=dp_val_grid,
            hopeless_threshold=0.005,
        )
        print(f"Saved: {args.modal_eps_plot}")

        plot_confidence_grid(
            confidence=conf_m,
            y_edges=y_edges,
            alpha=alpha,
            title=setting_title(f"Modal-action confidence (masked: visits < {min_visits})", m, mu, alpha),
            filename=args.modal_confidence_plot,
            t_edges=t_edges,
        )
        print(f"Saved: {args.modal_confidence_plot}")


if __name__ == "__main__":
    main()
