"""
Evaluate script: load a checkpoint produced by main.py and report average score.

Protocol:
- ε-greedy with ε=0.05 by default
- Run 30 episodes by default
- Uses the same Atari env wrappers as training (DQN-style max-over-2-frames + action repeat wrapper)
"""

import argparse
import math
import os
import sys
import glob
from typing import Optional, Tuple

import cv2
import gymnasium as gym
from gymnasium.wrappers import TimeLimit
import numpy as np
import torch

# Register ALE environments (optional dependency)
try:
    import ale_py  # type: ignore

    gym.register_envs(ale_py)
except Exception as e:
    ale_py = None
    if "--help" not in sys.argv and "-h" not in sys.argv:
        print(f"[Warning] ale_py import failed: {e}")
        print("[Warning] Atari envs may be unavailable until you install ale_py (and ROMs).")

_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
if _THIS_DIR not in sys.path:
    sys.path.insert(0, _THIS_DIR)

from config import get_params
from model import make_q_network
from atari_wrappers import AtariMaxPoolWrapper, NoopResetEnv


def preprocessing(img: np.ndarray) -> np.ndarray:
    """RGB -> grayscale 84x84"""
    img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    img = cv2.resize(img, (84, 84), interpolation=cv2.INTER_AREA)
    return img


def stack_frames(
    stacked_frames: Optional[np.ndarray],
    state: np.ndarray,
    *,
    is_new_episode: bool,
    num_stack: int,
) -> np.ndarray:
    frame = preprocessing(state)
    if is_new_episode or stacked_frames is None:
        return np.stack([frame for _ in range(num_stack)], axis=0)
    stacked_frames = stacked_frames[1:, ...]
    return np.concatenate([stacked_frames, np.expand_dims(frame, axis=0)], axis=0)


def _gym_make_strict(env_name: str, **kwargs):
    """Never silently drops Atari protocol kwargs."""
    try:
        return gym.make(env_name, **kwargs)
    except TypeError as e:
        msg = (
            f"[Env Error] Failed to create env={env_name!r} with required kwargs={kwargs}.\n"
            f"Original error: {e}\n"
            f"Suggested fixes: upgrade `gymnasium`/`ale-py`, and ensure you're using ALE v5 environments."
        )
        raise TypeError(msg) from e


def make_atari(
    env_name: str,
    *,
    max_episode_steps: Optional[int],
    frame_skip: int,
    sticky_action_prob: float,
    full_action_space: bool,
    noop_on_reset: bool,
    noop_max: int,
):
    """
    Same protocol as main.py / evaluate.py:
    - ALE frameskip=1
    - ALE sticky actions via repeat_action_probability
    - External wrapper does action repeat + max-over-2-frames (flicker reduction)
    """
    env = _gym_make_strict(
        env_name,
        frameskip=1,
        repeat_action_probability=float(sticky_action_prob),
        full_action_space=bool(full_action_space),
    )
    env = AtariMaxPoolWrapper(env, frame_skip=int(frame_skip))

    actual_max_steps = max_episode_steps
    if actual_max_steps is None:
        actual_max_steps = int(math.ceil(18000 / max(1, int(frame_skip))))

    if not isinstance(env, TimeLimit):
        env = TimeLimit(env, max_episode_steps=int(actual_max_steps))
    else:
        env._max_episode_steps = int(actual_max_steps)

    if bool(noop_on_reset):
        env = NoopResetEnv(env, noop_max=int(noop_max))
    return env


@torch.no_grad()
def select_action_eps_greedy(
    q_net: torch.nn.Module,
    stacked_state: np.ndarray,
    *,
    eps: float,
    rng: np.random.Generator,
    n_actions: int,
    device: torch.device,
) -> int:
    if rng.random() < float(eps):
        return int(rng.integers(0, n_actions))
    state_t = torch.from_numpy(stacked_state[None]).to(device)
    if state_t.dtype != torch.uint8:
        state_t = state_t.to(torch.uint8)
    q_vals = q_net(state_t)
    return int(torch.argmax(q_vals, dim=1).item())


