"""
FPVR-DQN Configuration
Configuration for the Future-Past Visitation Redundancy (FPVR) exploration algorithm
with experience replay (DQN-style training loop).

References:
- This supplementary package contains all required configuration defaults.
"""
import argparse
import math


def get_params():
    parser = argparse.ArgumentParser(
        description="FPVR with Experience Replay - DQN-style exploration")

    # ========== Environment ==========
    parser.add_argument("--env_name", default="ALE/MontezumaRevenge-v5", type=str,
                        help="Gymnasium environment id")
    parser.add_argument("--frame_skip", default=5, type=int,
                        help="Frame skip value. If use_max_pool=True, this is handled by the wrapper. "
                             "If use_max_pool=False, this is passed to the ALE env as frameskip.")
    parser.add_argument("--sticky_action_prob", default=0.25, type=float,
                        help="Sticky action probability (ALE repeat_action_probability). "
                             "Paper setting uses 0.25 for stochastic Atari.")
    parser.add_argument("--full_action_space", action=argparse.BooleanOptionalAction, default=False,
                        help="Use full ALE action set (A=18) when supported.")
    parser.add_argument("--noop_on_reset", action=argparse.BooleanOptionalAction, default=True,
                        help="If enabled, at the start of every episode reset, perform random NOOP actions "
                             "in [0, noop_max] (DQN-style random starts).")
    parser.add_argument("--noop_max", default=30, type=int,
                        help="Maximum number of random NOOP actions performed on episode reset "
                             "(uniformly sampled in [0, noop_max]).")
    parser.add_argument("--max_episode_steps", default=None, type=int,
                        help="Maximum episode length in agent steps (number of env.step() calls). "
                             "If not set, defaults to ceil(18000 / frame_skip) to match the DQN 5-minute cap "
                             "(18,000 ALE frames).")
    
    # ========== Training ==========
    parser.add_argument("--total_frames", default=int(100e6 + 1e3), type=int,
                        help="Total ALE frames to train (paper default: 100M). "
                             "If --total_timesteps is set, it takes precedence.")
    parser.add_argument("--total_timesteps", default=None, type=int,
                        help="Total env.step() calls to train. Overrides --total_frames if provided.")
    parser.add_argument("--learning_starts", default=10000, type=int,
                        help="Number of steps before training starts")
    parser.add_argument("--train_freq", default=4, type=int,
                        help="Train every N steps")
    parser.add_argument("--gradient_steps", default=1, type=int,
                        help="Number of gradient steps per training call")
    parser.add_argument("--batch_size", default=32, type=int,
                        help="Batch size for SR training")
    parser.add_argument("--sf_lr", default=5e-4, type=float,
                        help="Learning rate for the FPVR network optimizer (FPVRNetwork). "
                             "If not set, uses --lr (deprecated) or 2.5e-4.")
    parser.add_argument("--lr", default=2.5e-4, type=float,
        help="DEPRECATED: alias for --sf_lr (kept for backward compatibility).")
    parser.add_argument("--frame_stack", default=4, type=int,
        help="Number of frames stacked as state input")
    parser.add_argument("--num_sf_channel", default=1, type=int,
                        help="Number of stacked frames fed into the FPVR network. "
                             "If not set, uses frame_stack. (Replay buffer/env still use frame_stack.)")

    # ========== Reward Processing ==========
    parser.add_argument("--reward_clipping", default="clip", type=str,
                        choices=["none", "sign", "clip"],
                        help="Reward clipping as in the original DQN setup.\n"
                             "- sign: r <- sign(r) in {-1,0,1} (DQN default)\n"
                             "- clip: r <- clip(r, -1, 1)\n"
                             "- none: use raw rewards")
    
    # ========== Replay Buffer ==========
    parser.add_argument("--buffer_size", default=1000000, type=int,
                        help="Replay buffer capacity")
    parser.add_argument("--prioritized_replay", action="store_true",
                        help="Use prioritized experience replay")
    parser.add_argument("--prioritized_alpha", default=0.4, type=float,
                        help="PER alpha (priority exponent)")
    parser.add_argument("--prioritized_beta", default=0.4, type=float,
                        help="PER beta (importance sampling correction)")
    
    # ========== FPVR Core ==========
    parser.add_argument(
        "--fpvr_lambda_c",
        default=0.9,
        type=float,
        help="Decay factor λ for the persistence representation (past feature accumulator).",
    )
    parser.add_argument("--phi_dim", default=1024, type=int,
                        help="Feature dimension φ")
    parser.add_argument("--sf_gamma", default=0.5, type=float,
                        help="Successor feature discount factor")
    parser.add_argument(
        "--sf_target",
        default="min_redundancy",
        type=str,
        help="SR target policy (canonical): uniform_policy | min_redundancy",
    )
    # Note: SR always considers done (no sr_no_done flag).
    # Note: redundancy score is always cosine_similarity (no composite_novelty flag).
    
    # ========== Whitening ==========
    parser.add_argument("--whitening_update_every", default=1000, type=int,
                        help="Update ZCA matrix every N training steps")
    parser.add_argument("--whitening_ema_alpha", default=0.001, type=float,
                        help="EMA alpha for whitening statistics")
    parser.add_argument("--whitening_eps", default=1e-10, type=float,
                        help="Numerical regularizer ε for ZCA whitening. "
                             "Implements (Σ + εI)^(-1/2).")
    parser.add_argument("--cov_buffer", default=10000, type=int,
                        help="Number of most recent feature samples kept for covariance/mean estimation")
    # Note: statistics_method is always 'buffer' (no online option).
    # Note: FPVRNetwork has no reconstruction branch (no recon_w option).
    parser.add_argument("--sr_coeff", default=1.0, type=float,
                        help="SR loss coefficient")
    
    # ========== Logging & Saving ==========
    parser.add_argument("--interval", default=100, type=int,
                        help="Print/log interval (in training steps)")
    parser.add_argument("--eval_interval", default=100000, type=int,
                        help="Run Q-network epsilon-greedy evaluation every N training steps (0 to disable). "
                             "Evaluation uses Q-network only (ignores FPVR).")
    parser.add_argument("--eval_epsilon", default=0.05, type=float,
                        help="Fixed epsilon for epsilon-greedy evaluation of Q policy (Q-network only).")
    parser.add_argument("--save_interval", default=10e6, type=int,
                        help="Checkpoint save interval (in training steps)")
    parser.add_argument("--save_optimizers", action=argparse.BooleanOptionalAction, default=False,
                        help="Whether to save optimizer states in checkpoints. "
                             "Disabling makes checkpoints much smaller but you lose optimizer momentum/Adam moments "
                             "when resuming training.")
    parser.add_argument("--save_q_target", action=argparse.BooleanOptionalAction, default=True,
                        help="Whether to save the Q target network weights. "
                             "If disabled, target will be reconstructed from q_net on load.")
    parser.add_argument("--checkpoint_fp16", action=argparse.BooleanOptionalAction, default=False,
                        help="Store model weights in float16 inside the checkpoint to reduce file size. "
                             "Recommended only when --no-save_optimizers (weights-only checkpoints).")
    parser.add_argument("--gif_interval", default=200, type=int,
                        help="GIF save interval (in episodes, 0 to disable)")
    parser.add_argument("--gif_length", default=5000, type=int,
                        help="GIF episode length")
    parser.add_argument("--gif_fps", default=30, type=int,
                        help="GIF frames per second")
    parser.add_argument("--eval_episodes", default=5, type=int,
                        help="Number of episodes for evaluation")
    
    # ========== Misc ==========
    parser.add_argument("--seed", default=1, type=int,
                        help="Random seed")
    parser.add_argument("--num_seeds", default=5, type=int,
                        help="Number of independent runs (seeds) to execute sequentially. Paper uses 10 seeds.")
    parser.add_argument("--do_test", action="store_true",
                        help="Test mode (load and visualize)")
    parser.add_argument("--render", action="store_true",
                        help="Render environment")
    parser.add_argument("--train_from_scratch", action="store_false",
                        help="Continue from checkpoint if False")
    parser.add_argument("--verbose", action="store_true",
                        help="Print detailed action selection info")
    # IMPORTANT: do NOT use type=bool with argparse; `--reset_c False` would still parse as True.
    # Use BooleanOptionalAction so users can pass `--reset-c/--no-reset-c`.
    parser.add_argument(
        "--reset_c",
        action=argparse.BooleanOptionalAction,
        default=True,
        help="Reset c vector to zero at the start of each episode (use --no-reset-c to disable).",
    )
    # Note: fixed-state c mode / reset_c / fpvr_only are removed (always run DDQN + SR training).
    
    # ========== Exploration (ε-greedy) ==========
    parser.add_argument("--eps_start", default=1.0, type=float,
                        help="Initial epsilon value for ε-greedy exploration")
    parser.add_argument("--eps_end", default=0.1, type=float,
                        help="Final (minimum) epsilon value for ε-greedy exploration")
    parser.add_argument("--eps_decay_frames", default=int(1e6), type=int,
                        help="Number of ALE frames over which epsilon decays from eps_start to eps_end "
                             "(paper default: 1,000,000 frames).")
    # Backward-compat alias (kept to avoid breaking old scripts): treated as frames.
    parser.add_argument("--eps_decay_steps", default=None, type=int,
                        help="DEPRECATED: use --eps_decay_frames. If set, it will be treated as frames.")
    
    # ========== DDQN Parameters ==========
    parser.add_argument("--dqn_type", default="dqn", type=str,
                        choices=["dqn", "ddqn"],
                        help="Which TD target to use for Q-learning.\n"
                             "- dqn: y = r + gamma * max_a Q_target(s', a)\n"
                             "- ddqn: a* = argmax_a Q_online(s', a), y = r + gamma * Q_target(s', a*)")
    parser.add_argument("--q_lr", default=2.5e-4, type=float,
                        help="Learning rate (step-size) for Q-network RMSprop (paper default: 2.5e-4).")
    parser.add_argument("--q_gamma", default=0.99, type=float,
                        help="Discount factor for Q-learning TD target")
    parser.add_argument("--q_target_update", default=40000, type=int,
                        help="Target Q-network update frequency (in training steps)")
    parser.add_argument("--q_coeff", default=1.0, type=float,
                        help="Q-loss coefficient")
    parser.add_argument(
        "--q_net_type",
        default="iclr",
        type=str,
        choices=["nature", "iclr"],
        help="Q-network architecture.\n"
             "- nature: standard Nature DQN CNN (current default)\n"
             "- iclr: ICLR-style CNN from function_approximation (Q branch only, no extra heads)",
    )
    # Note: alternating period mode is removed (always use combined decision).
    parser.add_argument("--policy_type", default="q_bias", type=str,
                        choices=["q_bias", "filtered_zscore"],
                        help="How to combine Q and FPVR redundancy score for action selection.\n"
                             "q_bias: scores = raw_Q - alpha * zscore(redundancy)\n"
                             "filtered_zscore: if abs(max_a Q(s,a)) < q_abs_threshold, use ONLY redundancy for decision; "
                             "otherwise use zscore(Q) - alpha * zscore(redundancy).")
    parser.add_argument("--policy_alpha", default=0.001, type=float,
        help="Fixed weight alpha for FPVR redundancy term.")
    parser.add_argument("--q_abs_threshold", default=0.1, type=float,
        help="Threshold for filtered_zscore. If abs(max_a Q(s,a)) is below this value, "
                             "ignore Q and choose actions using only the FPVR redundancy score.")
    
    # ========== Mixed Monte Carlo ==========
    # IMPORTANT: do NOT use type=bool with argparse; `--mixed_mc False` would still parse as True.
    # Use BooleanOptionalAction so users can pass `--mixed-mc/--no-mixed-mc`, etc.
    parser.add_argument(
        "--mixed_mc",
        action=argparse.BooleanOptionalAction,
        default=True,
        help="Enable mixed Monte Carlo: combine 1-step and n-step TD targets (use --no-mixed-mc to disable).",
    )
    parser.add_argument(
        "--full_mc",
        action=argparse.BooleanOptionalAction,
        default=True,
        help="If enabled, use *full-episode* Monte Carlo return-to-go in the mixed target "
             "(use --no-full-mc to disable). When enabled, --n_step is ignored.",
    )
    parser.add_argument("--n_step", default=200, type=int,
                        help="Number of steps for n-step return (used when mixed_mc is enabled)")
    parser.add_argument("--mixed_mc_weight", default=0.5, type=float,
                        help="Weight for n-step return in mixed target (1-step weight = 1 - mixed_mc_weight)")
    
    # Parse args; argparse defaults are the single source of truth (avoid stale duplicate defaults).
    total_params = vars(parser.parse_args())

    # Validate SR target naming.
    sf_t = str(total_params.get("sf_target", "min_redundancy"))
    if sf_t not in ("uniform_policy", "min_redundancy"):
        raise ValueError(f"Invalid --sf_target={sf_t!r}. Use uniform_policy|min_redundancy.")
    total_params["sf_target"] = sf_t

    # DQN-style episode cap: 5 minutes of gameplay ~= 18,000 ALE frames.
    # Our TimeLimit counts *agent steps*; one agent step corresponds to ~frame_skip ALE frames.
    if total_params.get("max_episode_steps", None) is None:
        frame_skip = int(total_params.get("frame_skip", 5))
        total_params["max_episode_steps"] = int(math.ceil(18000 / max(1, frame_skip)))

    # Backward-compat: map deprecated --lr -> --sf_lr when sf_lr is not explicitly set.
    # Keep only one canonical key (`sf_lr`) downstream.
    if total_params.get("sf_lr", None) is None:
        if total_params.get("lr", None) is not None:
            total_params["sf_lr"] = float(total_params["lr"])
        else:
            total_params["sf_lr"] = 2.5e-4
    total_params.pop("lr", None)
    
    print("="*80)
    print("FPVR-DQN Configuration")
    frame_stack = int(total_params["frame_stack"])
    total_params["state_shape"] = (frame_stack, 84, 84)
    total_params["n_actions"] = None  # filled in main.py after env creation

    print("="*80)
    for k, v in total_params.items():
        print(f"  {k:25s} = {v}")
    print("="*80)
    
    return total_params

