"""
SR target ablation for FPVR in tabular MiniGrid-FourRooms (exploration-only).

This script matches the experimental logic of `gridworld/fourrooms_exploration.py`,
but compares only three variants of the SR TD target used in the action-conditioned
successor representation update inside FPVR:

  - min:    use M[s', a_min] where a_min minimizes FPVR(s',a) under current C
  - mean:   use mean_a M[s', a]
  - current:use M[s', a'] where a' is the *actually taken* next action at s'

Outputs (saved under gridworld/results/ by default):
  1) cumulative coverage curves
  2) windowed coverage curves (periodically reset)
  3) visitation heatmaps (1x3 subplots, shared colorbar on the right)
"""

from __future__ import annotations

import argparse
import os
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

# Local env (optional dependency: minigrid). Keep `-h` usable without installing it.
try:
    from bottleneck_env import SimpleEnv  # type: ignore
except Exception:
    SimpleEnv = None  # type: ignore

from utils import BottleneckVisualization


METHOD_COLORS = {
    "FPVR (SR target = min)": "C0",
    "FPVR (SR target = mean)": "C1",
    "FPVR (SR target = current)": "C2",
}

Action = int  # 0:up, 1:down, 2:left, 3:right


def _pos_to_state(x: int, y: int, width: int) -> int:
    return int(y * width + x)


