import os
import argparse
from typing import Dict, Any, Optional, Callable, List, Tuple, Set, Union
import json
import random

import numpy as np
import torch
import gymnasium as gym
from gymnasium.wrappers import RecordVideo
from matplotlib.colors import LinearSegmentedColormap

import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import matplotlib.patches as mpatches
import minigrid  
from minigrid.core.world_object import Wall, Door, Key, Ball, Box, Goal
from train_utils import setup_environment as setup_environment_multiseed
from utils import BabyAI_BC, prepare_obs

PPO_GAME_INFO = {
    "open": {"use_text": False},
    "pickup": {"use_text": False},
    "goto": {"use_text": True},
    "unlock": {"use_text": False},
    "bosslevel": {"use_text": True},
    "synthseq": {"use_text": True},
}

def success_from_info_or_reward(terminated: bool, truncated: bool, info: Dict[str, Any], ep_reward: float) -> bool:
    if isinstance(info, dict) and "success" in info:
        return bool(info["success"])
    return bool(terminated and ep_reward > 0.0)

def parse_args():
    p = argparse.ArgumentParser(
        description="Generalization evaluation: (1) unique-rooms-visited in N rollouts/seed, "
                    "(2) test finetuned policy from those visited rooms on the SAME map."
    )
    p.add_argument("--pretrained_weights", type=str, required=True,
               help="Path to PRETRAINED policy checkpoint/state_dict (.pt) used for initial rollouts.")
    p.add_argument("--finetuned_weights", type=str, required=True,
                help="Path to FINETUNED policy checkpoint/state_dict (.pt) used for room-start tests.")
    p.add_argument("--pretrained_temperature", type=float, default=1.0)
    p.add_argument("--finetuned_temperature", type=float, default=1.0)
    p.add_argument("--game_code", type=str, default="open",
                   help="Short game code, e.g., open / pickup / goto / unlock / bosslevel / synthseq.")
    p.add_argument("--n_rollouts_per_seed", type=int, default=100,
                   help="Number of rollouts to run per seed for room-visit counting.")
    p.add_argument("--max_steps_per_episode", type=int, default=100)
    p.add_argument("--save_video", action="store_true")
    p.add_argument("--video_dir", type=str, default="eval_videos_{game_code}")
    p.add_argument("--visualize_rollouts", action="store_true",
                   help="Save one image per seed with overlays of successful trajectories.")
    p.add_argument("--vis_dir", type=str, default="plots_{game_code}_100",
                   help="Directory to save rollout visualizations.")
    p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
    p.add_argument("--use_text", type=str, default="true", choices=["auto", "true", "false"])
    p.add_argument("--save_json", type=str, default="generalization_eval_{game_code}.json",
                   help="Path to save a JSON with per-seed results (format string with {game_code} allowed).")
    p.add_argument("--bg_lighten", type=float, default=0.8,
                   help="Blend background toward white by this amount in [0,1]. 0=no change, 0.5=quite light.")
    return p.parse_args()

def _allow_numpy_for_weights_only():
    """Allow-list NumPy internals so torch.load(weights_only=True) can unpickle older RNG states."""
    import numpy as _np
    import torch as _torch

    safe = []
    try:
        reconstruct = _np._core.multiarray._reconstruct 
    except Exception:
        reconstruct = _np.core.multiarray._reconstruct  
    safe += [reconstruct, _np.ndarray, _np.dtype, _np.generic]

    try:
        ndt_mod = _np.dtypes
        safe += [getattr(ndt_mod, n) for n in dir(ndt_mod) if n.endswith("DType")]
    except Exception:
        pass

    _torch.serialization.add_safe_globals(safe)

def _maybe_extract_actor_state(obj: Any) -> Dict[str, torch.Tensor]:
    if isinstance(obj, dict) and obj and all(isinstance(k, str) for k in obj.keys()):
        if all(isinstance(v, torch.Tensor) or torch.is_tensor(v) for v in obj.values()):
            return obj
        if "actor_state" in obj and isinstance(obj["actor_state"], dict):
            return obj["actor_state"]
        for k in ("state_dict", "model", "weights", "policy", "actor"):
            if k in obj and isinstance(obj[k], dict):
                inner = obj[k]
                if all(isinstance(v, torch.Tensor) or torch.is_tensor(v) for v in inner.values()):
                    return inner
    raise ValueError(
        "Unrecognized checkpoint format: expected an actor state_dict or an epoch checkpoint with 'actor_state'."
    )

