# evaluate_pass_n.py 

import os
import argparse
from typing import Dict, Any, Optional, Callable
import json
import random

import numpy as np
import torch
import gymnasium as gym
from gymnasium.wrappers import RecordVideo

import minigrid  

from train_utils import setup_environment as setup_environment_multiseed

from utils import BabyAI_BC, prepare_obs

GAME_CODE_TO_INFO = {
    "open": {"use_text": True},
    "pickup": {"use_text": True},
    "goto": {"use_text": True},
    "unlock": {"use_text": True},
    "synthseq": {"use_text": True},
    "bosslevel": {"use_text": True},
}

def success_from_info_or_reward(terminated: bool, truncated: bool, info: Dict[str, Any], ep_reward: float) -> bool:
    if isinstance(info, dict) and "success" in info:
        return bool(info["success"])
    return bool(terminated and ep_reward > 0.0)

def parse_args():
    parser = argparse.ArgumentParser(description="Evaluate a trained BabyAI agent using the pass@n metric (env matches ppo.py).")
    parser.add_argument("--weights", type=str, required=True,
                        help="Path to the model weights (.pt) to evaluate.")
    parser.add_argument("--game_code", type=str, default="bosslevel",
                        help="Short game code, e.g., open / pickup / goto / unlock.")

    parser.add_argument("--num_env_seeds", type=int, default=1000,
                        help="Number of unique environment seeds to evaluate on.")
    parser.add_argument("--n_rollouts_per_env", type=int, required=True,
                        help="Number of rollouts per environment seed (the 'n' in pass@n).")
    parser.add_argument("--start_seed", type=int, default=30000,
                        help="Starting seed for environment generation.")
    parser.add_argument("--max_steps_per_episode", type=int, default=100)
    parser.add_argument("--temperature", type=float, default=1.0)

    parser.add_argument("--save_video", action="store_true")
    parser.add_argument("--video_dir", type=str, default="eval_videos_{game_code}")

    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
    parser.add_argument("--use_text", type=str, default="auto", choices=["auto", "true", "false"])
    return parser.parse_args()


@torch.no_grad()
def run_single_rollout(
    env: gym.Env,
    actor: torch.nn.Module,
    device: torch.device,
    use_text: bool,
    max_steps: int,
    temperature: float,
    prepare_obs_fn: Callable[[Dict[str, Any]], Dict[str, torch.Tensor]],
    obs: Dict[str, Any],
) -> bool:
    """Runs one full episode and returns success status."""
    
    done = False
    steps = 0
    total_ep_reward = 0.0
    last_info: Dict[str, Any] = {}
    terminated = False
    truncated = False

    while not done:
        batch = prepare_obs_fn(obs, device=device, use_text=use_text)
        actor_logits = actor(batch) / max(1e-6, float(temperature))
        action = torch.distributions.Categorical(logits=actor_logits).sample().item()

        obs, reward, terminated, truncated, info = env.step(action)
        done = bool(terminated or truncated)
        last_info = info
        total_ep_reward += float(reward)

        steps += 1
        if max_steps is not None and steps >= max_steps:
            break

    success = bool(success_from_info_or_reward(terminated, truncated, last_info, total_ep_reward))
    return success

