"""
SR-target ablation for FPVR (deep visual FPVR) in visual MiniGrid-Maze.

This script mirrors the experimental logic / outputs of `visual_gridworld/exploration_comparison.py`,
but compares only three FPVR variants that differ in the SR TD target used in the successor feature
update (implemented by `FPVRVisualAgent.sf_target_mode`):

  - min:     "min_redundancy"  (choose a' minimizing next-state redundancy)
  - mean:    "uniform_policy"  (uniform mean over actions)
  - current: "current_policy"  (use the actually taken next action a')

Outputs (saved under visual_gridworld/results_sr_target_test/ by default):
  1) cumulative coverage curves (png/eps)
  2) windowed coverage curves (png/eps)
  3) visitation heatmaps: 1x3 subplots with a shared colorbar on the right
"""

from __future__ import annotations

import argparse
import json
import os
import sys
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import numpy as np
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

# Work around occasional Windows OpenMP runtime duplication (torch + MKL/numpy).
if os.name == "nt" and "KMP_DUPLICATE_LIB_OK" not in os.environ:
    os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

import torch

# Ensure project root for absolute imports
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
_ROOT_DIR = os.path.dirname(_THIS_DIR)
if _ROOT_DIR not in sys.path:
    sys.path.insert(0, _ROOT_DIR)

from visual_gridworld.visual_minigrid import SimpleEnv
from visual_gridworld.fpvr_agent import FPVRVisualAgent
from visual_gridworld.config import Config


METHOD_COLORS = {
    "FPVR (SR target = min)": "C0",
    "FPVR (SR target = mean)": "C1",
    "FPVR (SR target = current)": "C2",
}


def _win_long_path(path: str) -> str:
    """
    Best-effort Windows long-path support.
    When a path is too long, Windows APIs may raise FileNotFoundError (Errno 2).
    Prefixing with '\\\\?\\' enables extended-length paths on supported systems.
    """
    p = os.path.abspath(str(path))
    if os.name != "nt":
        return p
    # Already extended-length?
    if p.startswith("\\\\?\\"):
        return p
    # UNC path: \\server\share\... -> \\?\UNC\server\share\...
    if p.startswith("\\\\"):
        p2 = "\\\\?\\UNC\\" + p.lstrip("\\")
        return p2
    # Normal drive path: C:\...
    if len(p) >= 240:
        return "\\\\?\\" + p
    return p


def _lighten(color, amount=0.5):
    try:
        import matplotlib.colors as mc
        import colorsys

        c = mc.cnames.get(color, color)
        c = colorsys.rgb_to_hls(*mc.to_rgb(c))
        return colorsys.hls_to_rgb(c[0], 1 - amount * (1 - c[1]), c[2])
    except Exception:
        return color


@dataclass
class ExpConfig:
    env_size: int = 20
    seed_base: int = 1
    n_seeds: int = 10
    total_steps: int = 9000
    coverage_reset_interval: int = 3000
    out_dir: str = ""


def load_coverage_data(cov_path: str, cov_win_path: str, total_points: int) -> Tuple[np.ndarray, np.ndarray]:
    # On Windows, long paths may not be detectable/loadable without the '\\\\?\\' prefix.
    cov_path_lp = _win_long_path(cov_path)
    cov_win_path_lp = _win_long_path(cov_win_path)

    # Use file handles so numpy loading also benefits from long-path prefix.
    try:
        with open(cov_path_lp, "rb") as f:
            cov = np.load(f, allow_pickle=False)
    except FileNotFoundError as e:
        raise FileNotFoundError(cov_path) from e

    try:
        with open(cov_win_path_lp, "rb") as f:
            cov_win = np.load(f, allow_pickle=False)
    except FileNotFoundError as e:
        raise FileNotFoundError(cov_win_path) from e

    def _adj(x: np.ndarray, expected: int) -> np.ndarray:
        if len(x) == expected:
            return x
        if len(x) == expected + 1:
            return x[:expected]
        if len(x) == expected - 1:
            out = np.zeros(expected, dtype=x.dtype)
            out[: len(x)] = x
            out[len(x) :] = x[-1] if len(x) else 0
            return out
        raise ValueError(f"Unexpected length: {len(x)} vs expected {expected}")

    return _adj(cov, total_points), _adj(cov_win, total_points)