def load_actor_state_dict(path: str, device: torch.device) -> Dict[str, torch.Tensor]:
    _allow_numpy_for_weights_only()
    try:
        obj = torch.load(path, map_location=device)
    except Exception as e1:
        print(f"[warn] safe load failed: {e1}\nFalling back to weights_only=False.")
        obj = torch.load(path, map_location=device, weights_only=False)
    return _maybe_extract_actor_state(obj)


def parse_seeds_arg(seeds_arg: str) -> List[int]:
    if os.path.isfile(seeds_arg):
        seeds = []
        with open(seeds_arg, "r") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                parts = [p for chunk in line.split() for p in chunk.split(",")]
                for p in parts:
                    p = p.strip()
                    if p:
                        seeds.append(int(p))
        return seeds
    else:
        return [int(s.strip()) for s in seeds_arg.split(",") if s.strip()]


def get_agent_room_id(env_unwrapped) -> Optional[Tuple]:
    """
    Return a hashable room ID for the agent's current room.
    Tries BabyAI RoomGrid internals when available; falls back gracefully.
    """
    u = env_unwrapped

    for name in ("pos2room", "pos_to_room", "room_from_pos"):
        if hasattr(u, name) and callable(getattr(u, name)):
            try:
                rr = getattr(u, name)(u.agent_pos)  
                if isinstance(rr, tuple) and len(rr) == 2 and all(isinstance(x, int) for x in rr):
                    return ("rc", rr[0], rr[1])
                if hasattr(rr, "top") and hasattr(rr, "left"):
                    return ("tl", int(rr.top), int(rr.left))
            except Exception:
                pass

    if hasattr(u, "agent_room") and u.agent_room is not None:
        ar = u.agent_room
        if hasattr(ar, "top") and hasattr(ar, "left"):
            return ("tl", int(ar.top), int(ar.left))

    try:
        x, y = u.agent_pos
        rs = getattr(u, "room_size", None)
        if rs is not None and rs > 0:
            return ("approx", int(x // rs), int(y // rs))
    except Exception:
        pass

    return None


def _maybe_set_first(d: Dict[Any, Any], k, v):
    if k not in d:
        d[k] = v


@torch.no_grad()
def run_single_rollout(
    env: gym.Env,
    actor: torch.nn.Module,
    device: torch.device,
    use_text: bool,
    max_steps: int,
    temperature: float,
    prepare_obs_fn: Callable[[Dict[str, Any], torch.device, bool], Dict[str, torch.Tensor]],
    obs: Dict[str, Any],
    record_rooms: bool = False,
) -> Tuple[bool, Set[Tuple], List[Tuple[int, int]], Dict[Tuple, Tuple[int,int]]]:
    """
    Runs one episode.
    Returns: (success, visited_room_ids, trajectory, room_first_pos)
      - room_first_pos: maps room_id -> first-seen (x,y) tile during this rollout
    """
    done = False
    steps = 0
    total_ep_reward = 0.0
    last_info: Dict[str, Any] = {}
    terminated = False
    truncated = False
    visited_rooms: Set[Tuple] = set()
    trajectory: List[Tuple[int, int]] = []
    room_first_pos: Dict[Tuple, Tuple[int,int]] = {}

    if record_rooms:
        rid = get_agent_room_id(env.unwrapped)
        if rid is not None:
            visited_rooms.add(rid)
            _maybe_set_first(room_first_pos, rid, tuple(env.unwrapped.agent_pos))
    trajectory.append(env.unwrapped.agent_pos)

    while not done:
        batch = prepare_obs_fn(obs, device=device, use_text=use_text)
        logits = actor(batch) / max(1e-6, float(temperature))
        action = torch.distributions.Categorical(logits=logits).sample().item()

        obs, reward, terminated, truncated, info = env.step(action)
        done = bool(terminated or truncated)
        last_info = info
        total_ep_reward += float(reward)
        trajectory.append(env.unwrapped.agent_pos)

        if record_rooms:
            rid = get_agent_room_id(env.unwrapped)
            if rid is not None:
                visited_rooms.add(rid)
                _maybe_set_first(room_first_pos, rid, tuple(env.unwrapped.agent_pos))

        steps += 1
        if max_steps is not None and steps >= max_steps:
            break

    success = bool(success_from_info_or_reward(terminated, truncated, last_info, total_ep_reward))
    return success, visited_rooms, trajectory, room_first_pos


# Visualization Helper
def visualize_grid_and_trajectories(
    env: gym.Env,
    trajectories: List[List[Tuple[int, int]]],
    seed: int,
    output_path: str,
    bg_lighten: float = 0.8,
):
    """
    Renders the env and overlays successful trajectories.
    - Start = circle (o), Goal = star (*)
    - Each trajectory is shifted tangentially by a small, unique offset
      so paths don't overlap visually.
    """
    env.reset(seed=seed)
    background_img = env.render()
    if bg_lighten and bg_lighten > 0.0:
        img = background_img.astype(np.float32)
        background_img = np.clip((1.0 - bg_lighten) * img + bg_lighten * 255.0, 0, 255).astype(np.uint8)

    H, W = background_img.shape[:2]
    cx, cy = W / 2.0, H / 2.0  

    tile_size = getattr(env.unwrapped, "tile_size", 32)

    fig, ax = plt.subplots(figsize=(10, 10), dpi=150)
    ax.imshow(background_img)
    ax.axis("off")

    if not trajectories:
        ax.set_title(f"Successful Rollouts for Seed: {seed}\n(0 successful trajectories)", fontsize=16, pad=10)
        plt.tight_layout()
        plt.savefig(output_path, bbox_inches="tight")
        plt.close(fig)
        print(f"  -> Saved native-style visualization to {output_path}")
        return

    def offset_traj(px_coords: List[Tuple[float, float]],
                         idx: int, n: int, base_px: float = 8.0):
        """
        Translate the whole trajectory by a constant screen-space offset.
        The offset direction is the normal to the overall start->end vector.
        """
        if not px_coords:
            return px_coords

        s = (idx - (n - 1) / 2.0)
        x0, y0 = px_coords[0]
        x1, y1 = px_coords[-1]
        vx, vy = (x1 - x0), (y1 - y0)
        norm = (vx * vx + vy * vy) ** 0.5

        if norm < 1e-6:
            nx, ny = 1.0, 0.0
        else:
            nx, ny = -vy / norm, vx / norm

        ox, oy = nx * base_px * s, ny * base_px * s
        return [(x + ox, y + oy) for (x, y) in px_coords]

    num_trajs = len(trajectories)

    _roasted_anchors = [
        "#b91c1c",  
        "#ef4444",  
        "#E23A2E", 
        "#fb923c",  
        "#fbbf24",  
        "#F59E0B", 
        "#fde047",  

    ]
    _roasted = LinearSegmentedColormap.from_list("roasted", _roasted_anchors, N=256)

    num_trajs = len(trajectories)
    t = np.linspace(0.25, 0.85, max(1, num_trajs))
    t = t ** 1.2
    color_list = [_roasted(float(x)) for x in t]

    for i, traj in enumerate(trajectories):
        if not traj:
            continue

        px_coords = [((x + 0.5) * tile_size, (y + 0.5) * tile_size) for (x, y) in traj]
        px_coords = offset_traj(px_coords, idx=i, n=num_trajs, base_px=8.0)

        x_coords = [p[0] for p in px_coords]
        y_coords = [p[1] for p in px_coords]
        color = color_list[i]
        GOLD_FILL = "#FFD700"      
        GOLD_EDGE = "#B8860B" 

        ax.plot(x_coords, y_coords, linewidth=3.5, alpha=0.95, color=color, zorder=3)
        ax.plot(
            x_coords[0], y_coords[0],
            marker="o", markersize=11,
            markerfacecolor="#0000cd",   
            markeredgecolor="#00008b",  
            markeredgewidth=2.2,
            zorder=4
        )
        ax.plot(
            x_coords[-1], y_coords[-1],
            marker="*", markersize=18,  
            markerfacecolor="#22c55e",  
            markeredgecolor="#14532d",  
            markeredgewidth=2.2,
            zorder=5
        )

    ax.set_title(
        f"Successful Rollouts for Seed: {seed}\n({num_trajs} successful trajectories)",
        fontsize=16, pad=10
    )
    plt.tight_layout()
    plt.savefig(output_path, bbox_inches="tight")
    plt.close(fig)
    print(f"  -> Saved native-style visualization to {output_path}")


# Generalization evaluation
@torch.no_grad()
def evaluate_generalization(
    actor_pre: torch.nn.Module,        
    actor_fin: torch.nn.Module,       
    seeds: List[int],
    game_code: str,
    device: torch.device,
    use_text_pre: bool,
    use_text_fin: bool,
    n_rollouts_per_seed: int,
    max_steps: int,
    temperature_pre: float,
    temperature_fin: float,
    visualize: bool,
    vis_dir: str,
    bg_lighten: float,         
) -> Dict[str, Any]:
    all_results: Dict[str, Any] = {"game_code": game_code, "per_seed": []}
    
    if visualize:
        os.makedirs(vis_dir, exist_ok=True)

    for idx, seed in enumerate(seeds, 1):
        print(f"\n=== Seed {seed} ({idx}/{len(seeds)}) ===")
        env = setup_environment_multiseed(game_code, max_steps=max_steps, meta_fixed_seed=seed, render_mode="rgb_array")

        unique_rooms: Set[Tuple] = set()
        room_start_pos: Dict[Tuple, Tuple[int,int]] = {}  
        for r in range(n_rollouts_per_seed):
            obs, _ = env.reset(seed=seed)
            _, visited, _, first_pos_map = run_single_rollout(
                env=env,
                actor=actor_pre,
                device=device,
                use_text=use_text_pre,
                max_steps=max_steps,
                temperature=temperature_pre,
                prepare_obs_fn=prepare_obs,
                obs=obs,
                record_rooms=True,
            )
            unique_rooms |= visited
            for rid, pos in first_pos_map.items():
                _maybe_set_first(room_start_pos, rid, pos)

        print(f"  Unique rooms visited across {n_rollouts_per_seed} rollouts: {len(unique_rooms)}")

        test_room_ids = sorted(unique_rooms)
        room_results: List[Dict[str, Any]] = []
        successful_trajectories: List[List[Tuple[int, int]]] = []

        for k, rid in enumerate(test_room_ids, 1):
            env.reset(seed=seed)
            start_xy = room_start_pos.get(rid, None)
            if start_xy is None:
                obs = env.reset(seed=seed)[0]
            else:
                try:
                    env.unwrapped.agent_pos = tuple(start_xy)
                    if hasattr(env.unwrapped, "agent_dir"):
                        env.unwrapped.agent_dir = 0
                    if hasattr(env.unwrapped, "gen_obs"):
                        obs = env.unwrapped.gen_obs()
                    else:
                        obs = env.reset(seed=seed)[0]
                except Exception as e:
                    print(f"    [warn] Failed to set start pos for room {rid}: {e}")
                    obs = env.reset(seed=seed)[0]

            success, _, trajectory, _ = run_single_rollout(
                env=env,
                actor=actor_fin,
                device=device,
                use_text=use_text_fin,
                max_steps=max_steps,
                temperature=temperature_fin,
                prepare_obs_fn=prepare_obs,
                obs=obs,
                record_rooms=True,
            )
            
            if success:
                successful_trajectories.append(trajectory)

            room_results.append({
                "room_id": rid,
                "success_from_room_start": bool(success),
            })
            if k % 10 == 0 or k == len(test_room_ids):
                sr = sum(1 for rr in room_results if rr["success_from_room_start"])
                print(f"    ... tested {k}/{len(test_room_ids)} visited rooms | success so far: {sr}/{k}")
        
        if visualize:
            vis_path = os.path.join(vis_dir, f"seed_{seed}_rollouts.png")
            visualize_grid_and_trajectories(
                env,
                successful_trajectories,
                seed,
                vis_path,
                bg_lighten=bg_lighten,
            )

        env.close()

        per_seed_entry = {
            "seed": int(seed),
            "unique_rooms_visited_across_rollouts": int(len(unique_rooms)),
            "room_start_results": room_results,
            "room_start_success_rate": (
                float(sum(1 for rr in room_results if rr["success_from_room_start"])) / max(1, len(room_results))
            ),
        }
        all_results["per_seed"].append(per_seed_entry)

    avg_unique_rooms = float(np.mean([e["unique_rooms_visited_across_rollouts"] for e in all_results["per_seed"]])) if all_results["per_seed"] else 0.0
    avg_room_start_success = float(np.mean([e["room_start_success_rate"] for e in all_results["per_seed"]])) if all_results["per_seed"] else 0.0
    all_results["summary"] = {
        "num_seeds": len(seeds),
        "avg_unique_rooms_visited": avg_unique_rooms,
        "avg_room_start_success_rate": avg_room_start_success,
    }
    return all_results


# Main
def main():
    args = parse_args()
    torch.manual_seed(0)
    np.random.seed(0)

    config_filename = f"{args.game_code}.yaml"
    config_path = os.path.join(args.config_dir, config_filename)
    print(f"--- Loading configuration from: {config_path} ---")

    try:
        with open(config_path, 'r') as f:
            config = yaml.safe_load(f)
    except FileNotFoundError:
        print(f"Error: Configuration file not found at {config_path}")
        return 

    train_seeds = config["train_seeds"]
    evaluate_seeds = config["evaluate_seeds"]

    seeds = evaluate_seeds

    device = torch.device(args.device)
    print(f"Using device: {device}")

    if args.use_text == "auto":
        use_text = PPO_GAME_INFO.get(args.game_code, {"use_text": False})["use_text"]
    else:
        use_text = (args.use_text == "true")

    dummy_env = setup_environment_multiseed(args.game_code, max_steps=args.max_steps_per_episode, meta_fixed_seed=seeds[0])
    n_actions = dummy_env.action_space.n

    print(f"Loading PRETRAINED actor from: {args.pretrained_weights}")
    sd_pre = load_actor_state_dict(args.pretrained_weights, device=device)
    use_text_pre = ("tok_emb.weight" in sd_pre)
    vocab_pre = int(sd_pre["tok_emb.weight"].shape[0]) if use_text_pre else 200
    actor_pre = BabyAI_BC(n_actions=n_actions, use_text=use_text_pre, vocab_size=vocab_pre).to(device)
    actor_pre.load_state_dict(sd_pre, strict=True)
    actor_pre.eval()

    print(f"Loading FINETUNED actor from: {args.finetuned_weights}")
    sd_fin = load_actor_state_dict(args.finetuned_weights, device=device)
    use_text_fin = ("tok_emb.weight" in sd_fin)
    vocab_fin = int(sd_fin["tok_emb.weight"].shape[0]) if use_text_fin else 200
    actor_fin = BabyAI_BC(n_actions=n_actions, use_text=use_text_fin, vocab_size=vocab_fin).to(device)
    actor_fin.load_state_dict(sd_fin, strict=True)
    actor_fin.eval()
    dummy_env.close()

    vis_dir_formatted = args.vis_dir.format(game_code=args.game_code)
    
    results = evaluate_generalization(
        actor_pre=actor_pre,
        actor_fin=actor_fin,
        seeds=seeds,
        game_code=args.game_code,
        device=device,
        use_text_pre=use_text_pre,
        use_text_fin=use_text_fin,
        n_rollouts_per_seed=args.n_rollouts_per_seed,
        max_steps=args.max_steps_per_episode,
        temperature_pre=args.pretrained_temperature,
        temperature_fin=args.finetuned_temperature,
        visualize=args.visualize_rollouts,
        vis_dir=vis_dir_formatted,
        bg_lighten=args.bg_lighten,
    )

    print("\n" + "=" * 40)
    print("  Generalization Evaluation Summary")
    print("=" * 40)
    print(f"Seeds evaluated: {results['summary']['num_seeds']}")
    print(f"Avg unique rooms visited (over {args.n_rollouts_per_seed} rollouts/seed): "
          f"{results['summary']['avg_unique_rooms_visited']:.2f}")
    print(f"Avg success-from-room-start rate: "
          f"{100.0 * results['summary']['avg_room_start_success_rate']:.2f}%")
    print("=" * 40)

    out_path = args.save_json.format(game_code=args.game_code)
    with open(out_path, "w") as f:
        json.dump(results, f, indent=2)
    print(f"Saved per-seed details to: {out_path}")
    if args.visualize_rollouts:
        print(f"Visualizations saved in: {vis_dir_formatted}")


if __name__ == "__main__":
    main()