def load_checkpoint_and_config(
    checkpoint_path: str, device: torch.device
) -> Tuple[dict, dict]:
    def _try_load(p: str) -> dict | None:
        try:
            # Some PyTorch versions only support `weights_only` for certain formats.
            # Try the modern API first; fall back to legacy signature if needed.
            try:
                return torch.load(p, map_location=device, weights_only=False)
            except TypeError as e:
                if "weights_only" in str(e):
                    return torch.load(p, map_location=device)
                raise
        except Exception as e:
            # Common failure: corrupted/incomplete zip archive (interrupted save).
            msg = str(e)
            if "PytorchStreamReader failed reading zip archive" in msg or "failed finding central directory" in msg:
                return None
            # Other errors could still be due to incompatibility/corruption; treat as non-loadable.
            return None

    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

    # If user passed a directory, try to resolve a checkpoint inside it.
    candidates: list[str] = []
    if os.path.isdir(checkpoint_path):
        d = checkpoint_path
        candidates.extend(sorted(glob.glob(os.path.join(d, "checkpoint_step_*.pth"))))
        candidates.append(os.path.join(d, "checkpoint.pth"))
        candidates.extend(sorted(glob.glob(os.path.join(d, "*.pth"))))
    else:
        # User passed a file path. Try it first, then fall back to neighbors.
        candidates.append(checkpoint_path)
        d = os.path.dirname(os.path.abspath(checkpoint_path))
        step_files = sorted(glob.glob(os.path.join(d, "checkpoint_step_*.pth")))
        # Prefer larger step numbers (latest).
        def _step_key(p: str) -> int:
            base = os.path.basename(p)
            try:
                num = base.replace("checkpoint_step_", "").replace(".pth", "")
                return int(num)
            except Exception:
                return -1
        step_files = sorted(step_files, key=_step_key, reverse=True)
        candidates.extend(step_files)
        candidates.append(os.path.join(d, "checkpoint.pth"))
        # Also try any other .pth in the same directory.
        candidates.extend(sorted(glob.glob(os.path.join(d, "*.pth"))))

    # De-dup while keeping order
    seen = set()
    uniq: list[str] = []
    for c in candidates:
        c_abs = os.path.abspath(c)
        if c_abs in seen:
            continue
        seen.add(c_abs)
        uniq.append(c)
    candidates = uniq

    loaded_path: str | None = None
    ckpt: dict | None = None
    for c in candidates:
        if not os.path.exists(c) or os.path.isdir(c):
            continue
        ckpt = _try_load(c)
        if ckpt is not None:
            loaded_path = c
            break

    if ckpt is None:
        raise RuntimeError(
            "Failed to load checkpoint (file may be corrupted/incomplete). "
            f"Tried: {candidates[:10]}{' ...' if len(candidates) > 10 else ''}"
        )

    if loaded_path is not None and os.path.abspath(loaded_path) != os.path.abspath(checkpoint_path):
        print(f"[Warning] Requested checkpoint failed to load; fell back to: {loaded_path}")

    config = ckpt.get("config", None)
    if config is None:
        print("[Warning] No config found in checkpoint; falling back to get_params() defaults.")
        config = get_params()
    return ckpt, config


def _resolve_env_settings(config: dict, *, env_name_override: Optional[str]) -> dict:
    """
    Resolve evaluation env settings to match training protocol.

    Key invariants the user cares about:
    - frame_skip=5 (unless checkpoint config says otherwise)
    - sticky_action_prob=0.25 (unless checkpoint config says otherwise)
    - noop_on_reset / noop_max match training (DQN-style random starts)
    - max_episode_steps matches training cap (if missing, use DQN 5-min cap: ceil(18000/frame_skip))
    """
    env_name = str(env_name_override or config.get("env_name", "ALE/MontezumaRevenge-v5"))
    frame_skip = int(config.get("frame_skip", 5))
    sticky_action_prob = float(config.get("sticky_action_prob", 0.25))
    full_action_space = bool(config.get("full_action_space", True))
    noop_on_reset = bool(config.get("noop_on_reset", True))
    noop_max = int(config.get("noop_max", 30))

    max_episode_steps = config.get("max_episode_steps", None)
    if max_episode_steps is None:
        max_episode_steps = int(math.ceil(18000 / max(1, int(frame_skip))))
    else:
        max_episode_steps = int(max_episode_steps)

    frame_stack = int(config.get("frame_stack", 4))

    return {
        "env_name": env_name,
        "frame_skip": frame_skip,
        "sticky_action_prob": sticky_action_prob,
        "full_action_space": full_action_space,
        "noop_on_reset": noop_on_reset,
        "noop_max": noop_max,
        "max_episode_steps": max_episode_steps,
        "frame_stack": frame_stack,
    }