def get_wall_positions(env_size: int) -> np.ndarray:
    """Dynamically detect wall positions from the environment; returns [H,W] bool."""
    try:
        env = SimpleEnv(size=env_size, render_mode=None, highlight=False)
        env.reset()
        wall_mask = np.zeros((env_size, env_size), dtype=bool)
        for x in range(env_size):
            for y in range(env_size):
                cell = env.grid.get(x, y)
                if cell is not None and hasattr(cell, "__class__") and "Wall" in cell.__class__.__name__:
                    wall_mask[y, x] = True  # MiniGrid(x,y) -> numpy[y,x]
        env.close()
        return wall_mask
    except Exception:
        return np.zeros((env_size, env_size), dtype=bool)


def plot_curves(
    curves_by_method: Dict[str, List[np.ndarray]],
    *,
    ylabel: str,
    out_png: str,
    out_eps: str,
    reset_interval: Optional[int],
) -> None:
    os.makedirs(_win_long_path(os.path.dirname(os.path.abspath(out_png))), exist_ok=True)
    os.makedirs(_win_long_path(os.path.dirname(os.path.abspath(out_eps))), exist_ok=True)

    fig, ax = plt.subplots(figsize=(10, 6))
    label_fs = 18
    tick_fs = 14
    legend_fs = 14

    order = ["FPVR (SR target = min)", "FPVR (SR target = mean)", "FPVR (SR target = current)"]
    for name in order:
        if name not in curves_by_method or len(curves_by_method[name]) == 0:
            continue
        arr = np.stack(curves_by_method[name], axis=0)
        mean = arr.mean(axis=0)
        std = arr.std(axis=0)
        x = np.arange(len(mean))
        color = METHOD_COLORS.get(name, "C0")
        band_color = _lighten(color, amount=0.88)
        ax.fill_between(x, mean - std, mean + std, color=band_color, linewidth=0.0, alpha=0.3, zorder=1)
        ax.plot(x, mean, label=name, linewidth=2.6, color=color, zorder=2)

    if reset_interval is not None and int(reset_interval) > 0:
        max_len = max(len(v[0]) for v in curves_by_method.values() if len(v) > 0)
        for r in range(int(reset_interval), max_len, int(reset_interval)):
            ax.axvline(x=r, color="darkred", linestyle="--", alpha=0.7, linewidth=1.5, zorder=0)

    ax.set_xlabel("Steps", fontsize=label_fs)
    ax.set_ylabel(ylabel, fontsize=label_fs)
    ax.tick_params(axis="both", which="major", labelsize=tick_fs)
    ax.grid(alpha=0.3)
    ax.legend(fontsize=legend_fs, frameon=True, loc="lower right")
    fig.tight_layout()

    # Matplotlib/PIL may fail on Windows long paths if given a filename string.
    # Use file handles opened with the long-path prefix.
    with open(_win_long_path(out_png), "wb") as f:
        fig.savefig(f, format="png", dpi=200)
    with open(_win_long_path(out_eps), "wb") as f:
        fig.savefig(f, format="eps")
    plt.close(fig)


def plot_heatmaps_1x3(
    counts_by_method: Dict[str, List[np.ndarray]],
    *,
    out_png: str,
    env_size: int,
) -> None:
    os.makedirs(_win_long_path(os.path.dirname(os.path.abspath(out_png))), exist_ok=True)
    wall_mask = get_wall_positions(env_size)
    has_walls = bool(wall_mask.sum() > 0)

    order = ["FPVR (SR target = min)", "FPVR (SR target = mean)", "FPVR (SR target = current)"]
    mats = []
    names = []
    for n in order:
        if n not in counts_by_method or len(counts_by_method[n]) == 0:
            continue
        mean_counts = np.stack(counts_by_method[n], axis=0).mean(axis=0).astype(float)
        if has_walls:
            mean_counts = mean_counts.copy()
            mean_counts[wall_mask] = np.nan
        mats.append(mean_counts)
        names.append(n)

    if len(mats) == 0:
        print("[Warning] No counts available; skipped heatmap plotting.")
        return

    vmin = float(np.nanmin([np.nanmin(m) for m in mats]))
    vmax = float(np.nanmax([np.nanmax(m) for m in mats]))

    fig, axes = plt.subplots(
        1,
        4,
        figsize=(13.8, 4.8),
        gridspec_kw={"width_ratios": [1, 1, 1, 0.06], "wspace": 0.35},
    )
    plot_axes = [axes[0], axes[1], axes[2]]
    axes[3].set_visible(False)
    cbar_ax = fig.add_axes(
        [
            axes[2].get_position().x1 + 0.050,
            axes[2].get_position().y0,
            0.015,
            axes[2].get_position().height,
        ]
    )

    cmap = plt.get_cmap("hot").copy()
    cmap.set_bad("gray")
    title_fs = 12
    tick_fs = 10

    im = None
    for i, ax in enumerate(plot_axes):
        if i >= len(mats):
            ax.axis("off")
            continue
        im = ax.imshow(
            mats[i],
            cmap=cmap,
            origin="upper",
            interpolation="nearest",
            extent=[0, env_size, env_size, 0],
            vmin=vmin,
            vmax=vmax,
        )
        ax.set_title(names[i], fontsize=title_fs)
        ax.set_xticks(range(0, env_size + 1, 2))
        ax.set_yticks(range(0, env_size + 1, 2))
        ax.set_xlim(0, env_size)
        ax.set_ylim(env_size, 0)
        ax.grid(color="lightgray", linewidth=0.5)
        ax.tick_params(axis="both", which="major", labelsize=tick_fs)
        ax.set_xticklabels([])
        ax.set_yticklabels([])

    if im is not None:
        cbar = fig.colorbar(im, cax=cbar_ax)
        cbar.ax.tick_params(labelsize=12)
        cbar.set_label("Mean visit count", fontsize=12)

    fig.subplots_adjust(top=0.92, left=0.03, right=0.985, bottom=0.05)
    # Matplotlib/PIL may fail on Windows long paths if given a filename string.
    with open(_win_long_path(out_png), "wb") as f:
        fig.savefig(f, format="png", dpi=200)
    plt.close(fig)
    print(f"[OK] Saved heatmap grid: {out_png}")