@torch.no_grad()
def evaluate_pass_at_n(
    game_code: str,
    actor: torch.nn.Module,
    env: gym.Env,
    device: torch.device,
    use_text: bool,
    num_env_seeds: int,
    n_rollouts_per_env: int,
    start_seed: int,
    max_steps: int,
    temperature: float,
    prepare_obs_fn: Callable,
) -> Dict[str, Any]:
    passed_env_count = 0
    print(f"Starting pass@{n_rollouts_per_env} evaluation...")

    config_filename = f"{args.game_code}.yaml"
    config_path = os.path.join(args.config_dir, config_filename)
    print(f"--- Loading configuration from: {config_path} ---")

    try:
        with open(config_path, 'r') as f:
            config = yaml.safe_load(f)
    except FileNotFoundError:
        print(f"Error: Configuration file not found at {config_path}")
        return 

    train_seeds = config["train_seeds"]
    evaluate_seeds = config["evaluate_seeds"]
    hard_seeds = evaluate_seeds
    for i in range(len(hard_seeds)):
        current_seed = hard_seeds[i]

        env = setup_environment_multiseed(game_code, max_steps=max_steps, meta_fixed_seed=current_seed)
        
        env_is_passed = False

        for _ in range(n_rollouts_per_env):
            obs, _ = env.reset(seed=current_seed)

            success = run_single_rollout(
                env=env,
                actor=actor,
                device=device,
                use_text=use_text,
                max_steps=max_steps,
                temperature=temperature,
                prepare_obs_fn=prepare_obs_fn,
                obs=obs,
            )
            if success:
                env_is_passed = True
                break

        if env_is_passed:
            passed_env_count += 1

        if (i + 1) % 10 == 0 or (i + 1) == len(hard_seeds):
            current_pass_rate = passed_env_count / (i + 1)
            print(f"  ... tested {i + 1}/{len(hard_seeds)} seeds. "
                  f"Current pass rate: {current_pass_rate:.4f} ({passed_env_count}/{i + 1})")
            
        env.close()

    pass_rate = passed_env_count / len(hard_seeds) if len(hard_seeds) > 0 else 0.0
    return {
        f"pass_at_{n_rollouts_per_env}_rate": float(pass_rate),
        "passed_environments": int(passed_env_count),
        "total_environments": int(len(hard_seeds)),
    }
def main():
    args = parse_args()

    torch.manual_seed(args.start_seed % (2**32 - 1))
    np.random.seed(args.start_seed % (2**32 - 1))

    device = torch.device(args.device)
    print(f"Using device: {device}")

    if args.use_text == "auto":
        use_text = PPO_GAME_INFO.get(args.game_code, {"use_text": False})["use_text"]
    else:
        use_text = (args.use_text == "true")

    print(f"Creating evaluation environment via setup_environment(...) (game={args.game_code})")
    env = setup_environment_multiseed(args.game_code, max_steps=args.max_steps_per_episode, meta_fixed_seed=args.start_seed)

    if args.save_video:
        video_dir_formatted = args.video_dir.format(game_code=args.game_code)
        print(f"Videos will be saved to: {video_dir_formatted}")
        os.makedirs(video_dir_formatted, exist_ok=True)
        env = RecordVideo(env, video_folder=video_dir_formatted,
                          episode_trigger=lambda x: True,
                          name_prefix=f"eval-{args.game_code}")

    n_actions = env.action_space.n

    print(f"Loading actor model from: {args.weights}")
    state = torch.load(args.weights, map_location=device)

    ckpt_uses_text = ("tok_emb.weight" in state)
    vocab_size = int(state["tok_emb.weight"].shape[0]) if ckpt_uses_text else 200

    actor = BabyAI_BC(n_actions=n_actions, use_text=use_text, vocab_size=(vocab_size if use_text else 200)).to(device)
    actor.load_state_dict(state, strict=True)
    actor.eval()

    results = evaluate_pass_at_n(
        game_code=args.game_code,
        actor=actor,
        env=env,
        device=device,
        use_text=use_text,
        num_env_seeds=args.num_env_seeds,
        n_rollouts_per_env=args.n_rollouts_per_env,
        start_seed=args.start_seed,
        max_steps=args.max_steps_per_episode,
        temperature=args.temperature,
        prepare_obs_fn=prepare_obs,
    )
    env.close()

    n = args.n_rollouts_per_env
    print("\n" + "="*35)
    print(f"      Pass@{n} Evaluation Results")
    print("="*35)
    print(f"Passed Environments:  {results['passed_environments']} / {results['total_environments']}")
    print(f"Pass@{n} Rate:           {results[f'pass_at_{n}_rate']:.4f}")
    print("="*35)


if __name__ == "__main__":
    main()
