"""
Four-rooms exploration comparison on gridworld/bottleneck_env.py.

Methods compared (coverage = # unique visited states vs steps):
  1) FPVR (direct action selection by minimizing future-past visitation redundancy)
  1b) FPVR with fixed SR (optional; two-stage: pretrain SR with random walk, then freeze SR)
  2) FPVR intrinsic reward + SARSA
  3) Successor-Prodecessor (SP) intrinsic reward + SARSA
  4) SR-novelty intrinsic reward (1 / ||SR(s,a)||_1) + SARSA
  5) Random walk

State space: tabular grid cells (x,y) flattened to s = y*W + x, excluding walls/blocked cells.
We build a deterministic transition table from the MiniGrid layout.

Usage:
  python fourrooms_exploration.py --total_steps 3000 --n_seeds 50 --horizon 500
"""

from __future__ import annotations

import argparse
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

# Local env (optional dependency: minigrid). Keep `-h` usable without installing it.
try:
    from bottleneck_env import SimpleEnv  # type: ignore
except Exception:
    SimpleEnv = None  # type: ignore

from utils import BottleneckVisualization


# ---------------------------------------------------------------------------
# Consistent color mapping across all figures (keep stable even if methods are
# added/removed). This matches the default matplotlib cycle order used when all
# 5 methods are plotted together:
#   FPVR (blue), r^{FP}+SARSA (orange), SP (green), SR (red), Random (purple).
# ---------------------------------------------------------------------------
METHOD_COLORS = {
    "FPVR": "C0",
    "FPVR with fixed SR": "C5",
    "$r^{FP}$ + SARSA": "C1",
    "SP+SARSA": "C2",
    "SR+SARSA": "C3",
    "Random Walk": "C4",
}


Action = int  # 0:up, 1:down, 2:left, 3:right


def _pos_to_state(x: int, y: int, width: int) -> int:
    return int(y * width + x)