def run_deep_fpvr_one(
    *,
    cfg: ExpConfig,
    seed: int,
    sf_target_mode: str,
    out_dir: str,
) -> Tuple[str, str, str]:
    """
    Run one FPVR training with a specified SR target mode.
    Returns (coverage_path, coverage_windowed_path, counts_path).
    """
    # Ensure output directory exists before saving any files.
    os.makedirs(_win_long_path(out_dir), exist_ok=True)

    # Import canonical preprocess + training loop
    sys.path.insert(0, os.path.dirname(__file__))
    from fpvr_run import preprocess as preprocess_run
    from fpvr_run import run_fpvr_training

    # Set seeds for reproducibility
    np.random.seed(seed)
    torch.manual_seed(seed)

    env = SimpleEnv(size=cfg.env_size, render_mode="rgb_array", highlight=False)
    obs, _ = env.reset(seed=seed)
    obs_shape = preprocess_run(obs).shape
    n_actions = int(env.action_space.n)

    default_config = Config()
    agent = FPVRVisualAgent(
        obs_shape=obs_shape,
        n_actions=n_actions,
        lr=default_config.lr,
        phi_dim=default_config.phi_dim,
        psi_dim=default_config.psi_dim,
        sf_gamma=default_config.sf_gamma,
        beta=default_config.beta,
        capacity=default_config.capacity,
        batch_size=default_config.batch_size,
        update_after=default_config.update_after,
        update_every=default_config.update_every,
        whitening_update_every=default_config.whitening_update_every,
        sf_target_mode=str(sf_target_mode),
    )

    # run.py's training loop expects a global `config` (for lambda_c, print_interval)
    import types
    import builtins

    config_ns = types.SimpleNamespace()
    config_ns.lambda_c = float(default_config.lambda_c)
    config_ns.print_interval = 0
    builtins.config = config_ns
    try:
        res = run_fpvr_training(
            env=env,
            agent=agent,
            steps=int(cfg.total_steps),
            k_train=int(default_config.k_train),
            visualize=False,
            render_delay=0.0,
            feat_dim=int(default_config.phi_dim),
            env_seed=int(seed),
            reset_interval=int(cfg.coverage_reset_interval),
        )
    finally:
        try:
            delattr(builtins, "config")
        except Exception:
            pass
        env.close()

    cov = np.asarray(res.get("coverage_cumulative", res["coverage"]), dtype=np.int32)
    cov_win = np.asarray(res.get("coverage_windowed", res["coverage"]), dtype=np.int32)
    counts = np.asarray(res["counts"], dtype=np.int32)

    cov_path = os.path.join(out_dir, "coverage.npy")
    win_path = os.path.join(out_dir, "coverage_windowed.npy")
    cnt_path = os.path.join(out_dir, "counts.npy")
    # (Extra safety) Make sure parent directories exist (also handles long paths on Windows).
    os.makedirs(_win_long_path(os.path.dirname(os.path.abspath(cov_path))), exist_ok=True)

    # Use file handles so we can apply long-path prefix on Windows reliably.
    with open(_win_long_path(cov_path), "wb") as f:
        np.save(f, cov)
    with open(_win_long_path(win_path), "wb") as f:
        np.save(f, cov_win)
    with open(_win_long_path(cnt_path), "wb") as f:
        np.save(f, counts)

    with open(_win_long_path(os.path.join(out_dir, "config.json")), "w", encoding="utf-8") as f:
        json.dump(
            {
                "seed": int(seed),
                "env_size": int(cfg.env_size),
                "total_steps": int(cfg.total_steps),
                "coverage_reset_interval": int(cfg.coverage_reset_interval),
                "sf_target_mode": str(sf_target_mode),
            },
            f,
            indent=2,
        )

    return cov_path, win_path, cnt_path