def _state_to_pos(s: int, width: int) -> Tuple[int, int]:
    return int(s % width), int(s // width)


def build_transition_from_env(env: SimpleEnv) -> Tuple[np.ndarray, np.ndarray]:
    """
    Build deterministic transition matrix T[s,a]=s' and a wall_mask[H,W] (True = blocked).
    Any cell with an object that cannot be overlapped is treated as blocked.
    """
    width, height = int(env.width), int(env.height)
    n_states = width * height
    n_actions = 4

    wall_mask_xy = np.zeros((width, height), dtype=bool)
    for y in range(height):
        for x in range(width):
            cell = env.grid.get(x, y)
            if cell is not None and (not cell.can_overlap()):
                wall_mask_xy[x, y] = True

    action_map = {0: (0, -1), 1: (0, 1), 2: (-1, 0), 3: (1, 0)}
    T = np.zeros((n_states, n_actions), dtype=np.int32)
    for s in range(n_states):
        x, y = _state_to_pos(s, width)
        if wall_mask_xy[x, y]:
            T[s, :] = s
            continue
        for a in range(n_actions):
            dx, dy = action_map[a]
            nx, ny = x + dx, y + dy
            sn = s
            if 0 <= nx < width and 0 <= ny < height and (not wall_mask_xy[nx, ny]):
                sn = _pos_to_state(nx, ny, width)
            T[s, a] = sn

    return T, wall_mask_xy.T  # (H,W)


def _first_free_state(wall_mask_hw: np.ndarray) -> int:
    h, w = wall_mask_hw.shape
    for y in range(h):
        for x in range(w):
            if not wall_mask_hw[y, x]:
                return _pos_to_state(x, y, w)
    raise RuntimeError("No free states in the environment.")


def _zscore(x: np.ndarray, eps: float = 1e-8) -> np.ndarray:
    mu = float(np.mean(x))
    sd = float(np.std(x))
    return (x - mu) / (sd + eps)


def _cosine(u: np.ndarray, v: np.ndarray, eps: float = 1e-12) -> float:
    un = float(np.linalg.norm(u) + eps)
    vn = float(np.linalg.norm(v) + eps)
    return float(np.dot(u, v) / (un * vn))


def _softmax_sample_from_cost(costs: np.ndarray, *, beta: float, rng: np.random.Generator) -> Action:
    # lower cost -> higher probability
    c = costs - float(np.min(costs))
    logits = -float(beta) * c
    logits = logits - float(np.max(logits))
    p = np.exp(logits)
    p = p / np.clip(np.sum(p), 1e-12, None)
    return int(rng.choice(np.arange(costs.shape[0]), p=p))


@dataclass
class RunConfig:
    total_steps: int = 3000
    n_seeds: int = 50
    seed: int = 78

    # SR learning
    gamma_sr: float = 0.9
    alpha_sr: float = 0.1
    c_decay: float = 0.999

    # Action selection temperature for direct FPVR
    fpvr_beta: float = 10.0

    # Start state
    start_x: Optional[int] = 1
    start_y: Optional[int] = 1

    # For visualization: windowed coverage resets every N steps (0 disables).
    coverage_reset_interval: int = 1000


def run_fpvr_direct_sr_target(
    T: np.ndarray,
    wall_mask_hw: np.ndarray,
    cfg: RunConfig,
    rng: np.random.Generator,
    *,
    sr_target: str,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """FPVR direct action selection, differing only in the SR TD target expectation."""
    n_states, n_act = T.shape
    width = wall_mask_hw.shape[1]

    # Action-conditioned SR (successor occupancy vector over states)
    M = np.zeros((n_states, n_act, n_states), dtype=np.float32)
    # Discounted past visitation (vector over states)
    C = np.zeros((n_states,), dtype=np.float32)
    # State visitation counts during training
    visit_counts = np.zeros((n_states,), dtype=np.int64)

    visited = set()
    cov = np.zeros((cfg.total_steps,), dtype=np.int32)
    visited_win = set()
    cov_win = np.zeros((cfg.total_steps,), dtype=np.int32)

    # init start state
    if cfg.start_x is not None and cfg.start_y is not None and not wall_mask_hw[cfg.start_y, cfg.start_x]:
        s = _pos_to_state(cfg.start_x, cfg.start_y, width)
    else:
        s = _first_free_state(wall_mask_hw)

    prev_s: int | None = None
    prev_a: int | None = None
    prev_s_target: str = str(sr_target).lower()

    for t in range(cfg.total_steps):
        visit_counts[s] += 1
        visited.add(s)
        visited_win.add(s)
        C *= float(cfg.c_decay)
        C[s] += 1.0

        # action costs = redundancy between future SR and past C
        costs = np.empty((n_act,), dtype=np.float32)
        for a in range(n_act):
            costs[a] = _cosine(M[s, a, :], C)
        costs = _zscore(costs)

        a = _softmax_sample_from_cost(costs, beta=float(cfg.fpvr_beta), rng=rng)

        # delayed SR update for (prev_s, prev_a) using current state as s'
        if prev_s is not None and prev_a is not None:
            e = np.zeros((n_states,), dtype=np.float32)
            e[s] = 1.0

            if prev_s_target == "min":
                # Choose a' that minimizes FPVR(s', a') under the current past-visitation vector C.
                next_costs = np.empty((n_act,), dtype=np.float32)
                for aa in range(n_act):
                    next_costs[aa] = _cosine(M[s, aa, :], C)
                a_min = int(np.argmin(next_costs))
                psi_exp = M[s, a_min, :]
            elif prev_s_target == "current":
                # Use the action actually taken at s' (current state).
                psi_exp = M[s, a, :]
            else:
                # mean: uniform policy expectation
                psi_exp = M[s, :, :].mean(axis=0)

            target = e + float(cfg.gamma_sr) * psi_exp
            M[prev_s, prev_a, :] = (1.0 - float(cfg.alpha_sr)) * M[prev_s, prev_a, :] + float(cfg.alpha_sr) * target

        sn = int(T[s, a])
        prev_s, prev_a = s, a
        s = sn

        cov[t] = int(len(visited))
        cov_win[t] = int(len(visited_win))
        if int(cfg.coverage_reset_interval) > 0 and ((t + 1) % int(cfg.coverage_reset_interval) == 0):
            visited_win.clear()
            visited_win.add(s)

    return cov, cov_win, visit_counts


def _mean_std(curves: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
    arr = np.stack(curves, axis=0)
    return arr.mean(axis=0), arr.std(axis=0)


def main() -> int:
    p = argparse.ArgumentParser(description="Fourrooms FPVR SR-target comparison (coverage + heatmaps).")
    p.add_argument("--total_steps", type=int, default=3000)
    p.add_argument("--n_seeds", type=int, default=50)
    p.add_argument("--seed", type=int, default=78)
    p.add_argument("--start_x", type=int, default=1)
    p.add_argument("--start_y", type=int, default=1)

    # SR / FPVR
    p.add_argument("--gamma_sr", type=float, default=0.9)
    p.add_argument("--alpha_sr", type=float, default=0.1)
    p.add_argument("--c_decay", type=float, default=0.999)
    p.add_argument("--fpvr_beta", type=float, default=10.0, help="Softmax temperature for FPVR.")

    # Coverage visualization (windowed coverage)
    p.add_argument(
        "--coverage_reset_interval",
        type=int,
        default=1000,
        help="Reset the windowed coverage every N steps (0 disables). Default: 1000.",
    )

    # Output paths: by default, save under gridworld/results/ (not repo root).
    _DEFAULT_RESULTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results")
    p.add_argument("--out_png", type=str, default=os.path.join(_DEFAULT_RESULTS_DIR, "fourrooms_sr_target.png"))
    p.add_argument(
        "--out_eps",
        type=str,
        default=None,
        help="Optional EPS output path. If not set, will save an .eps next to --out_png.",
    )
    args = p.parse_args()

    def _resolve_out_path(path: str) -> str:
        if os.path.isabs(path):
            return path
        return os.path.join(_DEFAULT_RESULTS_DIR, path)

    args.out_png = _resolve_out_path(str(args.out_png))
    if args.out_eps is not None:
        args.out_eps = _resolve_out_path(str(args.out_eps))
    os.makedirs(os.path.dirname(os.path.abspath(args.out_png)), exist_ok=True)

    if SimpleEnv is None:
        raise ModuleNotFoundError(
            "Missing optional dependency for gridworld/minigrid environment (minigrid). "
            "Please install it (e.g., `pip install minigrid`) to run fourrooms_sr_target_test."
        )

    cfg = RunConfig(
        total_steps=int(args.total_steps),
        n_seeds=int(args.n_seeds),
        seed=int(args.seed),
        gamma_sr=float(args.gamma_sr),
        alpha_sr=float(args.alpha_sr),
        c_decay=float(args.c_decay),
        fpvr_beta=float(args.fpvr_beta),
        start_x=int(args.start_x) if args.start_x is not None else None,
        start_y=int(args.start_y) if args.start_y is not None else None,
        coverage_reset_interval=int(args.coverage_reset_interval),
    )

    env = SimpleEnv(render_mode=None)
    env.reset(seed=cfg.seed)
    T, wall_mask_hw = build_transition_from_env(env)
    env.close()

    methods = {
        "FPVR (SR target = min)": "min",
        "FPVR (SR target = mean)": "mean",
        "FPVR (SR target = current)": "current",
    }

    all_curves: Dict[str, List[np.ndarray]] = {k: [] for k in methods.keys()}
    all_curves_win: Dict[str, List[np.ndarray]] = {k: [] for k in methods.keys()}
    all_visits: Dict[str, List[np.ndarray]] = {k: [] for k in methods.keys()}

    # For fairness across SR-target variants: use the same RNG seed per method per run.
    for i in range(cfg.n_seeds):
        base_seed = int(cfg.seed + i)
        for name, target in methods.items():
            rng = np.random.default_rng(base_seed)
            cov_full, cov_win, visit_counts = run_fpvr_direct_sr_target(T, wall_mask_hw, cfg, rng, sr_target=str(target))
            all_curves[name].append(cov_full)
            all_curves_win[name].append(cov_win)
            all_visits[name].append(visit_counts)

    # ---------------- Plot styling (paper-ready) ----------------
    label_fs = 18
    tick_fs = 14
    legend_fs = 14

    def plot_and_save(curves_dict: Dict[str, List[np.ndarray]], *, ylabel: str, out_png: str, out_eps: str):
        fig, ax = plt.subplots(figsize=(9, 5.4))
        x = np.arange(cfg.total_steps)

        def _lighten(color, amount: float = 0.85):
            r, g, b = mcolors.to_rgb(color)
            return (r + (1 - r) * amount, g + (1 - g) * amount, b + (1 - b) * amount)

        for idx, (name, curves) in enumerate(curves_dict.items()):
            mean, std = _mean_std(curves)
            color = METHOD_COLORS.get(name, f"C{idx % 10}")
            band_color = _lighten(color, amount=0.88)
            ax.fill_between(x, mean - std, mean + std, color=band_color, linewidth=0.0, zorder=1)
            ax.plot(x, mean, label=name, linewidth=2.6, color=color, zorder=2)

        ax.set_xlabel("Steps", fontsize=label_fs)
        ax.set_ylabel(ylabel, fontsize=label_fs)
        ax.tick_params(axis="both", which="major", labelsize=tick_fs)
        ax.grid(alpha=0.3)
        ax.legend(fontsize=legend_fs, frameon=True, loc="lower right")
        fig.tight_layout()
        fig.savefig(out_png, dpi=200)
        fig.savefig(out_eps, format="eps")
        plt.close(fig)

    out_eps = args.out_eps
    if out_eps is None:
        stem, _ext = os.path.splitext(args.out_png)
        out_eps = stem + ".eps"
    os.makedirs(os.path.dirname(os.path.abspath(out_eps)), exist_ok=True)

    plot_and_save(all_curves, ylabel="Number of visited states", out_png=args.out_png, out_eps=out_eps)
    print(f"[OK] Saved plot to: {args.out_png}")
    print(f"[OK] Saved plot to: {out_eps}")

    if int(cfg.coverage_reset_interval) > 0:
        stem_png, _ = os.path.splitext(args.out_png)
        out_png_win = f"{stem_png}_reset{int(cfg.coverage_reset_interval)}.png"
        stem_eps, _ = os.path.splitext(out_eps)
        out_eps_win = f"{stem_eps}_reset{int(cfg.coverage_reset_interval)}.eps"
        plot_and_save(
            all_curves_win,
            ylabel=f"Visited states (reset every {int(cfg.coverage_reset_interval)} steps)",
            out_png=out_png_win,
            out_eps=out_eps_win,
        )
        print(f"[OK] Saved plot to: {out_png_win}")
        print(f"[OK] Saved plot to: {out_eps_win}")

    # ---------------- Heatmaps: 1x3 with shared colorbar ----------------
    stem_png, _ = os.path.splitext(args.out_png)
    try:
        env_vis = SimpleEnv(render_mode=None)
        env_vis.reset(seed=int(cfg.seed))
        _ = BottleneckVisualization(env_vis)  # kept for parity with fourrooms_exploration imports

        # Mean visitation vectors and save .npy per method
        v_mean_by_name: Dict[str, np.ndarray] = {}
        for name in methods.keys():
            v_stack = np.stack(all_visits[name], axis=0).astype(np.float64)
            v_mean = v_stack.mean(axis=0)
            v_mean_by_name[name] = v_mean
            np.save(f"{stem_png}_visits_{name.replace(' ', '_').replace('=', '').replace('(', '').replace(')', '').replace(',', '')}.npy", v_mean.astype(np.float64))

        H, W = int(wall_mask_hw.shape[0]), int(wall_mask_hw.shape[1])

        def _vector_to_hw_matrix(v: np.ndarray) -> np.ndarray:
            mat = np.full((H, W), np.nan, dtype=np.float64)
            for y in range(H):
                for x in range(W):
                    if not wall_mask_hw[y, x]:
                        idx = int(x + y * W)
                        if 0 <= idx < int(v.shape[0]):
                            mat[y, x] = float(v[idx])
            return mat

        names_to_plot = list(methods.keys())
        mats = [_vector_to_hw_matrix(v_mean_by_name[n]) for n in names_to_plot]
        vmin = float(np.nanmin([np.nanmin(m) for m in mats]))
        vmax = float(np.nanmax([np.nanmax(m) for m in mats]))

        fig, axes = plt.subplots(
            1,
            4,
            figsize=(13.8, 4.8),
            gridspec_kw={"width_ratios": [1, 1, 1, 0.06], "wspace": 0.35},
        )
        plot_axes = [axes[0], axes[1], axes[2]]
        axes[3].set_visible(False)
        cbar_ax = fig.add_axes(
            [
                axes[2].get_position().x1 + 0.050,
                axes[2].get_position().y0,
                0.015,
                axes[2].get_position().height,
            ]
        )

        cmap = plt.get_cmap("hot").copy()
        cmap.set_bad("gray")
        title_fs = 12
        tick_fs_hm = 10

        im = None
        for i, ax in enumerate(plot_axes):
            name = names_to_plot[i]
            data = mats[i]
            im = ax.imshow(
                data,
                cmap=cmap,
                origin="upper",
                interpolation="nearest",
                extent=[0, W, H, 0],
                vmin=vmin,
                vmax=vmax,
            )
            ax.set_title(name, fontsize=title_fs)
            ax.set_xticks(np.arange(0, W + 1, 1))
            ax.set_yticks(np.arange(0, H + 1, 1))
            ax.set_xlim(0, W)
            ax.set_ylim(H, 0)
            ax.grid(color="lightgray", linewidth=0.5)
            ax.tick_params(axis="both", which="major", labelsize=tick_fs_hm)
            ax.set_xticklabels([])
            ax.set_yticklabels([])

        if im is not None:
            cbar = fig.colorbar(im, cax=cbar_ax)
            cbar.ax.tick_params(labelsize=12)
            cbar.set_label("Mean visit count", fontsize=12)

        fig.subplots_adjust(top=0.92, left=0.03, right=0.985, bottom=0.05)
        out_all = f"{stem_png}_visits_all.png"
        fig.savefig(out_all, dpi=200)
        plt.close(fig)
        print(f"[OK] Saved combined visitation heatmap to: {out_all}")
    finally:
        try:
            env_vis.close()  # type: ignore[name-defined]
        except Exception:
            pass

    return 0


if __name__ == "__main__":
    raise SystemExit(main())