def _state_to_pos(s: int, width: int) -> Tuple[int, int]:
    return int(s % width), int(s // width)


def build_transition_from_env(env: SimpleEnv) -> Tuple[np.ndarray, np.ndarray]:
    """
    Build deterministic transition matrix T[s,a]=s' and a wall_mask[H,W] (True = blocked).
    Any cell with an object that cannot be overlapped is treated as blocked (e.g., walls, locked doors).
    """
    width, height = int(env.width), int(env.height)
    n_states = width * height
    n_actions = 4

    wall_mask_xy = np.zeros((width, height), dtype=bool)
    for y in range(height):
        for x in range(width):
            cell = env.grid.get(x, y)
            # Block if there is an object and the agent cannot overlap it (walls, closed/locked doors, etc.)
            if cell is not None and (not cell.can_overlap()):
                wall_mask_xy[x, y] = True

    # Actions: up/down/left/right
    action_map = {
        0: (0, -1),
        1: (0, 1),
        2: (-1, 0),
        3: (1, 0),
    }

    T = np.zeros((n_states, n_actions), dtype=np.int32)
    for s in range(n_states):
        x, y = _state_to_pos(s, width)
        if wall_mask_xy[x, y]:
            T[s, :] = s
            continue
        for a in range(n_actions):
            dx, dy = action_map[a]
            nx, ny = x + dx, y + dy
            sn = s
            if 0 <= nx < width and 0 <= ny < height and (not wall_mask_xy[nx, ny]):
                sn = _pos_to_state(nx, ny, width)
            T[s, a] = sn

    # Return wall mask in (H,W) for plotting/logic
    return T, wall_mask_xy.T


def _first_free_state(wall_mask_hw: np.ndarray) -> int:
    h, w = wall_mask_hw.shape
    for y in range(h):
        for x in range(w):
            if not wall_mask_hw[y, x]:
                return _pos_to_state(x, y, w)
    raise RuntimeError("No free states in the environment.")


def _pick_action_eps_greedy(q_row: np.ndarray, *, eps: float, rng: np.random.Generator) -> Action:
    if rng.random() < eps:
        return int(rng.integers(0, q_row.shape[0]))
    maxv = float(np.max(q_row))
    cand = np.flatnonzero(q_row == maxv)
    return int(rng.choice(cand))


def _pick_action_eps_greedy_min(q_row: np.ndarray, *, eps: float, rng: np.random.Generator) -> Action:
    """ε-greedy for *minimization* (treat q_row as costs)."""
    if rng.random() < eps:
        return int(rng.integers(0, q_row.shape[0]))
    minv = float(np.min(q_row))
    cand = np.flatnonzero(q_row == minv)
    return int(rng.choice(cand))


def _zscore(x: np.ndarray, eps: float = 1e-8) -> np.ndarray:
    mu = float(np.mean(x))
    sd = float(np.std(x))
    return (x - mu) / (sd + eps)


def _cosine(u: np.ndarray, v: np.ndarray, eps: float = 1e-12) -> float:
    un = float(np.linalg.norm(u) + eps)
    vn = float(np.linalg.norm(v) + eps)
    return float(np.dot(u, v) / (un * vn))


def _softmax_sample_from_cost(costs: np.ndarray, *, beta: float, rng: np.random.Generator) -> Action:
    # lower cost -> higher probability
    c = costs - float(np.min(costs))
    logits = -float(beta) * c
    logits = logits - float(np.max(logits))
    p = np.exp(logits)
    p = p / np.clip(np.sum(p), 1e-12, None)
    return int(rng.choice(np.arange(costs.shape[0]), p=p))


@dataclass
class RunConfig:
    total_steps: int = 3000
    n_seeds: int = 50
    seed: int = 78

    # FPVR / SR learning
    gamma_sr: float = 0.9
    alpha_sr: float = 0.1
    c_decay: float = 0.999
    fpvr_sr_target: str = "min"  # {"mean","min"}: expectation over a' in SR target
    # Two-stage ablation: pretrain a fixed SR under random walk, then freeze it in FPVR.
    fixed_sr_pretrain_steps: int = 30000

    # Action selection temperature for direct FPVR
    fpvr_beta: float = 10.0

    # SARSA
    sarsa_alpha: float = 0.1
    sarsa_gamma: float = 0.99
    eps: float = 0.1

    # Intrinsic reward scaling
    r_fpvr_scale: float = 1.0
    r_sp_scale: float = 1.0
    r_sr_scale: float = 1.0
    sr_l1_eps: float = 1e-6

    # Start state
    start_x: Optional[int] = 1
    start_y: Optional[int] = 1
    # For visualization: windowed coverage resets every N steps (0 disables).
    coverage_reset_interval: int = 1000


def run_fpvr_direct(
    T: np.ndarray, wall_mask_hw: np.ndarray, cfg: RunConfig, rng: np.random.Generator
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Method 1: FPVR direct action selection using cosine(M[s,a], C)."""
    n_states, n_act = T.shape
    width = wall_mask_hw.shape[1]

    # Action-conditioned SR (successor occupancy vector over states)
    M = np.zeros((n_states, n_act, n_states), dtype=np.float32)
    # Discounted past visitation (vector over states)
    C = np.zeros((n_states,), dtype=np.float32)
    # State visitation counts during training
    visit_counts = np.zeros((n_states,), dtype=np.int64)

    visited = set()
    cov = np.zeros((cfg.total_steps,), dtype=np.int32)
    visited_win = set()
    cov_win = np.zeros((cfg.total_steps,), dtype=np.int32)

    # init start state
    if cfg.start_x is not None and cfg.start_y is not None and not wall_mask_hw[cfg.start_y, cfg.start_x]:
        s = _pos_to_state(cfg.start_x, cfg.start_y, width)
    else:
        s = _first_free_state(wall_mask_hw)

    prev_s = None
    prev_a = None

    for t in range(cfg.total_steps):
        visit_counts[s] += 1
        # visitation updates
        visited.add(s)
        visited_win.add(s)
        C *= float(cfg.c_decay)
        C[s] += 1.0

        # action costs = redundancy between future SR and past C
        costs = np.empty((n_act,), dtype=np.float32)
        for a in range(n_act):
            costs[a] = _cosine(M[s, a, :], C)
        costs = _zscore(costs)

        a = _softmax_sample_from_cost(costs, beta=float(cfg.fpvr_beta), rng=rng)

        # delayed SR update for (prev_s, prev_a) using current state as s'
        if prev_s is not None and prev_a is not None:
            e = np.zeros((n_states,), dtype=np.float32)
            e[s] = 1.0
            if str(getattr(cfg, "fpvr_sr_target", "mean")).lower() == "min":
                # Choose a' that minimizes FPVR(s', a') under the current past-visitation vector C.
                next_costs = np.empty((n_act,), dtype=np.float32)
                for aa in range(n_act):
                    next_costs[aa] = _cosine(M[s, aa, :], C)
                a_min = int(np.argmin(next_costs))
                psi_exp = M[s, a_min, :]
            else:
                # Uniform policy expectation
                psi_exp = M[s, :, :].mean(axis=0)
            target = e + float(cfg.gamma_sr) * psi_exp
            M[prev_s, prev_a, :] = (1.0 - float(cfg.alpha_sr)) * M[prev_s, prev_a, :] + float(cfg.alpha_sr) * target

        sn = int(T[s, a])
        prev_s, prev_a = s, a
        s = sn

        cov[t] = int(len(visited))
        cov_win[t] = int(len(visited_win))
        if int(cfg.coverage_reset_interval) > 0 and ((t + 1) % int(cfg.coverage_reset_interval) == 0):
            visited_win.clear()
            visited_win.add(s)

    return cov, cov_win, visit_counts


def _pretrain_action_sr_random_walk(
    T: np.ndarray, wall_mask_hw: np.ndarray, *, steps: int, gamma_sr: float, alpha_sr: float, rng: np.random.Generator, start_s: int
) -> np.ndarray:
    """
    Pretrain an action-conditioned SR table M[s,a,:] using a uniform-random behavior policy.

    We use the same 1-step TD update as in `run_fpvr_direct`, but fix the next-state
    expectation to the uniform mean over actions (policy = uniform random).
    """
    n_states, n_act = T.shape
    M = np.zeros((n_states, n_act, n_states), dtype=np.float32)
    s = int(start_s)
    prev_s: int | None = None
    prev_a: int | None = None

    for _t in range(int(steps)):
        a = int(rng.integers(0, n_act))

        # delayed SR update for (prev_s, prev_a) using current state as s'
        if prev_s is not None and prev_a is not None:
            e = np.zeros((n_states,), dtype=np.float32)
            e[s] = 1.0
            psi_exp = M[s, :, :].mean(axis=0)  # uniform policy expectation
            target = e + float(gamma_sr) * psi_exp
            M[prev_s, prev_a, :] = (1.0 - float(alpha_sr)) * M[prev_s, prev_a, :] + float(alpha_sr) * target

        sn = int(T[s, a])
        prev_s, prev_a = s, a
        s = sn

    return M


def run_fpvr_fixed_sr(
    T: np.ndarray, wall_mask_hw: np.ndarray, cfg: RunConfig, rng: np.random.Generator
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    FPVR with fixed SR (two-stage ablation).

    Stage 1: interact with the environment using a uniform-random policy and learn an
             action-conditioned SR table M[s,a,:].
    Stage 2: run the standard FPVR direct action selection, but *freeze* M (no SR updates).

    Important: this method must not perturb the RNG stream used by other methods in `main()`.
    We therefore clone the generator state and use a local RNG.
    """
    import copy

    # Clone RNG state without advancing `rng`
    rng_state = copy.deepcopy(rng.bit_generator.state)
    local_rng = np.random.default_rng()
    local_rng.bit_generator.state = rng_state

    n_states, n_act = T.shape
    width = wall_mask_hw.shape[1]

    # init start state (same rule as other methods)
    if cfg.start_x is not None and cfg.start_y is not None and not wall_mask_hw[cfg.start_y, cfg.start_x]:
        start_s = _pos_to_state(cfg.start_x, cfg.start_y, width)
    else:
        start_s = _first_free_state(wall_mask_hw)

    # ---------- Stage 1: SR pretraining ----------
    M = _pretrain_action_sr_random_walk(
        T,
        wall_mask_hw,
        steps=int(getattr(cfg, "fixed_sr_pretrain_steps", 30000)),
        gamma_sr=float(cfg.gamma_sr),
        alpha_sr=float(cfg.alpha_sr),
        rng=local_rng,
        start_s=int(start_s),
    )

    # ---------- Stage 2: FPVR action selection with fixed M ----------
    C = np.zeros((n_states,), dtype=np.float32)
    visit_counts = np.zeros((n_states,), dtype=np.int64)

    visited = set()
    cov = np.zeros((cfg.total_steps,), dtype=np.int32)
    visited_win = set()
    cov_win = np.zeros((cfg.total_steps,), dtype=np.int32)

    s = int(start_s)
    for t in range(cfg.total_steps):
        visit_counts[s] += 1
        visited.add(s)
        visited_win.add(s)

        C *= float(cfg.c_decay)
        C[s] += 1.0

        costs = np.empty((n_act,), dtype=np.float32)
        for a in range(n_act):
            costs[a] = _cosine(M[s, a, :], C)
        costs = _zscore(costs)
        a = _softmax_sample_from_cost(costs, beta=float(cfg.fpvr_beta), rng=local_rng)

        s = int(T[s, a])
        cov[t] = int(len(visited))
        cov_win[t] = int(len(visited_win))
        if int(cfg.coverage_reset_interval) > 0 and ((t + 1) % int(cfg.coverage_reset_interval) == 0):
            visited_win.clear()
            visited_win.add(s)

    return cov, cov_win, visit_counts


def run_fpvr_reward_sarsa(
    T: np.ndarray, wall_mask_hw: np.ndarray, cfg: RunConfig, rng: np.random.Generator
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Method 2: FPVR-equivalent intrinsic reward + SARSA.

    Paper alignment:
    The *unnormalized* FPVR can be written as:
        <m^π(s_t,a_t), c_t> = sum_{s'} M^π(s_t,a_t,s') * C_t(s')
    which is the action-value of a *non-stationary intrinsic reward* r_t(s)=C_t(s).

    Therefore, the equivalent intrinsic reward used here is:
        r_t = C_t(s_{t+1})

    Note: this reward *encourages revisiting* recently visited states (high redundancy),
    and is included as a baseline.
    """
    n_states, n_act = T.shape
    width = wall_mask_hw.shape[1]

    # Learn a COST-to-go Q (smaller is better), induced by the non-stationary visit cost C_t(s).
    Q = np.zeros((n_states, n_act), dtype=np.float32)
    C = np.zeros((n_states,), dtype=np.float32)
    visit_counts = np.zeros((n_states,), dtype=np.int64)

    visited = set()
    cov = np.zeros((cfg.total_steps,), dtype=np.int32)
    visited_win = set()
    cov_win = np.zeros((cfg.total_steps,), dtype=np.int32)

    if cfg.start_x is not None and cfg.start_y is not None and not wall_mask_hw[cfg.start_y, cfg.start_x]:
        s = _pos_to_state(cfg.start_x, cfg.start_y, width)
    else:
        s = _first_free_state(wall_mask_hw)

    a = _pick_action_eps_greedy_min(Q[s], eps=float(cfg.eps), rng=rng)

    for t in range(cfg.total_steps):
        visit_counts[s] += 1
        visited.add(s)
        visited_win.add(s)
        C *= float(cfg.c_decay)
        C[s] += 1.0

        sn = int(T[s, a])
        # FPVR-equivalent intrinsic *cost*: visit cost to the next state.
        cost = float(cfg.r_fpvr_scale) * float(C[sn])
        a_next = _pick_action_eps_greedy_min(Q[sn], eps=float(cfg.eps), rng=rng)

        # SARSA update (cost minimization): Q <- cost + gamma * Q(s', a')
        td_target = cost + float(cfg.sarsa_gamma) * float(Q[sn, a_next])
        Q[s, a] = (1.0 - float(cfg.sarsa_alpha)) * Q[s, a] + float(cfg.sarsa_alpha) * td_target

        s, a = sn, a_next
        cov[t] = int(len(visited))
        cov_win[t] = int(len(visited_win))
        if int(cfg.coverage_reset_interval) > 0 and ((t + 1) % int(cfg.coverage_reset_interval) == 0):
            visited_win.clear()
            visited_win.add(s)

    return cov, cov_win, visit_counts


def run_sp_reward_sarsa(
    T: np.ndarray, wall_mask_hw: np.ndarray, cfg: RunConfig, rng: np.random.Generator
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Method 3: Successor-Prodecessor (SP) intrinsic reward + SARSA."""
    n_states, n_act = T.shape
    width = wall_mask_hw.shape[1]

    Q = np.zeros((n_states, n_act), dtype=np.float32)
    M_ss = np.zeros((n_states, n_states), dtype=np.float32)  # state SR
    visit_counts = np.zeros((n_states,), dtype=np.int64)

    visited = set()
    cov = np.zeros((cfg.total_steps,), dtype=np.int32)
    visited_win = set()
    cov_win = np.zeros((cfg.total_steps,), dtype=np.int32)

    if cfg.start_x is not None and cfg.start_y is not None and not wall_mask_hw[cfg.start_y, cfg.start_x]:
        s = _pos_to_state(cfg.start_x, cfg.start_y, width)
    else:
        s = _first_free_state(wall_mask_hw)

    a = _pick_action_eps_greedy(Q[s], eps=float(cfg.eps), rng=rng)

    for t in range(cfg.total_steps):
        visit_counts[s] += 1
        visited.add(s)
        visited_win.add(s)

        sn = int(T[s, a])

        # Update M_ss (TD SR)
        e = np.zeros((n_states,), dtype=np.float32)
        e[sn] = 1.0
        target_m = e + float(cfg.gamma_sr) * M_ss[sn, :]
        M_ss[s, :] = (1.0 - float(cfg.alpha_sr)) * M_ss[s, :] + float(cfg.alpha_sr) * target_m

        # SP intrinsic reward
        col_sum = float(M_ss[:, sn].sum())
        r_int = float(cfg.r_sp_scale) * (float(M_ss[s, sn]) - col_sum)

        a_next = _pick_action_eps_greedy(Q[sn], eps=float(cfg.eps), rng=rng)
        td_target = r_int + float(cfg.sarsa_gamma) * float(Q[sn, a_next])
        Q[s, a] = (1.0 - float(cfg.sarsa_alpha)) * Q[s, a] + float(cfg.sarsa_alpha) * td_target

        s, a = sn, a_next
        cov[t] = int(len(visited))
        cov_win[t] = int(len(visited_win))
        if int(cfg.coverage_reset_interval) > 0 and ((t + 1) % int(cfg.coverage_reset_interval) == 0):
            visited_win.clear()
            visited_win.add(s)

    return cov, cov_win, visit_counts


def run_sr_l1_reward_sarsa(
    T: np.ndarray, wall_mask_hw: np.ndarray, cfg: RunConfig, rng: np.random.Generator
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Method 4: SR intrinsic reward + SARSA.

    Paper correction:
      intrinsic reward for (s,a,s') should be
        r_t = scale / (||ψ(s')||_1 + ε)
    where ψ(s') is the successor representation (vector) of the next state.

    Implementation:
      We learn a *state* SR table M_ss[s,:] with TD(0):
        M_ss[s,:] <- (1-α) M_ss[s,:] + α ( e_{s'} + γ_sr M_ss[s',:] )
      and compute reward from the current estimate of M_ss[s',:].
    """
    n_states, n_act = T.shape
    width = wall_mask_hw.shape[1]

    Q = np.zeros((n_states, n_act), dtype=np.float32)
    M_ss = np.zeros((n_states, n_states), dtype=np.float32)  # state SR
    visit_counts = np.zeros((n_states,), dtype=np.int64)

    visited = set()
    cov = np.zeros((cfg.total_steps,), dtype=np.int32)
    visited_win = set()
    cov_win = np.zeros((cfg.total_steps,), dtype=np.int32)

    if cfg.start_x is not None and cfg.start_y is not None and not wall_mask_hw[cfg.start_y, cfg.start_x]:
        s = _pos_to_state(cfg.start_x, cfg.start_y, width)
    else:
        s = _first_free_state(wall_mask_hw)

    a = _pick_action_eps_greedy(Q[s], eps=float(cfg.eps), rng=rng)

    for t in range(cfg.total_steps):
        visit_counts[s] += 1
        visited.add(s)
        visited_win.add(s)
        sn = int(T[s, a])

        # Reward uses next-state SR vector ψ(s') ≈ M_ss[s',:]
        sr_l1_next = float(np.sum(np.abs(M_ss[sn, :])))
        r_int = float(cfg.r_sr_scale) / (sr_l1_next + float(cfg.sr_l1_eps))

        a_next = _pick_action_eps_greedy(Q[sn], eps=float(cfg.eps), rng=rng)

        td_target = r_int + float(cfg.sarsa_gamma) * float(Q[sn, a_next])
        Q[s, a] = (1.0 - float(cfg.sarsa_alpha)) * Q[s, a] + float(cfg.sarsa_alpha) * td_target

        # Update state SR for s
        e = np.zeros((n_states,), dtype=np.float32)
        e[sn] = 1.0
        target = e + float(cfg.gamma_sr) * M_ss[sn, :]
        M_ss[s, :] = (1.0 - float(cfg.alpha_sr)) * M_ss[s, :] + float(cfg.alpha_sr) * target

        s, a = sn, a_next
        cov[t] = int(len(visited))
        cov_win[t] = int(len(visited_win))
        if int(cfg.coverage_reset_interval) > 0 and ((t + 1) % int(cfg.coverage_reset_interval) == 0):
            visited_win.clear()
            visited_win.add(s)

    return cov, cov_win, visit_counts


def run_random_walk(
    T: np.ndarray, wall_mask_hw: np.ndarray, cfg: RunConfig, rng: np.random.Generator
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    n_states, n_act = T.shape
    width = wall_mask_hw.shape[1]

    visited = set()
    cov = np.zeros((cfg.total_steps,), dtype=np.int32)
    visited_win = set()
    cov_win = np.zeros((cfg.total_steps,), dtype=np.int32)
    visit_counts = np.zeros((n_states,), dtype=np.int64)

    if cfg.start_x is not None and cfg.start_y is not None and not wall_mask_hw[cfg.start_y, cfg.start_x]:
        s = _pos_to_state(cfg.start_x, cfg.start_y, width)
    else:
        s = _first_free_state(wall_mask_hw)

    for t in range(cfg.total_steps):
        visit_counts[s] += 1
        visited.add(s)
        visited_win.add(s)
        a = int(rng.integers(0, n_act))
        s = int(T[s, a])
        cov[t] = int(len(visited))
        cov_win[t] = int(len(visited_win))
        if int(cfg.coverage_reset_interval) > 0 and ((t + 1) % int(cfg.coverage_reset_interval) == 0):
            visited_win.clear()
            visited_win.add(s)

    return cov, cov_win, visit_counts


def _mean_std(curves: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
    arr = np.stack(curves, axis=0)
    return arr.mean(axis=0), arr.std(axis=0)


def main() -> int:
    p = argparse.ArgumentParser(description="Fourrooms exploration comparison (coverage curves).")
    p.add_argument("--total_steps", type=int, default=500)
    p.add_argument("--n_seeds", type=int, default=50)
    p.add_argument("--seed", type=int, default=78)
    p.add_argument("--start_x", type=int, default=1)
    p.add_argument("--start_y", type=int, default=1)

    # FPVR/SR
    p.add_argument("--gamma_sr", type=float, default=0.9)
    p.add_argument("--alpha_sr", type=float, default=0.1)
    p.add_argument("--c_decay", type=float, default=0.999)
    p.add_argument("--fpvr_beta", type=float, default=10.0, help="Softmax temperature for direct FPVR.")
    p.add_argument(
        "--fpvr_sr_target",
        type=str,
        default="min",
        choices=["mean", "min"],
        help="SR TD target expectation over a' at next state s'. "
             "mean: use mean_a M[s',a]; min: use M[s',a_min] where a_min minimizes FPVR(s',a).",
    )
    p.add_argument(
        "--compare_fixed",
        default=False,
        help="If set, add an extra method: FPVR with fixed SR (two-stage: random-walk SR pretraining, then frozen SR).",
    )
    p.add_argument(
        "--fixed_sr_pretrain_steps",
        type=int,
        default=30000,
        help="Stage-1 SR pretraining steps for --compare_fixed (default: 30000).",
    )

    # SARSA
    p.add_argument("--eps", type=float, default=0.1, help="Epsilon for epsilon-greedy (SARSA methods).")
    p.add_argument("--sarsa_alpha", type=float, default=0.1)
    p.add_argument("--sarsa_gamma", type=float, default=0.99)

    # Reward scales
    p.add_argument("--r_fpvr_scale", type=float, default=1.0)
    p.add_argument("--r_sp_scale", type=float, default=1.0)
    p.add_argument("--r_sr_scale", type=float, default=1.0)
    p.add_argument("--sr_l1_eps", type=float, default=1e-6)

    # Coverage visualization (windowed coverage)
    p.add_argument(
        "--coverage_reset_interval",
        type=int,
        default=1000,
        help="Reset the windowed coverage every N steps (0 disables). Default: 1000.",
    )

    # Output paths: by default, save under gridworld/results/ (not repo root).
    _DEFAULT_RESULTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results")
    p.add_argument("--out_png", type=str, default=os.path.join(_DEFAULT_RESULTS_DIR, "fourrooms_coverage.png"))
    p.add_argument(
        "--out_eps",
        type=str,
        default=None,
        help="Optional EPS output path. If not set, will save an .eps next to --out_png.",
    )
    args = p.parse_args()

    def _resolve_out_path(path: str) -> str:
        """
        Keep outputs inside gridworld/results by default.
        - Absolute paths are kept as-is.
        - Relative paths are interpreted relative to gridworld/results.
        """
        if os.path.isabs(path):
            return path
        return os.path.join(_DEFAULT_RESULTS_DIR, path)

    # Normalize output paths (so passing a relative filename won't end up in repo root).
    args.out_png = _resolve_out_path(str(args.out_png))
    if args.out_eps is not None:
        args.out_eps = _resolve_out_path(str(args.out_eps))

    # Ensure output directory exists.
    os.makedirs(os.path.dirname(os.path.abspath(args.out_png)), exist_ok=True)

    if SimpleEnv is None:
        raise ModuleNotFoundError(
            "Missing optional dependency for gridworld/minigrid environment (minigrid). "
            "Please install it (e.g., `pip install minigrid`) to run fourrooms_exploration."
        )

    cfg = RunConfig(
        total_steps=int(args.total_steps),
        n_seeds=int(args.n_seeds),
        seed=int(args.seed),
        gamma_sr=float(args.gamma_sr),
        alpha_sr=float(args.alpha_sr),
        c_decay=float(args.c_decay),
        fpvr_sr_target=str(args.fpvr_sr_target),
        fixed_sr_pretrain_steps=int(args.fixed_sr_pretrain_steps),
        fpvr_beta=float(args.fpvr_beta),
        sarsa_alpha=float(args.sarsa_alpha),
        sarsa_gamma=float(args.sarsa_gamma),
        eps=float(args.eps),
        r_fpvr_scale=float(args.r_fpvr_scale),
        r_sp_scale=float(args.r_sp_scale),
        r_sr_scale=float(args.r_sr_scale),
        sr_l1_eps=float(args.sr_l1_eps),
        start_x=int(args.start_x) if args.start_x is not None else None,
        start_y=int(args.start_y) if args.start_y is not None else None,
        coverage_reset_interval=int(args.coverage_reset_interval),
    )

    env = SimpleEnv(render_mode=None)
    env.reset(seed=cfg.seed)
    T, wall_mask_hw = build_transition_from_env(env)
    env.close()

    methods = {
        "FPVR": run_fpvr_direct,
        # Optional two-stage ablation
        **({"FPVR with fixed SR": run_fpvr_fixed_sr} if bool(getattr(args, "compare_fixed", False)) else {}),
        "$r^{FP}$ + SARSA": run_fpvr_reward_sarsa,
        "SP+SARSA": run_sp_reward_sarsa,
        "SR+SARSA": run_sr_l1_reward_sarsa,
        "Random Walk": run_random_walk,
    }

    all_curves: Dict[str, List[np.ndarray]] = {k: [] for k in methods.keys()}
    all_curves_win: Dict[str, List[np.ndarray]] = {k: [] for k in methods.keys()}
    all_visits: Dict[str, List[np.ndarray]] = {k: [] for k in methods.keys()}
    for i in range(cfg.n_seeds):
        rng = np.random.default_rng(cfg.seed + i)
        for name, fn in methods.items():
            cov_full, cov_win, visit_counts = fn(T, wall_mask_hw, cfg, rng)
            all_curves[name].append(cov_full)
            all_curves_win[name].append(cov_win)
            all_visits[name].append(visit_counts)

    # ---------------- Plot styling (paper-ready) ----------------
    label_fs = 18
    tick_fs = 14
    legend_fs = 14
    title_fs = 18

    def plot_and_save(curves_dict: Dict[str, List[np.ndarray]], *, ylabel: str, out_png: str, out_eps: str):
        fig, ax = plt.subplots(figsize=(9, 5.4))
        x = np.arange(cfg.total_steps)

        # EPS/PDF friendliness: avoid relying on alpha-transparency for uncertainty bands.
        def _lighten(color, amount: float = 0.85):
            r, g, b = mcolors.to_rgb(color)
            return (r + (1 - r) * amount, g + (1 - g) * amount, b + (1 - b) * amount)

        for idx, (name, curves) in enumerate(curves_dict.items()):
            mean, std = _mean_std(curves)
            # Use stable per-method colors; fallback to cycle if an unknown label appears.
            color = METHOD_COLORS.get(name, f"C{idx % 10}")
            band_color = _lighten(color, amount=0.88)
            ax.fill_between(x, mean - std, mean + std, color=band_color, linewidth=0.0, zorder=1)
            ax.plot(x, mean, label=name, linewidth=2.6, color=color, zorder=2)

        ax.set_xlabel("Steps", fontsize=label_fs)
        ax.set_ylabel(ylabel, fontsize=label_fs)
        ax.tick_params(axis="both", which="major", labelsize=tick_fs)
        ax.grid(alpha=0.3)
        ax.legend(fontsize=legend_fs, frameon=True, loc="lower right")
        fig.tight_layout()
        fig.savefig(out_png, dpi=200)
        fig.savefig(out_eps, format="eps")
        plt.close(fig)

    # Plot 1: cumulative coverage (same as before)
    out_eps = args.out_eps
    if out_eps is None:
        stem, _ext = os.path.splitext(args.out_png)
        out_eps = stem + ".eps"
    os.makedirs(os.path.dirname(os.path.abspath(out_eps)), exist_ok=True)
    plot_and_save(all_curves, ylabel="Number of visited states", out_png=args.out_png, out_eps=out_eps)
    print(f"[OK] Saved plot to: {args.out_png}")
    print(f"[OK] Saved plot to: {out_eps}")

    # Plot 2: windowed coverage (periodic reset; visualization only)
    if int(cfg.coverage_reset_interval) > 0:
        stem_png, _ = os.path.splitext(args.out_png)
        out_png_win = f"{stem_png}_reset{int(cfg.coverage_reset_interval)}.png"
        stem_eps, _ = os.path.splitext(out_eps)
        out_eps_win = f"{stem_eps}_reset{int(cfg.coverage_reset_interval)}.eps"
        plot_and_save(
            all_curves_win,
            ylabel=f"Visited states (reset every {int(cfg.coverage_reset_interval)} steps)",
            out_png=out_png_win,
            out_eps=out_eps_win,
        )
        print(f"[OK] Saved plot to: {out_png_win}")
        print(f"[OK] Saved plot to: {out_eps_win}")

    # ---------------- State visitation heatmaps (per method) ----------------
    # Save mean visitation counts across seeds; use log1p for visualization.
    stem_png, _ = os.path.splitext(args.out_png)
    try:
        env_vis = SimpleEnv(render_mode=None)
        env_vis.reset(seed=int(cfg.seed))
        viz = BottleneckVisualization(env_vis)

        def _safe_name(name: str) -> str:
            out = str(name).replace(" ", "_")
            out = out.replace("+", "plus")
            out = out.replace("/", "_")
            out = out.replace("$", "")
            out = out.replace("{", "")
            out = out.replace("}", "")
            out = out.replace("^", "")
            return out

        # --- Collect mean visitation vectors (and still save per-method .npy) ---
        v_mean_by_name: Dict[str, np.ndarray] = {}
        for name in methods.keys():
            if not all_visits.get(name):
                continue
            v_stack = np.stack(all_visits[name], axis=0).astype(np.float64)
            v_mean = v_stack.mean(axis=0)
            v_mean_by_name[name] = v_mean
            np.save(f"{stem_png}_visits_{_safe_name(name)}.npy", v_mean.astype(np.float64))

        # --- Combined heatmap figure with shared colorbar ---
        # Layout:
        # - default (no --compare_fixed): 1x5
        # - with --compare_fixed: 2x3
        compare_fixed = bool(getattr(args, "compare_fixed", False))
        if compare_fixed:
            nrows, ncols = 2, 3
            desired_order = [
                "FPVR",
                "FPVR with fixed SR",
                "$r^{FP}$ + SARSA",
                "SP+SARSA",
                "SR+SARSA",
                "Random Walk",
            ]
        else:
            nrows, ncols = 1, 5
            desired_order = [
                "FPVR",
                "$r^{FP}$ + SARSA",
                "SP+SARSA",
                "SR+SARSA",
                "Random Walk",
            ]

        names_to_plot = [n for n in desired_order if n in v_mean_by_name]
        if len(names_to_plot) == 0:
            print("[Warning] No visitation data available; skipped heatmap plotting.")
        else:
            # Convert visitation vectors into matrices with wall cells as NaN (so they render as gray).
            H, W = int(wall_mask_hw.shape[0]), int(wall_mask_hw.shape[1])

            def _vector_to_hw_matrix(v: np.ndarray) -> np.ndarray:
                mat = np.full((H, W), np.nan, dtype=np.float64)
                for y in range(H):
                    for x in range(W):
                        if not wall_mask_hw[y, x]:
                            idx = int(x + y * W)
                            if 0 <= idx < int(v.shape[0]):
                                mat[y, x] = float(v[idx])
                return mat

            mats = []
            for n in names_to_plot:
                # Use RAW mean visit counts (not log-scaled).
                mats.append(_vector_to_hw_matrix(v_mean_by_name[n]))

            # Shared normalization across all subplots.
            vmin = float(np.nanmin([np.nanmin(m) for m in mats]))
            vmax = float(np.nanmax([np.nanmax(m) for m in mats]))

            # ------------------------------------------------------------------
            # Layout with a dedicated colorbar column (avoids overlap).
            # - no --compare_fixed: 1x5 heatmaps + 1 cbar column => 1x6
            # - with --compare_fixed: 2x3 heatmaps + 1 cbar column => 2x4
            # ------------------------------------------------------------------
            if compare_fixed:
                fig, axes = plt.subplots(
                    2,
                    4,
                    figsize=(13.2, 8.4),
                    gridspec_kw={"width_ratios": [1, 1, 1, 0.06], "wspace": 0.35, "hspace": 0.35},
                )
                plot_axes = [
                    axes[0, 0],
                    axes[0, 1],
                    axes[0, 2],
                    axes[1, 0],
                    axes[1, 1],
                    axes[1, 2],
                ]
                cax = axes[:, 3]
                # Hide the empty cbar column axes frames (we'll use it as cbar axis)
                for ax in axes[:, 3]:
                    ax.set_visible(False)
                # We'll create a single cbar axis spanning both rows using the top one.
                # Put the colorbar slightly further to the right to avoid any overlap.
                cbar_ax = fig.add_axes(
                    [
                        axes[0, 2].get_position().x1 + 0.08,
                        axes[1, 2].get_position().y0,
                        0.015,
                        axes[0, 2].get_position().y1 - axes[1, 2].get_position().y0,
                    ]
                )
            else:
                fig, axes = plt.subplots(
                    1,
                    6,
                    figsize=(18.6, 4.6),
                    gridspec_kw={"width_ratios": [1, 1, 1, 1, 1, 0.06], "wspace": 0.35},
                )
                plot_axes = [axes[0], axes[1], axes[2], axes[3], axes[4]]
                axes[5].set_visible(False)
                # Put the colorbar slightly further to the right to avoid any overlap.
                cbar_ax = fig.add_axes(
                    [
                        axes[4].get_position().x1 + 0.050,
                        axes[4].get_position().y0,
                        0.015,
                        axes[4].get_position().height,
                    ]
                )

            cmap = plt.get_cmap("hot").copy()
            cmap.set_bad("gray")

            title_fs = 12  # smaller than the default single-plot titles
            tick_fs = 10

            im = None
            for i, ax in enumerate(plot_axes):
                if i >= len(names_to_plot):
                    ax.axis("off")
                    continue
                name = names_to_plot[i]
                data = mats[i]
                im = ax.imshow(
                    data,
                    cmap=cmap,
                    origin="upper",
                    interpolation="nearest",
                    # Align pixels to integer grid cell boundaries (avoids half-cell offset).
                    extent=[0, W, H, 0],
                    vmin=vmin,
                    vmax=vmax,
                )
                ax.set_title(name, fontsize=title_fs)
                ax.set_xticks(np.arange(0, W + 1, 1))
                ax.set_yticks(np.arange(0, H + 1, 1))
                ax.set_xlim(0, W)
                ax.set_ylim(H, 0)
                ax.grid(color="lightgray", linewidth=0.5)
                ax.tick_params(axis="both", which="major", labelsize=tick_fs)
                ax.set_xticklabels([])
                ax.set_yticklabels([])

            if im is not None:
                cbar = fig.colorbar(im, cax=cbar_ax)
                cbar.ax.tick_params(labelsize=12)
                cbar.set_label("Mean visit count", fontsize=12)

            # Add a bit more room for titles and the external colorbar.
            fig.subplots_adjust(top=0.92, left=0.03, right=0.985, bottom=0.05)
            out_all = f"{stem_png}_visits_all.png"
            fig.savefig(out_all, dpi=200)
            plt.close(fig)
            print(f"[OK] Saved combined visitation heatmap to: {out_all}")
    finally:
        try:
            env_vis.close()  # type: ignore[name-defined]
        except Exception:
            pass
    return 0


if __name__ == "__main__":
    raise SystemExit(main())

