"""
Evaluation script: load a trained checkpoint and evaluate the policy using the DDQN Q-network.

This script uses the Q-network only (greedy action selection) and ignores FPVR (SR/whitening/c) branches.
"""
import gymnasium as gym
from gymnasium.wrappers import TimeLimit
import numpy as np
import torch
import os
import sys
import argparse
import cv2
from tqdm import tqdm
import math

# Register ALE environments
try:
    import ale_py  # type: ignore

    gym.register_envs(ale_py)
except Exception as e:
    ale_py = None
    if "--help" not in sys.argv and "-h" not in sys.argv:
        print(f"[Warning] ale_py import failed: {e}")
        print("[Warning] Atari envs may be unavailable until you install ale_py (and ROMs).")

# Ensure local modules are importable
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
if _THIS_DIR not in sys.path:
    sys.path.insert(0, _THIS_DIR)

from config import get_params
from model import make_q_network
from atari_wrappers import NoopResetEnv, AtariMaxPoolWrapper


def find_latest_checkpoint(runs_dir: str) -> str | None:
    """Find the most recently modified checkpoint under runs_dir."""
    if not os.path.isdir(runs_dir):
        return None
    best_path: str | None = None
    best_mtime: float = -1.0
    for root, _dirs, files in os.walk(runs_dir):
        for fn in files:
            if not (fn.endswith(".pth") and fn.startswith("checkpoint")):
                continue
            p = os.path.join(root, fn)
            try:
                mt = os.path.getmtime(p)
            except OSError:
                continue
            if mt > best_mtime:
                best_mtime = mt
                best_path = p
    return best_path


def preprocessing(img):
    """Preprocess image: RGB -> grayscale 84x84."""
    img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    img = cv2.resize(img, (84, 84), interpolation=cv2.INTER_AREA)
    return img


def stack_frames(stacked_frames, state, is_new_episode, num_stack):
    """Stack frames for observation"""
    frame = preprocessing(state)
    
    if is_new_episode:
        stacked_frames = np.stack([frame for _ in range(num_stack)], axis=0)
    else:
        stacked_frames = stacked_frames[1:, ...]
        stacked_frames = np.concatenate([stacked_frames, np.expand_dims(frame, axis=0)], axis=0)
    return stacked_frames


def _gym_make_strict(env_name: str, **kwargs):
    """Strict gym.make for Atari. Never silently drops protocol kwargs."""
    try:
        return gym.make(env_name, **kwargs)
    except TypeError as e:
        msg = (
            f"[Env Error] Failed to create env={env_name!r} with required kwargs={kwargs}.\n"
            f"This evaluation requires strict control over sticky actions and full action space.\n"
            f"Original error: {e}\n"
            f"Suggested fixes: upgrade `gymnasium`/`ale-py`, and ensure you're using the ALE v5 environments."
        )
        raise TypeError(msg) from e


def make_atari(
    env_name,
    max_episode_steps=None,
    use_max_pool=False,
    frame_skip=5,
    sticky_action_prob=0.25,
    full_action_space=True,
    noop_on_reset=True,
    noop_max=30,
):
    """
    Create Atari environment
    Copied from main.py (kept consistent for evaluation).
    """
    # Always enable DQN-style action repeat + max-over-last-2-frames preprocessing.
    # Ignore `use_max_pool` to avoid protocol drift.
    env = _gym_make_strict(
        env_name,
        frameskip=1,
        repeat_action_probability=float(sticky_action_prob),
        full_action_space=full_action_space,
    )
    env = AtariMaxPoolWrapper(env, frame_skip=frame_skip)
    # TimeLimit counts env.step() calls (agent steps), not ALE frames.
    actual_max_steps = max_episode_steps

    if actual_max_steps is None:
        actual_max_steps = int(math.ceil(18000 / max(1, int(frame_skip))))
    
    # Wrap with TimeLimit to ensure max_episode_steps is enforced
    # This is more reliable than just setting _max_episode_steps attribute
    if not isinstance(env, TimeLimit):
        env = TimeLimit(env, max_episode_steps=actual_max_steps)
    else:
        # If already wrapped with TimeLimit, update the max_episode_steps
        env._max_episode_steps = actual_max_steps

    if bool(noop_on_reset):
        env = NoopResetEnv(env, noop_max=int(noop_max))
    
    return env


