#!/usr/bin/env python
"""Evaluate trained networks on MiniGrid and plot per-option heat-maps.

For each reachable cell, and averaged over the four agent orientations,
compute and save:
  • φ_i(s) from the VPSNet
  • V_i(s) from the ValueNet
  • r_rand_i(s) from a fixed Random Fourier Features layer
Optionally, stitch all options into overview grids.
"""
import argparse, os, torch, numpy as np, matplotlib.pyplot as plt
from bottleneck_env           import SimpleEnv
import matplotlib as mpl
from discrete_VPS_option.gridworld.utils import BottleneckVisualization
from continuous_vps_agent     import ContinuousVPSAgent
from networks                 import AtariCNN, VPSNet, ValueNet, RFFLayer  # ← RFFLayer

# ---------------- CLI ----------------
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt",     default="./gridworld/outputs/networks.pt",
                    help="trained network checkpoint (.pt)")
parser.add_argument("--out_dir",  default="vps_maps", type=str)
parser.add_argument("--device",   default="cpu")
parser.add_argument("--grid",     action="store_true",
                    help="stitch all options into one overview grid")
args = parser.parse_args()
os.makedirs(args.out_dir, exist_ok=True)
dev = torch.device(args.device)

# ---------------- 1. Load models ----------------
ckpt       = torch.load(args.ckpt, map_location=dev)
k_opt      = ckpt["k"]
stack_len  = ckpt["frame_stack_len"]

# ---- φ(s) network ----
backbone_vps = AtariCNN(in_channels=stack_len).to(dev)
backbone_vps.load_state_dict(ckpt["backbone_vps"]); backbone_vps.eval()
vps_net   = VPSNet(k_opt, backbone_vps).to(dev)
vps_net.head.load_state_dict(ckpt["vps_head"]); vps_net.eval()

# ---- V(s) network ----
backbone_val = AtariCNN(in_channels=stack_len).to(dev)
backbone_val.load_state_dict(ckpt["backbone_val"]); backbone_val.eval()
val_net  = ValueNet(k_opt, backbone_val).to(dev)
val_net.head.load_state_dict(ckpt["value_head"]); val_net.eval()

# ---- RFF random reward ----
backbone_rff = AtariCNN(in_channels=stack_len).to(dev)
backbone_rff.load_state_dict(ckpt["backbone_rff"]); backbone_rff.eval()

rff_layer = RFFLayer(backbone_rff.flat_dim, k_opt).to(dev)
rff_layer.load_state_dict(ckpt["rff"])
rff_layer.scale = ckpt["rff_scale"]       # restore scaling factor
rff_layer.eval()

# ---------------- 2. Environment & pre-processing helper ----------------
env = SimpleEnv(render_mode="rgb_array", highlight=False)
env.reset(seed=0)
dummy = ContinuousVPSAgent(env,
                           k_options=k_opt,
                           frame_stack_len=stack_len,
                           device=args.device,
                           buffer_cap=1)

@torch.no_grad()
def preprocess(obs_rgb):
    """Gray-scale 84×84 image duplicated to fill stack_len channels"""
    f = dummy._preprocess(obs_rgb)                 # (1,84,84)
    return torch.cat([f]*stack_len, 0).unsqueeze(0)  # (1,C,84,84)

H, W = env.height, env.width
sum_phi   = np.zeros((k_opt, W, H), dtype=np.float32)
sum_val   = np.zeros_like(sum_phi)
sum_rff   = np.zeros_like(sum_phi)                 # ← random reward
cnt       = np.zeros((W, H),     dtype=np.int32)

# ---------------- 3. Locate walkable cells ----------------
grid = env._grid if hasattr(env, "_grid") else env.grid

def cell_at(x: int, y: int):
    if isinstance(grid, list):
        return grid[x][y]
    if hasattr(grid, "get"):
        return grid.get(x, y)
    raise TypeError("Unknown grid type.")

