import os
import argparse
from typing import Dict, Any, Optional, Callable, List, Tuple, Set
import json
import random

import numpy as np
import torch
import gymnasium as gym
from gymnasium.wrappers import RecordVideo

import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

import minigrid 
from minigrid.core.world_object import Wall, Door, Key, Ball, Box, Goal  

from ..train_utils import setup_environment as setup_environment_multiseed

from pretrain import MLPPolicy
from ..utils import prepare_obs, HIDDEN_DIM  

PPO_GAME_INFO = {
    "minigrid": {"use_text": False},
}

def success_from_info_or_reward(terminated: bool, truncated: bool, info: Dict[str, Any], ep_reward: float) -> bool:
    if isinstance(info, dict) and "success" in info:
        return bool(info["success"])
    return bool(terminated and ep_reward > 0.0)

def parse_args():
    p = argparse.ArgumentParser(
        description="Generalization evaluation for MiniGrid using MLPPolicy: "
                    "(1) unique rooms visited in N rollouts/seed, "
                    "(2) test finetuned policy from those room starts on the same map."
    )
    p.add_argument("--pretrained_weights", type=str, required=True,
                   help="Path to PRETRAINED MLPPolicy state_dict (.pt) used for initial rollouts.")
    p.add_argument("--finetuned_weights", type=str, required=True,
                   help="Path to FINETUNED MLPPolicy state_dict (.pt) used for room-start tests.")

    p.add_argument("--pretrained_temperature", type=float, default=1.0)
    p.add_argument("--finetuned_temperature", type=float, default=1.0)

    p.add_argument("--game_code", type=str, default="minigrid",
                   help="minigrid (or other MiniGrid-like task wired via train_utils).")
    p.add_argument("--n_rollouts_per_seed", type=int, default=100,
                   help="Number of rollouts to run per seed for room-visit counting.")
    p.add_argument("--max_steps_per_episode", type=int, default=100)
    p.add_argument("--save_video", action="store_true")
    p.add_argument("--video_dir", type=str, default="eval_videos_{game_code}")

    p.add_argument("--visualize_rollouts", action="store_true",
                   help="Save one image per seed with overlays of successful trajectories.")
    p.add_argument("--vis_dir", type=str, default="vis_minigrid_{game_code}",
                   help="Directory to save rollout visualizations.")
    p.add_argument("--bg_lighten", type=float, default=0.8,
                   help="Blend background toward white by this amount in [0,1]. 0=no change.")

    p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")

    p.add_argument("--save_json", type=str, default="generalization_eval_{game_code}.json",
                   help="Path to save a JSON with per-seed results (format string with {game_code} allowed).")
    return p.parse_args()

def _allow_numpy_for_weights_only():
    """Allow-list NumPy internals so torch.load(weights_only=True) can unpickle older RNG states."""
    import numpy as _np
    import torch as _torch

    safe = []
    try:
        reconstruct = _np._core.multiarray._reconstruct  
    except Exception:
        reconstruct = _np.core.multiarray._reconstruct  
    safe += [reconstruct, _np.ndarray, _np.dtype, _np.generic]

    try:
        ndt_mod = _np.dtypes
        safe += [getattr(ndt_mod, n) for n in dir(ndt_mod) if n.endswith("DType")]
    except Exception:
        pass

    _torch.serialization.add_safe_globals(safe)

def _maybe_extract_mlp_state(obj: Any) -> Dict[str, torch.Tensor]:
    if isinstance(obj, dict) and obj and all(isinstance(k, str) for k in obj.keys()):
        if "actor_state" in obj and isinstance(obj["actor_state"], dict):
            return obj["actor_state"]
        if all(torch.is_tensor(v) for v in obj.values()):
            return obj
        for k in ("state_dict", "model", "weights", "policy", "actor"):
            if k in obj and isinstance(obj[k], dict) and all(torch.is_tensor(v) for v in obj[k].values()):
                return obj[k]
    raise ValueError("Unrecognized checkpoint format for MLPPolicy. Expect a plain state_dict or a dict with 'actor_state'.")

def load_mlp_state_dict(path: str, device: torch.device) -> Dict[str, torch.Tensor]:
    _allow_numpy_for_weights_only()
    try:
        obj = torch.load(path, map_location=device)
    except Exception as e1:
        print(f"[warn] safe load failed: {e1}\nFalling back to weights_only=False.")
        obj = torch.load(path, map_location=device, weights_only=False)
    return _maybe_extract_mlp_state(obj)