def evaluate_eps_greedy(
    checkpoint_path: str,
    *,
    eps: float = 0.05,
    n_episodes: int = 30,
    seed: int = 0,
    env_name_override: Optional[str] = None,
    render: bool = False,
) -> float:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    ckpt, config = load_checkpoint_and_config(checkpoint_path, device)

    env_cfg = _resolve_env_settings(config, env_name_override=env_name_override)
    env_name = env_cfg["env_name"]
    frame_stack = int(env_cfg["frame_stack"])
    state_shape = (frame_stack, 84, 84)

    # Build env once to get action space
    tmp_env = make_atari(
        env_name,
        max_episode_steps=int(env_cfg["max_episode_steps"]),
        frame_skip=int(env_cfg["frame_skip"]),
        sticky_action_prob=float(env_cfg["sticky_action_prob"]),
        full_action_space=bool(env_cfg["full_action_space"]),
        noop_on_reset=bool(env_cfg["noop_on_reset"]),
        noop_max=int(env_cfg["noop_max"]),
    )
    n_actions = int(tmp_env.action_space.n)
    tmp_env.close()

    q_net_type = str(config.get("q_net_type", "nature")).lower()
    q_net = make_q_network(state_shape, n_actions, q_net_type).to(device)
    if ckpt.get("q_net", None) is None:
        raise ValueError("Checkpoint does not contain 'q_net' weights; cannot evaluate.")
    q_net.load_state_dict(ckpt["q_net"])
    q_net.eval()

    env = make_atari(
        env_name,
        max_episode_steps=int(env_cfg["max_episode_steps"]),
        frame_skip=int(env_cfg["frame_skip"]),
        sticky_action_prob=float(env_cfg["sticky_action_prob"]),
        full_action_space=bool(env_cfg["full_action_space"]),
        noop_on_reset=bool(env_cfg["noop_on_reset"]),
        noop_max=int(env_cfg["noop_max"]),
    )

    rng = np.random.default_rng(int(seed))
    returns = []
    try:
        for ep in range(int(n_episodes)):
            obs, _ = env.reset(seed=int(rng.integers(0, 2**31 - 1)))
            stacked = stack_frames(None, obs, is_new_episode=True, num_stack=frame_stack)
            done = False
            ep_ret = 0.0

            while not done:
                if render:
                    env.render()
                a = select_action_eps_greedy(
                    q_net,
                    stacked,
                    eps=float(eps),
                    rng=rng,
                    n_actions=n_actions,
                    device=device,
                )
                next_obs, reward, terminated, truncated, _info = env.step(a)
                done = bool(terminated or truncated)
                ep_ret += float(reward)
                if not done:
                    stacked = stack_frames(stacked, next_obs, is_new_episode=False, num_stack=frame_stack)

            returns.append(ep_ret)
    finally:
        env.close()

    mean_ret = float(np.mean(returns)) if returns else float("nan")
    std_ret = float(np.std(returns)) if returns else float("nan")
    print(
        f"[EvalScore] checkpoint={checkpoint_path}\n"
        f"  env={env_name}\n"
        f"  env_cfg: frame_skip={int(env_cfg['frame_skip'])} sticky_action_prob={float(env_cfg['sticky_action_prob']):.3f} "
        f"max_episode_steps={int(env_cfg['max_episode_steps'])} "
        f"noop_on_reset={bool(env_cfg['noop_on_reset'])} noop_max={int(env_cfg['noop_max'])} "
        f"full_action_space={bool(env_cfg['full_action_space'])}\n"
        f"  episodes={int(n_episodes)} eps={float(eps):.3f} seed={int(seed)}\n"
        f"  mean={mean_ret:.2f} std={std_ret:.2f} min={float(np.min(returns)):.2f} max={float(np.max(returns)):.2f}"
    )
    return mean_ret


def main() -> int:
    parser = argparse.ArgumentParser(description="Evaluate main.py checkpoint with epsilon-greedy and report mean score.")
    parser.add_argument(
        "--checkpoint",
        type=str,
        default=os.path.join(os.path.dirname(__file__), "results", "checkpoint_step_20000000.pth"),
        help="Path to checkpoint .pth saved by main.py (default: pbe_dqn_test/results/checkpoint_step_20000000.pth).",
    )
    parser.add_argument("--episodes", type=int, default=30, help="Number of evaluation episodes (default: 30)")
    parser.add_argument("--epsilon", type=float, default=0.05, help="Epsilon for epsilon-greedy evaluation (default: 0.05)")
    parser.add_argument("--seed", type=int, default=0, help="RNG seed for evaluation episode seeds (default: 0)")
    parser.add_argument("--env_name", type=str, default=None, help="Override env_name (default: from checkpoint config)")
    parser.add_argument("--render", action="store_true", help="Render during evaluation (may be slow)")
    parser.add_argument(
        "--dry_run",
        action="store_true",
        help="Only load checkpoint and build the Q network (no env). Useful when ALE is not installed.",
    )
    args = parser.parse_args()

    if bool(args.dry_run):
        device = torch.device("cpu")
        ckpt, config = load_checkpoint_and_config(args.checkpoint, device)
        q_net_type = str(config.get("q_net_type", "nature")).lower()
        sd = ckpt.get("q_net", None)
        if sd is None:
            raise ValueError("Checkpoint does not contain 'q_net' weights; cannot dry-run.")
        if not isinstance(sd, dict) or "head.bias" not in sd:
            raise ValueError("Unsupported q_net state_dict format; missing 'head.bias'.")
        n_actions = int(sd["head.bias"].numel())
        frame_stack = int(config.get("frame_stack", 4))
        state_shape = (frame_stack, 84, 84)
        q_net = make_q_network(state_shape, n_actions, q_net_type).to(device)
        q_net.load_state_dict(sd)
        print(
            f"[DryRun] checkpoint={args.checkpoint}\n"
            f"  q_net_type={q_net_type} net={type(q_net).__name__} n_actions={n_actions} state_shape={state_shape}"
        )
        return 0

    evaluate_eps_greedy(
        args.checkpoint,
        eps=float(args.epsilon),
        n_episodes=int(args.episodes),
        seed=int(args.seed),
        env_name_override=args.env_name,
        render=bool(args.render),
    )
    return 0


if __name__ == "__main__":
    raise SystemExit(main())