free_coords = [(x, y)
               for y in range(H)
               for x in range(W)
               if (cell := cell_at(x, y)) is None or getattr(cell, "walkable", True)]

print(f"[Info] reachable cells: {len(free_coords)}")

# ---------------- 4. Traverse cells -------------------
for (x, y) in free_coords:
    phi_acc = np.zeros(k_opt, dtype=np.float32)
    val_acc = np.zeros(k_opt, dtype=np.float32)
    rff_acc = np.zeros(k_opt, dtype=np.float32)

    for dir_ in range(4):                # 0:N,1:E,2:S,3:W
        env.agent_pos = (x, y)
        env.agent_dir = dir_
        obs = env.render()
        stack = preprocess(obs)

        with torch.no_grad():
            phi = vps_net(stack)[0].cpu().numpy()          # (k,)
            v   = val_net(stack)[0].cpu().numpy()          # (k,)
            rr  = rff_layer(backbone_rff(stack))[0].cpu().numpy()  # (k,)

        phi_acc += phi
        val_acc += v
        rff_acc += rr

    sum_phi[:, x, y] = phi_acc / 4.0
    sum_val[:, x, y] = val_acc / 4.0
    sum_rff[:, x, y] = rff_acc / 4.0
    cnt[x, y]        = 1

# ---------------- 5. Plot ----------------------
viz = BottleneckVisualization(env)
phi_maps, val_maps, rff_maps = [], [], []

for i in range(k_opt):
    phi_map = np.where(cnt > 0, sum_phi[i], np.nan).T
    val_map = np.where(cnt > 0, sum_val[i], np.nan).T
    rff_map = np.where(cnt > 0, sum_rff[i], np.nan).T

    phi_maps.append(phi_map); val_maps.append(val_map); rff_maps.append(rff_map)

    # --- φ heat-map ---
    viz.plot_2d_heatmap(phi_map.flatten(),
                        title=f"Option {i}  VPS(s)", topk=0,
                        color_bar=True)
    p = os.path.join(args.out_dir, f"opt_{i:02d}_phi.png")
    plt.savefig(p, dpi=300); plt.close()
    print(f"[✓] {p}")

    # --- V heat-map ---
    viz.plot_2d_heatmap(val_map.flatten(),
                        title=f"Option {i}  V(s)", topk=0,
                        color_bar=True)
    p = os.path.join(args.out_dir, f"opt_{i:02d}_val.png")
    plt.savefig(p, dpi=300); plt.close()
    print(f"[✓] {p}")

    # --- r_rand heat-map ---
    viz.plot_2d_heatmap(rff_map.flatten(),
                        title=f"Option {i}  r(s)", topk=0,
                        color_bar=True)
    p = os.path.join(args.out_dir, f"opt_{i:02d}_rff.png")
    plt.savefig(p, dpi=300); plt.close()
    print(f"[✓] {p}")

# ---------------- 6. Stitch overview (optional) -------------
if args.grid:
    def merge(maps, fname, title):
        cols = int(np.ceil(np.sqrt(k_opt)))
        rows = int(np.ceil(k_opt / cols))
        fig, axes = plt.subplots(rows, cols, figsize=(3*cols, 3*rows))
        for idx, ax in enumerate(axes.flat):
            if idx < k_opt:
                im = ax.imshow(maps[idx], cmap="inferno", origin="upper")
                ax.set_title(f"opt {idx}")
                ax.axis("off")
            else:
                ax.axis("off")
        fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.6)
        fig.suptitle(title)
        plt.tight_layout()
        pth = os.path.join(args.out_dir, fname)
        plt.savefig(pth, dpi=300); plt.close()
        print(f"[✓] merged grid → {pth}")

    merge(phi_maps, "all_options_phi.png",  "All Options  φ(s)")
    merge(val_maps, "all_options_val.png",  "All Options  V(s)")
    merge(rff_maps, "all_options_rff.png",  "All Options  r_rand(s)")