def _obs_to_feature(prepped: Dict[str, torch.Tensor]) -> torch.Tensor:
    """
    Convert prepared obs to a single feature row for the MLP:
      - Flattened image (C,H,W) -> (C*H*W)
      - Direction one-hot (size 4)
    """
    img: torch.Tensor = prepped["image"]         
    x_img = img.reshape(1, -1).float()          
    dir_tensor: Optional[torch.Tensor] = prepped.get("direction")
    if dir_tensor is None:
        dir_tensor = torch.zeros((1,), dtype=torch.long, device=img.device)
    dir_tensor = dir_tensor.clamp(0, 3)
    dir_oh = torch.nn.functional.one_hot(dir_tensor, num_classes=4).float() 
    return torch.cat([x_img, dir_oh], dim=1)  

def _batch_to_features(
    obs_list: List[Dict[str, Any]],
    device: torch.device,
    use_text: bool = False 
) -> torch.Tensor:
    feats = []
    for obs in obs_list:
        prepped = prepare_obs(obs, device=device, use_text=use_text)
        feats.append(_obs_to_feature(prepped))  
    return torch.cat(feats, dim=0)  

def _find_fourrooms_splits(u) -> Tuple[int, int]:
    """Heuristically find the central vertical/horizontal wall lines in FourRooms."""
    W, H = int(u.width), int(u.height)
    vcands = []
    for x in range(1, W - 1):
        wall_cnt = gap_cnt = 0
        for y in range(1, H - 1):
            obj = u.grid.get(x, y)
            if isinstance(obj, Wall):
                wall_cnt += 1
            elif obj is None:
                gap_cnt += 1
        if wall_cnt >= (H - 3) and gap_cnt <= 2:
            vcands.append(x)
    vline = min(vcands, key=lambda xx: abs(xx - W // 2)) if vcands else (W // 2)

    hcands = []
    for y in range(1, H - 1):
        wall_cnt = gap_cnt = 0
        for x in range(1, W - 1):
            obj = u.grid.get(x, y)
            if isinstance(obj, Wall):
                wall_cnt += 1
            elif obj is None:
                gap_cnt += 1
        if wall_cnt >= (W - 3) and gap_cnt <= 2:
            hcands.append(y)
    hline = min(hcands, key=lambda yy: abs(yy - H // 2)) if hcands else (H // 2)
    return vline, hline


def get_agent_room_id(env_unwrapped) -> Optional[Tuple]:
    """
    Return a hashable room ID for BabyAI *or* MiniGrid FourRooms.
    For FourRooms we return ('4rooms', left_right, top_bottom) where each is {0,1}.
    """
    u = env_unwrapped

    try:
        if hasattr(u, "grid") and hasattr(u, "agent_pos") and isinstance(u.agent_pos, tuple):
            vline, hline = _find_fourrooms_splits(u)
            x, y = int(u.agent_pos[0]), int(u.agent_pos[1])
            lr = 0 if x < vline else 1   
            tb = 0 if y < hline else 1  
            return ("4rooms", lr, tb)
    except Exception:
        pass

    for name in ("pos2room", "pos_to_room", "room_from_pos"):
        if hasattr(u, name) and callable(getattr(u, name)):
            try:
                rr = getattr(u, name)(u.agent_pos)
                if isinstance(rr, tuple) and len(rr) == 2 and all(isinstance(x, int) for x in rr):
                    return ("rc", rr[0], rr[1])
                if hasattr(rr, "top") and hasattr(rr, "left"):
                    return ("tl", int(rr.top), int(rr.left))
            except Exception:
                pass

    if hasattr(u, "agent_room") and u.agent_room is not None:
        ar = u.agent_room
        if hasattr(ar, "top") and hasattr(ar, "left"):
            return ("tl", int(ar.top), int(ar.left))

    try:
        W, H = int(u.width), int(u.height)
        x, y = u.agent_pos
        return ("approx", int(x < W // 2), int(y < H // 2))
    except Exception:
        return None

def _maybe_set_first(d: Dict[Any, Any], k, v):
    if k not in d:
        d[k] = v

@torch.no_grad()
def run_single_rollout(
    env: gym.Env,
    actor: torch.nn.Module,
    device: torch.device,
    max_steps: int,
    temperature: float,
    obs: Dict[str, Any],
    record_rooms: bool = False,
) -> Tuple[bool, Set[Tuple], List[Tuple[int, int]], Dict[Tuple, Tuple[int,int]]]:
    """
    Runs one episode with an MLPPolicy actor.
    Returns: (success, visited_room_ids, trajectory, room_first_pos)
      - room_first_pos: maps room_id -> first-seen (x,y) tile during this rollout
    """
    done = False
    steps = 0
    total_ep_reward = 0.0
    last_info: Dict[str, Any] = {}
    terminated = False
    truncated = False
    visited_rooms: Set[Tuple] = set()
    trajectory: List[Tuple[int, int]] = []
    room_first_pos: Dict[Tuple, Tuple[int,int]] = {}

    if record_rooms:
        rid = get_agent_room_id(env.unwrapped)
        if rid is not None:
            visited_rooms.add(rid)
            _maybe_set_first(room_first_pos, rid, tuple(env.unwrapped.agent_pos))
    trajectory.append(tuple(env.unwrapped.agent_pos))

    while not done:
        feats = _batch_to_features([obs], device=device, use_text=False)
        logits = actor(feats) / max(1e-6, float(temperature))
        action = torch.distributions.Categorical(logits=logits).sample().item()

        obs, reward, terminated, truncated, info = env.step(action)
        done = bool(terminated or truncated)
        last_info = info
        total_ep_reward += float(reward)
        trajectory.append(tuple(env.unwrapped.agent_pos))

        if record_rooms:
            rid = get_agent_room_id(env.unwrapped)
            if rid is not None:
                visited_rooms.add(rid)
                _maybe_set_first(room_first_pos, rid, tuple(env.unwrapped.agent_pos))

        steps += 1
        if max_steps is not None and steps >= max_steps:
            break

    success = bool(success_from_info_or_reward(terminated, truncated, last_info, total_ep_reward))
    return success, visited_rooms, trajectory, room_first_pos

def visualize_grid_and_trajectories(
    env: gym.Env,
    trajectories: List[List[Tuple[int, int]]],
    seed: int,
    output_path: str,
    bg_lighten: float = 0.8,
):
    """
    Renders the env and overlays successful trajectories.
    Start = circle, Goal = star. Each trajectory is offset slightly to reduce overlap.
    """
    env.reset(seed=seed)
    background_img = env.render()
    if bg_lighten and bg_lighten > 0.0:
        img = background_img.astype(np.float32)
        background_img = np.clip((1.0 - bg_lighten) * img + bg_lighten * 255.0, 0, 255).astype(np.uint8)

    H, W = background_img.shape[:2]
    tile_size = getattr(env.unwrapped, "tile_size", 32)

    fig, ax = plt.subplots(figsize=(10, 10), dpi=150)
    ax.imshow(background_img)
    ax.axis("off")

    if not trajectories:
        ax.set_title(f"Successful Rollouts for Seed: {seed}\n(0 successful trajectories)", fontsize=16, pad=10)
        plt.tight_layout()
        plt.savefig(output_path, bbox_inches="tight")
        plt.close(fig)
        print(f"  -> Saved visualization to {output_path}")
        return

    def offset_traj(px_coords: List[Tuple[float, float]], idx: int, n: int, base_px: float = 8.0):
        if not px_coords:
            return px_coords
        s = (idx - (n - 1) / 2.0)
        x0, y0 = px_coords[0]
        x1, y1 = px_coords[-1]
        vx, vy = (x1 - x0), (y1 - y0)
        norm = (vx * vx + vy * vy) ** 0.5
        if norm < 1e-6:
            nx, ny = 1.0, 0.0
        else:
            nx, ny = -vy / norm, vx / norm
        ox, oy = nx * base_px * s, ny * base_px * s
        return [(x + ox, y + oy) for (x, y) in px_coords]

    # warm color ramp
    _anchors = ["#b91c1c", "#ef4444", "#fb923c", "#fbbf24", "#fde047"]
    _cmap = LinearSegmentedColormap.from_list("roasted", _anchors, N=256)
    t = np.linspace(0.25, 0.85, max(1, len(trajectories))) ** 1.2
    color_list = [_cmap(float(x)) for x in t]

    for i, traj in enumerate(trajectories):
        if not traj:
            continue
        px_coords = [((x + 0.5) * tile_size, (y + 0.5) * tile_size) for (x, y) in traj]
        px_coords = offset_traj(px_coords, idx=i, n=len(trajectories), base_px=8.0)
        xs = [p[0] for p in px_coords]
        ys = [p[1] for p in px_coords]
        color = color_list[i]
        ax.plot(xs, ys, linewidth=3.5, alpha=0.95, color=color, zorder=3)
        ax.plot(xs[0], ys[0], marker="o", markersize=11,
                markerfacecolor="#0000cd", markeredgecolor="#00008b", markeredgewidth=2.2, zorder=4)
        ax.plot(xs[-1], ys[-1], marker="*", markersize=18,
                markerfacecolor="#22c55e", markeredgecolor="#14532d", markeredgewidth=2.2, zorder=5)

    ax.set_title(
        f"Successful Rollouts for Seed: {seed}\n({len(trajectories)} successful trajectories)",
        fontsize=16, pad=10
    )
    plt.tight_layout()
    plt.savefig(output_path, bbox_inches="tight")
    plt.close(fig)
    print(f"  -> Saved visualization to {output_path}")

@torch.no_grad()
def evaluate_generalization(
    actor_pre: torch.nn.Module,        
    actor_fin: torch.nn.Module,       
    seeds: List[int],
    game_code: str,
    device: torch.device,
    n_rollouts_per_seed: int,
    max_steps: int,
    temperature_pre: float,
    temperature_fin: float,
    visualize: bool,
    vis_dir: str,
    bg_lighten: float,
) -> Dict[str, Any]:
    all_results: Dict[str, Any] = {"game_code": game_code, "per_seed": []}

    if visualize:
        os.makedirs(vis_dir, exist_ok=True)

    for idx, seed in enumerate(seeds, 1):
        print(f"\n=== Seed {seed} ({idx}/{len(seeds)}) ===")
        env = setup_environment_multiseed(game_code, max_steps=max_steps, meta_fixed_seed=seed, render_mode="rgb_array")

        unique_rooms: Set[Tuple] = set()
        room_start_pos: Dict[Tuple, Tuple[int,int]] = {}
        for r in range(n_rollouts_per_seed):
            obs, _ = env.reset(seed=seed)
            _, visited, _, first_pos_map = run_single_rollout(
                env=env,
                actor=actor_pre,
                device=device,
                max_steps=max_steps,
                temperature=temperature_pre,
                obs=obs,
                record_rooms=True,
            )
            unique_rooms |= visited
            for rid, pos in first_pos_map.items():
                _maybe_set_first(room_start_pos, rid, pos)

        print(f"  Unique rooms visited across {n_rollouts_per_seed} rollouts: {len(unique_rooms)}")

        test_room_ids = sorted(unique_rooms)
        room_results: List[Dict[str, Any]] = []
        successful_trajectories: List[List[Tuple[int, int]]] = []

        for k, rid in enumerate(test_room_ids, 1):
            env.reset(seed=seed)
            start_xy = room_start_pos.get(rid, None)
            if start_xy is None:
                obs = env.reset(seed=seed)[0]
            else:
                try:
                    env.unwrapped.agent_pos = tuple(start_xy)
                    if hasattr(env.unwrapped, "agent_dir"):
                        env.unwrapped.agent_dir = 0
                    if hasattr(env.unwrapped, "gen_obs"):
                        obs = env.unwrapped.gen_obs()
                    else:
                        obs = env.reset(seed=seed)[0]
                except Exception as e:
                    print(f"    [warn] Failed to set start pos for room {rid}: {e}")
                    obs = env.reset(seed=seed)[0]

            success, _, trajectory, _ = run_single_rollout(
                env=env,
                actor=actor_fin,
                device=device,
                max_steps=max_steps,
                temperature=temperature_fin,
                obs=obs,
                record_rooms=True,
            )
            if success:
                successful_trajectories.append(trajectory)

            room_results.append({
                "room_id": rid,
                "success_from_room_start": bool(success),
            })
            if k % 10 == 0 or k == len(test_room_ids):
                sr = sum(1 for rr in room_results if rr["success_from_room_start"])
                print(f"    ... tested {k}/{len(test_room_ids)} visited rooms | success so far: {sr}/{k}")

        if visualize:
            vis_path = os.path.join(vis_dir, f"seed_{seed}_rollouts.png")
            visualize_grid_and_trajectories(env, successful_trajectories, seed, vis_path, bg_lighten=bg_lighten)

        env.close()

        per_seed_entry = {
            "seed": int(seed),
            "unique_rooms_visited_across_rollouts": int(len(unique_rooms)),
            "room_start_results": room_results,
            "room_start_success_rate": (
                float(sum(1 for rr in room_results if rr["success_from_room_start"])) / max(1, len(room_results))
            ),
        }
        all_results["per_seed"].append(per_seed_entry)

    avg_unique_rooms = float(np.mean([e["unique_rooms_visited_across_rollouts"] for e in all_results["per_seed"]])) if all_results["per_seed"] else 0.0
    avg_room_start_success = float(np.mean([e["room_start_success_rate"] for e in all_results["per_seed"]])) if all_results["per_seed"] else 0.0
    all_results["summary"] = {
        "num_seeds": len(seeds),
        "avg_unique_rooms_visited": avg_unique_rooms,
        "avg_room_start_success_rate": avg_room_start_success,
    }
    return all_results

def main():
    args = parse_args()

    train_seeds = config["train_seeds"]
    evaluate_seeds = config["evaluate_seeds"]
    seeds = evaluate_seeds


    torch.manual_seed(0)
    np.random.seed(0)
    random.seed(0)

    device = torch.device(args.device)
    print(f"Using device: {device}")

    dummy_env = setup_environment_multiseed(args.game_code, max_steps=args.max_steps_per_episode, meta_fixed_seed=seeds[0])
    n_actions = dummy_env.action_space.n
    dummy_obs, _ = dummy_env.reset()
    input_dim = int(_batch_to_features([dummy_obs], device=device, use_text=False).shape[1])

    print(f"Loading PRETRAINED MLP actor from: {args.pretrained_weights}")
    sd_pre = load_mlp_state_dict(args.pretrained_weights, device=device)
    actor_pre = MLPPolicy(input_dim=input_dim, output_dim=n_actions, hidden_dim=HIDDEN_DIM).to(device)
    actor_pre.load_state_dict(sd_pre, strict=True)
    actor_pre.eval()

    print(f"Loading FINETUNED MLP actor from: {args.finetuned_weights}")
    sd_fin = load_mlp_state_dict(args.finetuned_weights, device=device)
    actor_fin = MLPPolicy(input_dim=input_dim, output_dim=n_actions, hidden_dim=HIDDEN_DIM).to(device)
    actor_fin.load_state_dict(sd_fin, strict=True)
    actor_fin.eval()

    dummy_env.close()

    vis_dir_formatted = args.vis_dir.format(game_code=args.game_code)

    results = evaluate_generalization(
        actor_pre=actor_pre,
        actor_fin=actor_fin,
        seeds=seeds,
        game_code=args.game_code,
        device=device,
        n_rollouts_per_seed=args.n_rollouts_per_seed,
        max_steps=args.max_steps_per_episode,
        temperature_pre=args.pretrained_temperature,
        temperature_fin=args.finetuned_temperature,
        visualize=args.visualize_rollouts,
        vis_dir=vis_dir_formatted,
        bg_lighten=args.bg_lighten,
    )

    print("\n" + "=" * 40)
    print("  Generalization Evaluation Summary")
    print("=" * 40)
    print(f"Seeds evaluated: {results['summary']['num_seeds']}")
    print(f"Avg unique rooms visited (over {args.n_rollouts_per_seed} rollouts/seed): "
          f"{results['summary']['avg_unique_rooms_visited']:.2f}")
    print(f"Avg success-from-room-start rate: "
          f"{100.0 * results['summary']['avg_room_start_success_rate']:.2f}%")
    print("=" * 40)

    out_path = args.save_json.format(game_code=args.game_code)
    with open(out_path, "w") as f:
        json.dump(results, f, indent=2)
    print(f"Saved per-seed details to: {out_path}")
    if args.visualize_rollouts:
        print(f"Visualizations saved in: {vis_dir_formatted}")

if __name__ == "__main__":
    main()