def main() -> int:
    p = argparse.ArgumentParser(description="Visual FPVR SR-target comparison (deep FPVR only).")
    p.add_argument("--env_size", type=int, default=20)
    p.add_argument("--n_seeds", type=int, default=10)
    p.add_argument("--seed_base", type=int, default=1)
    p.add_argument("--total_steps", type=int, default=9000)
    p.add_argument("--coverage_reset_interval", type=int, default=3000)
    p.add_argument(
        "--out_dir",
        type=str,
        default=os.path.join(os.path.dirname(__file__), "results_sr_target_test"),
    )
    args = p.parse_args()

    cfg = ExpConfig(
        env_size=int(args.env_size),
        seed_base=int(args.seed_base),
        n_seeds=int(args.n_seeds),
        total_steps=int(args.total_steps),
        coverage_reset_interval=int(args.coverage_reset_interval),
        out_dir=str(args.out_dir),
    )
    os.makedirs(_win_long_path(cfg.out_dir), exist_ok=True)

    # Save run config
    with open(_win_long_path(os.path.join(cfg.out_dir, "config.json")), "w", encoding="utf-8") as f:
        json.dump(vars(args), f, indent=2)

    variants = [
        ("FPVR (SR target = min)", "min_redundancy", "min"),
        ("FPVR (SR target = mean)", "uniform_policy", "mean"),
        ("FPVR (SR target = current)", "current_policy", "current"),
    ]

    all_curves: Dict[str, List[np.ndarray]] = {name: [] for (name, _mode, _tag) in variants}
    all_curves_win: Dict[str, List[np.ndarray]] = {name: [] for (name, _mode, _tag) in variants}
    all_counts: Dict[str, List[np.ndarray]] = {name: [] for (name, _mode, _tag) in variants}

    for seed in range(cfg.seed_base, cfg.seed_base + cfg.n_seeds):
        for name, mode, tag in variants:
            run_dir = os.path.join(cfg.out_dir, f"{tag}_seed{seed}")
            print(f"[Run] {name} seed={seed} sf_target_mode={mode}")
            cov_path, win_path, cnt_path = run_deep_fpvr_one(
                cfg=cfg,
                seed=int(seed),
                sf_target_mode=str(mode),
                out_dir=run_dir,
            )

            cov, cov_win = load_coverage_data(cov_path, win_path, cfg.total_steps + 1)
            with open(_win_long_path(cnt_path), "rb") as f:
                counts = np.load(f, allow_pickle=False)

            all_curves[name].append(cov)
            all_curves_win[name].append(cov_win)
            all_counts[name].append(counts)

    # Plots
    out_png = os.path.join(cfg.out_dir, "coverage_sr_target.png")
    out_eps = out_png.replace(".png", ".eps")
    plot_curves(all_curves, ylabel="Number of visited states", out_png=out_png, out_eps=out_eps, reset_interval=None)
    print(f"[OK] Saved cumulative coverage plot: {out_png}")

    if int(cfg.coverage_reset_interval) > 0:
        win_png = os.path.join(cfg.out_dir, f"coverage_sr_target_reset{int(cfg.coverage_reset_interval)}.png")
        win_eps = win_png.replace(".png", ".eps")
        plot_curves(
            all_curves_win,
            ylabel=f"Visited states (reset every {int(cfg.coverage_reset_interval)} steps)",
            out_png=win_png,
            out_eps=win_eps,
            reset_interval=int(cfg.coverage_reset_interval),
        )
        print(f"[OK] Saved windowed coverage plot: {win_png}")

    heat_png = os.path.join(cfg.out_dir, "heatmap_sr_target_all.png")
    plot_heatmaps_1x3(all_counts, out_png=heat_png, env_size=int(cfg.env_size))
    return 0


if __name__ == "__main__":
    raise SystemExit(main())