@torch.no_grad()
def select_action_greedy(q_net, state, device, n_actions):
    """
    Greedy action selection using the Q-network (epsilon=0).
    
    Args:
        q_net: Q-network
        state: observation [C,H,W] or [1,C,H,W]
        device: torch device
        n_actions: number of actions
    
    Returns:
        action: selected action
    """
    if isinstance(state, np.ndarray):
        if state.ndim == 3:
            state = np.expand_dims(state, 0)
        # Keep uint8; QNetwork handles conversion internally
        state = torch.from_numpy(state).to(device)
        if state.dtype != torch.uint8:
            state = state.to(torch.uint8)
    
    # Forward pass through Q network
    # QNetwork converts uint8 -> float internally (x.float() / 255.0)
    q_values = q_net(state)  # [1, A]
    action = int(torch.argmax(q_values, dim=1).item())
    return action


def evaluate_policy(checkpoint_path, n_episodes=10, seed=None, render=False, verbose=True):
    """
    Evaluate a trained policy using the Q-network only.
    
    Args:
        checkpoint_path: path to checkpoint file
        n_episodes: number of evaluation episodes
        seed: RNG seed
        render: whether to render the environment
        verbose: whether to print verbose logs
    
    Returns:
        results: dict with evaluation statistics
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load checkpoint
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
    
    # Load checkpoint with weights_only=False to allow loading config and other non-weight data
    # Note: PyTorch 2.6+ defaults to weights_only=True, but our checkpoints contain config dict
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    
    # Load config from checkpoint (fallback to defaults if missing)
    if 'config' in checkpoint:
        config = checkpoint['config']
    else:
        # If no config is saved, use argparse defaults
        print("[Warning] No config found in checkpoint, using default config")
        config = get_params()
    
    # Set RNG seeds
    if seed is not None:
        np.random.seed(seed)
        torch.manual_seed(seed)
        eval_seed = seed
    else:
        eval_seed = np.random.randint(0, 10000)
        np.random.seed(eval_seed)
        torch.manual_seed(eval_seed)
    
    if verbose:
        print(f"[Evaluate] Loading checkpoint: {checkpoint_path}")
        print(f"[Evaluate] Device: {device}")
        print(f"[Evaluate] Evaluation seed: {eval_seed}")
        print(f"[Evaluate] Number of episodes: {n_episodes}")
    
    # Determine frame_stack
    frame_stack = int(config.get("frame_stack", 4))
    state_shape = (frame_stack, 84, 84)
    
    # Create a temporary env to get the action space
    test_env = make_atari(
        config["env_name"],
        max_episode_steps=config.get("max_episode_steps", 4500),
        use_max_pool=config.get("use_max_pool", False),
        frame_skip=config.get("frame_skip", 5),
        sticky_action_prob=config.get("sticky_action_prob", 0.25),
        full_action_space=config.get("full_action_space", True),
        noop_on_reset=config.get("noop_on_reset", True),
        noop_max=config.get("noop_max", 30),
    )
    n_actions = test_env.action_space.n
    test_env.close()
    
    if verbose:
        print(f"[Evaluate] Environment: {config['env_name']}")
        print(f"[Evaluate] Action space: {n_actions}")
        print(f"[Evaluate] State shape: {state_shape}")
    
    # Build Q-network (architecture is stored in checkpoint config when available)
    q_net_type = str(config.get("q_net_type", "nature")).lower()
    q_net = make_q_network(state_shape, n_actions, q_net_type).to(device)
    
    # Load Q-network weights
    if 'q_net' in checkpoint and checkpoint['q_net'] is not None:
        q_net.load_state_dict(checkpoint['q_net'])
        if verbose:
            print("[Evaluate] Loaded Q network weights from checkpoint")
    else:
        raise ValueError("Checkpoint does not contain Q network weights (q_net). Cannot evaluate DDQN policy.")
    
    q_net.eval()
    
    # Create evaluation environment
    env = make_atari(
        config["env_name"],
        max_episode_steps=config.get("max_episode_steps", 4500),
        use_max_pool=config.get("use_max_pool", False),
        frame_skip=config.get("frame_skip", 5),
        sticky_action_prob=config.get("sticky_action_prob", 0.25),
        full_action_space=config.get("full_action_space", True),
        noop_on_reset=config.get("noop_on_reset", True),
        noop_max=config.get("noop_max", 30),
    )
    
    # Evaluation stats
    episode_rewards = []
    episode_lengths = []
    positive_reward_episodes = 0
    
    if verbose:
        print(f"\n[Evaluate] Starting evaluation...")
        pbar = tqdm(range(n_episodes), desc="Evaluating")
    else:
        pbar = range(n_episodes)
    
    for episode in pbar:
        # Reset environment
        reset_result = env.reset(seed=eval_seed + episode if seed is not None else None)
        obs = reset_result[0] if isinstance(reset_result, tuple) else reset_result
        
        # Initialize stacked frames
        stacked_state = stack_frames(
            np.zeros((frame_stack, 84, 84), dtype=np.uint8),
            obs,
            is_new_episode=True,
            num_stack=frame_stack
        )
        
        episode_reward = 0.0
        episode_length = 0
        done = False
        
        while not done:
            if render:
                env.render()
            
            # Select action using Q network (greedy, epsilon=0)
            action = select_action_greedy(q_net, stacked_state, device, n_actions)
            
            # Step environment
            next_obs, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated
            
            # Update stacked frames
            stacked_state = stack_frames(
                stacked_state,
                next_obs,
                is_new_episode=False,
                num_stack=frame_stack
            )
            
            episode_reward += reward
            episode_length += 1
        
        episode_rewards.append(episode_reward)
        episode_lengths.append(episode_length)
        if episode_reward > 0:
            positive_reward_episodes += 1
        
        if verbose:
            pbar.set_postfix({
                'reward': f'{episode_reward:.1f}',
                'length': episode_length,
                'avg_reward': f'{np.mean(episode_rewards):.1f}'
            })
    
    env.close()
    
    # Compute summary statistics
    results = {
        'episode_rewards': episode_rewards,
        'episode_lengths': episode_lengths,
        'mean_reward': np.mean(episode_rewards),
        'std_reward': np.std(episode_rewards),
        'min_reward': np.min(episode_rewards),
        'max_reward': np.max(episode_rewards),
        'mean_length': np.mean(episode_lengths),
        'std_length': np.std(episode_lengths),
        'positive_reward_episodes': positive_reward_episodes,
        'positive_reward_rate': positive_reward_episodes / n_episodes,
        'n_episodes': n_episodes,
    }
    
    if verbose:
        print(f"\n[Evaluate] Evaluation Results:")
        print(f"  Mean Reward: {results['mean_reward']:.2f} ± {results['std_reward']:.2f}")
        print(f"  Min/Max Reward: {results['min_reward']:.1f} / {results['max_reward']:.1f}")
        print(f"  Mean Length: {results['mean_length']:.1f} ± {results['std_length']:.1f}")
        print(f"  Positive Reward Episodes: {results['positive_reward_episodes']}/{n_episodes} ({results['positive_reward_rate']*100:.1f}%)")
    
    return results


def main():
    parser = argparse.ArgumentParser(description="Evaluate trained DDQN+FPVR policy using only Q network")
    parser.add_argument(
        "--checkpoint",
        type=str,
        default=None,
        help="Path to checkpoint file (.pth). If not set, the latest checkpoint under fpvr_dqn_atari/runs/ is used.",
    )
    parser.add_argument("--n_episodes", type=int, default=10,
                        help="Number of episodes to evaluate")
    parser.add_argument("--seed", type=int, default=None,
                        help="Random seed for evaluation")
    parser.add_argument("--render", action="store_true",
                        help="Render environment during evaluation")
    parser.add_argument("--quiet", action="store_true",
                        help="Suppress verbose output")
    
    args = parser.parse_args()
    if args.checkpoint is None:
        ckpt = find_latest_checkpoint(os.path.join(_THIS_DIR, "runs"))
        if ckpt is None:
            raise ValueError("No checkpoint found. Pass --checkpoint or run training to create fpvr_dqn_atari/runs/...")
        args.checkpoint = ckpt
    
    try:
        results = evaluate_policy(
            checkpoint_path=args.checkpoint,
            n_episodes=args.n_episodes,
            seed=args.seed,
            render=args.render,
            verbose=not args.quiet
        )
        
        # Return exit code
        return 0
    except Exception as e:
        print(f"[Error] Evaluation failed: {e}")
        import traceback
        traceback.print_exc()
        return 1


if __name__ == "__main__":
    exit(main())

