"""
Reward-collection experiment on gridworld/bottleneck_env.py (tabular model).

Task (external reward):
  - At the beginning of each episode, the agent starts from a random free cell
    (non-wall) that is NOT the current goal cell.
  - A target goal cell exists on a free cell (not wall), with constraint: goal != start.
  - Stepping onto the goal yields reward +1 and the episode terminates early.
  - All other steps give a constant step penalty (default -0.01).
  - Episode ends when either (success) or (episode step limit == horizon).
  - Goal position refreshes every K *training steps* (can happen mid-episode).

Training protocol:
  - Train online for a fixed time-step budget `total_steps`.
  - On each episode end (success or horizon), reset agent to the start state.
  - Each method uses its own behavior policy for exploration, but the goal refresh schedule is
    deterministic from (seed, step//K) so it is identical across methods for the same seed.
  - If --random_refresh is enabled, we pre-generate a random goal list (per seed) before training
    so all methods share the exact same random goal at each refresh block.

Plotting protocol (what we plot):
  - We periodically evaluate the current Q-function (SARSA Q) at fixed training checkpoints.
  - Evaluation uses epsilon-greedy *on the extrinsic SARSA Q* (no intrinsic bias at evaluation time).
  - At each checkpoint, we run `n_eval_episodes` evaluation episodes and record the mean episode return.
  - The final curve shows mean ± std over training seeds, optionally smoothed over checkpoints.
"""

from __future__ import annotations

import argparse
import os
from typing import Dict, List, Tuple

import numpy as np
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt

# Optional dependency: minigrid env
try:
    from bottleneck_env import SimpleEnv  # type: ignore
except Exception:
    SimpleEnv = None  # type: ignore

from fourrooms_exploration import _cosine, _pos_to_state, _state_to_pos, _zscore, build_transition_from_env
from utils import BottleneckVisualization


# Fixed color mapping for the 5 methods used in THIS reward experiment:
# (Keep stable even if methods are later added/removed.)
REWARD_METHOD_COLORS = {
    "FPVR+SARSA": "C0",  # blue
    "rFP+SARSA": "C1",  # orange
    "SP+SARSA": "C2",  # green
    "SR+SARSA": "C3",  # red
    "SARSA": "C4",  # purple
}

# Display-only names (for legend text). Keep internal method keys unchanged.
REWARD_METHOD_DISPLAY_NAMES = {
    "rFP+SARSA": r"${r}^{FP}$+SARSA",
}


def _pick_argmax_tiebreak(scores: np.ndarray, rng: np.random.Generator) -> int:
    m = float(scores.max())
    cand = np.flatnonzero(scores == m)
    return int(rng.choice(cand))


def _pick_action_eps_greedy_scores(scores: np.ndarray, *, eps: float, rng: np.random.Generator) -> int:
    if rng.random() < float(eps):
        return int(rng.integers(0, scores.shape[0]))
    return _pick_argmax_tiebreak(scores, rng)


def _softmax_from_neg_cost(costs: np.ndarray, beta: float, rng: np.random.Generator) -> int:
    # lower cost => higher prob
    c = costs - float(np.min(costs))
    logits = -float(beta) * c
    logits = logits - float(np.max(logits))
    p = np.exp(logits)
    p = p / np.clip(np.sum(p), 1e-12, None)
    return int(rng.choice(np.arange(costs.shape[0]), p=p))


def _free_states(wall_mask_hw: np.ndarray) -> np.ndarray:
    h, w = wall_mask_hw.shape
    xs, ys = np.where(~wall_mask_hw)
    # xs,ys are in (y,x) because wall_mask_hw is [H,W]
    states = np.array([_pos_to_state(int(x), int(y), w) for y, x in zip(xs, ys)], dtype=np.int32)
    return states


def _goal_for_block(
    *,
    wall_mask_hw: np.ndarray,
    start_state: int,
    free_states: np.ndarray,
    base_seed: int,
    block_idx: int,
    goal_schedule_states: List[int] | None = None,
) -> int:
    """
    Deterministically sample goal_state for a given refresh-block index.
    This makes the goal schedule identical across methods for the same seed and refresh schedule.
    """
    # Optional full manual schedule override: goal_schedule_states[block_idx] if provided.
    if goal_schedule_states is not None and int(block_idx) < len(goal_schedule_states):
        g = int(goal_schedule_states[int(block_idx)])
        width = int(wall_mask_hw.shape[1])
        gx, gy = _state_to_pos(g, width)
        if wall_mask_hw[int(gy), int(gx)]:
            raise ValueError(f"goal_schedule block {int(block_idx)} is on a wall cell: goal=(x={gx},y={gy})")
        return g

    rng = np.random.default_rng(int(base_seed) + 1000003 * int(block_idx) + 17)
    # Start state is randomized per-episode, so we do NOT exclude any particular start here.
    goal = int(rng.choice(free_states))
    return goal


def _sample_episode_start_state(
    *,
    free_states: np.ndarray,
    goal_state: int,
    rng: np.random.Generator,
) -> int:
    """
    Sample an episode start state uniformly from free (non-wall) states,
    excluding the current goal state.
    """
    candidates = free_states[free_states != int(goal_state)]
    if candidates.size == 0:
        # Degenerate case (should not happen in normal maps); fallback to any free state.
        return int(rng.choice(free_states))
    return int(rng.choice(candidates))


def _parse_goal_schedule_text(s: str) -> List[Tuple[int, int]]:
    """
    Parse a goal schedule string like:
      "12,12; 2,5; 7,9"
    into [(12,12),(2,5),(7,9)].
    """
    txt = str(s).strip()
    if txt == "":
        return []
    parts = [p.strip() for p in txt.split(";") if p.strip()]
    out: List[Tuple[int, int]] = []
    for p in parts:
        # allow "x,y" with optional spaces
        if "," not in p:
            raise ValueError(f"Invalid goal_schedule entry '{p}'. Expected 'x,y'.")
        xs, ys = [t.strip() for t in p.split(",", 1)]
        out.append((int(xs), int(ys)))
    return out


def _load_goal_schedule(
    *,
    goal_schedule: str,
    goal_schedule_file: str | None,
    width: int,
    wall_mask_hw: np.ndarray,
    start_state: int,
) -> List[int]:
    """
    Build a list of goal states per block from either inline text or a file.
    File format: one "x y" or "x,y" per line (blank lines / lines starting with # are ignored).
    """
    pairs: List[Tuple[int, int]] = []
    if goal_schedule_file is not None:
        path = str(goal_schedule_file)
        with open(path, "r", encoding="utf-8") as f:
            for line in f:
                t = line.strip()
                if t == "" or t.startswith("#"):
                    continue
                if "," in t:
                    xs, ys = [a.strip() for a in t.split(",", 1)]
                else:
                    toks = t.split()
                    if len(toks) != 2:
                        raise ValueError(f"Invalid line in goal_schedule_file: '{t}' (expected 'x y' or 'x,y')")
                    xs, ys = toks[0], toks[1]
                pairs.append((int(xs), int(ys)))
    else:
        pairs = _parse_goal_schedule_text(goal_schedule)

    states: List[int] = []
    for i, (x, y) in enumerate(pairs):
        s = _pos_to_state(int(x), int(y), int(width))
        if wall_mask_hw[int(y), int(x)]:
            raise ValueError(f"goal_schedule idx={i} is on a wall cell: goal=(x={x},y={y})")
        states.append(int(s))
    return states


def _moving_average(x: np.ndarray, window: int) -> np.ndarray:
    if window <= 1:
        return x
    w = int(window)
    kernel = np.ones((w,), dtype=np.float32) / float(w)
    # 'valid' then pad to original length by repeating edge values (for cleaner plots)
    y = np.convolve(x, kernel, mode="valid").astype(np.float32)
    if y.size == 0:
        return x
    pad_left = w - 1
    return np.concatenate([np.full((pad_left,), y[0], dtype=np.float32), y], axis=0)


def _eval_policy_eps_greedy(
    *,
    Q: np.ndarray,
    T: np.ndarray,
    wall_mask_hw: np.ndarray,
    start_state: int,
    horizon: int,
    n_eval_episodes: int,
    eps_eval: float,
    step_penalty: float,
    goal_state: int,
    base_seed: int,
    eval_seed_offset: int,
) -> float:
    """
    Evaluate epsilon-greedy on SARSA Q, return mean episode return.

    IMPORTANT: evaluation uses a *fixed* goal_state (no refresh during evaluation),
    so the curve reflects adaptation to the current training stage/task.
    """
    n_states, n_act = T.shape
    rng = np.random.default_rng(int(base_seed) + int(eval_seed_offset))
    free_states = _free_states(wall_mask_hw)

    rets: List[float] = []
    for ep in range(int(n_eval_episodes)):
        # Episode start is randomized on free cells excluding the goal.
        s = _sample_episode_start_state(free_states=free_states, goal_state=int(goal_state), rng=rng)
        ep_ret = 0.0
        done = False
        for _ in range(int(horizon)):
            # epsilon-greedy on Q only
            if rng.random() < float(eps_eval):
                a = int(rng.integers(0, n_act))
            else:
                a = _pick_argmax_tiebreak(Q[s], rng)
            sn = int(T[s, a])

            r = float(step_penalty)
            if sn == int(goal_state):
                r = 1.0
                done = True

            ep_ret += float(r)
            s = sn
            if done:
                break
        rets.append(float(ep_ret))

    return float(np.mean(rets)) if rets else 0.0


def _train_one_method_with_eval(
    *,
    method: str,
    T: np.ndarray,
    wall_mask_hw: np.ndarray,
    start_state: int,
    total_steps: int,
    horizon: int,
    refresh_k_steps: int,
    eval_every_steps: int,
    n_eval_episodes: int,
    eps: float,
    eps_eval: float,
    q_alpha: float,
    q_gamma: float,
    gamma_sr: float,
    alpha_sr: float,
    c_decay: float,
    reset_c: bool,
    fpvr_zscore: bool,
    fpvr_type: str,
    fpvr_sr_target: str,
    fpvr_alpha: float,
    rfp_beta: float,
    sp_beta: float,
    sr_beta: float,
    step_penalty: float,
    free_states: np.ndarray,
    base_seed: int,
    goal_schedule_states: List[int] | None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Train online for `total_steps` and return:
      - eval_steps: array of checkpoint steps
      - eval_means: array of mean eval returns (epsilon-greedy on Q)
    """
    n_states, n_act = T.shape
    # IMPORTANT: use a method-independent RNG seed so that, under identical behavior policies
    # (e.g., rFP+SARSA with rfp_beta≈0 vs SARSA), the trajectories match for the same base_seed.
    # Different methods will still diverge once they take different actions, which is expected.
    rng = np.random.default_rng(int(base_seed) + 1337)

    Q = np.zeros((n_states, n_act), dtype=np.float32)
    # State visitation counts during training (over env states including walls; walls will stay 0).
    visit_counts = np.zeros((n_states,), dtype=np.int64)
    # Compute/maintain only what each method needs.
    need_C = method in ("FPVR+SARSA", "rFP+SARSA")
    need_M = method == "FPVR+SARSA"
    need_M_ss = method in ("SP+SARSA", "SR+SARSA")

    C = np.zeros((n_states,), dtype=np.float32) if need_C else None
    M = np.zeros((n_states, n_act, n_states), dtype=np.float32) if need_M else None
    M_ss = np.zeros((n_states, n_states), dtype=np.float32) if need_M_ss else None
    e_buf = np.zeros((n_states,), dtype=np.float32) if (need_M or need_M_ss) else None

    def fpvr_costs(s: int) -> np.ndarray:
        assert M is not None and C is not None
        costs = np.empty((n_act,), dtype=np.float32)
        for a in range(n_act):
            if fpvr_type == "cosine_similarity":
                costs[a] = _cosine(M[s, a, :], C)
            elif fpvr_type == "inner_product":
                costs[a] = float(np.dot(M[s, a, :], C))
            else:
                raise ValueError(f"Unknown fpvr_type: {fpvr_type}")
        return costs

    def pick_action_behavior(s0: int, *, eps_local: float) -> int:
        # Plain SARSA baseline: epsilon-greedy on the extrinsic SARSA Q-table.
        if method == "SARSA":
            return _pick_action_eps_greedy_scores(Q[s0], eps=float(eps_local), rng=rng)

        # For intrinsic-reward methods (rFP/SP/SR), the intrinsic signal is injected
        # into the SARSA target reward. Therefore action selection is ε-greedy on Q.
        if method in ("rFP+SARSA", "SP+SARSA", "SR+SARSA"):
            return _pick_action_eps_greedy_scores(Q[s0], eps=float(eps_local), rng=rng)

        # FPVR+SARSA is the *Q-bias* variant: FPVR affects decision directly (as requested).
        if method == "FPVR+SARSA":
            # By default, z-score across the *action dimension* at the current state s0
            # before using FPVR as a bias term. This increases discrimination among actions.
            fpvr_term = _zscore(fpvr_costs(s0)) if bool(fpvr_zscore) else fpvr_costs(s0)
            scores = Q[s0] - float(fpvr_alpha) * fpvr_term
        else:
            raise ValueError(f"Unknown method: {method}")

        return _pick_action_eps_greedy_scores(scores, eps=float(eps_local), rng=rng)

    # goal state (refresh by training steps)
    train_step = 0
    block_idx = 0
    goal = _goal_for_block(
        wall_mask_hw=wall_mask_hw,
        start_state=start_state,
        free_states=free_states,
        base_seed=int(base_seed),
        block_idx=int(block_idx),
        goal_schedule_states=goal_schedule_states,
    )

    # episode state
    ep_idx = 0
    s = _sample_episode_start_state(free_states=free_states, goal_state=int(goal), rng=rng)
    ep_step = 0
    visit_counts[s] += 1  # count initial start state visit

    # checkpoints
    if eval_every_steps <= 0:
        raise ValueError("eval_every_steps must be > 0")
    eval_steps = list(range(0, int(total_steps) + 1, int(eval_every_steps)))
    if eval_steps[-1] != int(total_steps):
        eval_steps.append(int(total_steps))
    eval_means: List[float] = []
    next_eval_i = 0

    def maybe_eval():
        nonlocal next_eval_i
        if next_eval_i < len(eval_steps) and train_step == int(eval_steps[next_eval_i]):
            mean_ret = _eval_policy_eps_greedy(
                Q=Q,
                T=T,
                wall_mask_hw=wall_mask_hw,
                start_state=start_state,
                horizon=horizon,
                n_eval_episodes=n_eval_episodes,
                eps_eval=eps_eval,
                step_penalty=step_penalty,
                goal_state=int(goal),
                base_seed=base_seed,
                eval_seed_offset=99991 * next_eval_i,
            )
            eval_means.append(mean_ret)
            next_eval_i += 1

    while train_step < int(total_steps):
        # refresh goal by training steps (can happen mid-episode)
        if int(refresh_k_steps) > 0:
            new_block = int(train_step) // int(refresh_k_steps)
            if new_block != block_idx:
                block_idx = new_block
                goal = _goal_for_block(
                    wall_mask_hw=wall_mask_hw,
                    start_state=start_state,
                    free_states=free_states,
                    base_seed=int(base_seed),
                    block_idx=int(block_idx),
                    goal_schedule_states=goal_schedule_states,
                )
                # NOTE: we do NOT reset the SARSA Q-table on goal refresh.
                # SR-related tables (M, M_ss) are also kept across refreshes.

        # Episode-start hook: optionally reset the discounted visitation accumulator C.
        # This affects ALL methods that use C (FPVR+SARSA and rFP+SARSA).
        if bool(reset_c) and int(ep_step) == 0 and C is not None:
            C.fill(0.0)

        # Evaluate *after* applying refresh logic, so checkpoints at refresh boundaries
        # reflect the new goal (and any Q reset).
        maybe_eval()

        # visitation update (discounted) - only for methods that use C
        if C is not None:
            C *= float(c_decay)
            C[s] += 1.0

        a = pick_action_behavior(s, eps_local=float(eps))
        sn = int(T[s, a])
        visit_counts[sn] += 1

        # ---------------- Extrinsic reward (task) ----------------
        r_ext = float(step_penalty)
        done_success = False
        if sn == int(goal):
            r_ext = 1.0
            done_success = True

        done = bool(done_success) or (int(ep_step) + 1 >= int(horizon))

        # ---------------- Intrinsic reward (tabular methods) ----------------
        # For these methods, intrinsic reward is injected into the SARSA TD target.
        # Sign convention depends on the method (some are rewards, some are penalties).
        r_int = 0.0
        r_total = float(r_ext)
        if method == "rFP+SARSA":
            # Visit-count novelty penalty: r_int = C_t(s_{t+1})
            # Use the current discounted visitation C BEFORE adding the next state's visit.
            assert C is not None
            r_int = float(C[sn])
            r_total = float(r_ext) - float(rfp_beta) * r_int
        elif method == "SP+SARSA":
            # SP intrinsic reward based on state-SR table:
            # r_int(s,a,s') = M_ss[s, s'] - sum_{x} M_ss[x, s']
            assert M_ss is not None
            col_sum = float(M_ss[:, sn].sum())
            r_int = float(M_ss[s, sn]) - col_sum
            r_total = float(r_ext) + float(sp_beta) * r_int
        elif method == "SR+SARSA":
            # SR-novelty intrinsic reward:
            # r_int(s,a,s') = 1 / (||psi(s')||_1 + eps)  with psi(s') = M_ss[s', :]
            assert M_ss is not None
            l1 = float(np.sum(np.abs(M_ss[sn, :])))
            r_int = 1.0 / (l1 + 1e-6)
            r_total = float(r_ext) + float(sr_beta) * r_int
        # FPVR+SARSA keeps extrinsic-only learning here; FPVR affects behavior via Q-bias.

        # SARSA next action for update
        if done:
            a_next = 0
            td_target = float(r_total)
        else:
            a_next = pick_action_behavior(sn, eps_local=float(eps))
            td_target = float(r_total) + float(q_gamma) * float(Q[sn, a_next])
        Q[s, a] = (1.0 - float(q_alpha)) * Q[s, a] + float(q_alpha) * td_target

        # update SR tables (terminal-aware) - only when needed
        if need_M or need_M_ss:
            assert e_buf is not None
            e_buf.fill(0.0)
            e_buf[sn] = 1.0
            g = float(gamma_sr) * (0.0 if done else 1.0)

            if need_M:
                assert M is not None
                # SR TD target expectation over a' at next state s':
                # - mean: uniform expectation over actions
                # - min: choose a' that minimizes FPVR(s',a') under the current C
                if str(fpvr_sr_target).lower() == "min":
                    next_costs = fpvr_costs(int(sn))
                    a_min = int(np.argmin(next_costs))
                    psi_exp = M[sn, a_min, :]
                else:
                    psi_exp = M[sn, :, :].mean(axis=0)
                M[s, a, :] = (1.0 - float(alpha_sr)) * M[s, a, :] + float(alpha_sr) * (e_buf + g * psi_exp)

            if need_M_ss:
                assert M_ss is not None
                M_ss[s, :] = (1.0 - float(alpha_sr)) * M_ss[s, :] + float(alpha_sr) * (e_buf + g * M_ss[sn, :])

        # advance time
        train_step += 1
        ep_step += 1
        s = sn

        # episode reset on termination
        if done:
            ep_idx += 1
            s = _sample_episode_start_state(free_states=free_states, goal_state=int(goal), rng=rng)
            ep_step = 0
            if bool(reset_c) and C is not None:
                C.fill(0.0)
            # Count the reset-to-start state as visited (episode boundary).
            visit_counts[s] += 1

    return (
        np.array(eval_steps[: len(eval_means)], dtype=np.int32),
        np.array(eval_means, dtype=np.float32),
        visit_counts.astype(np.int64),
    )


def _mean_std(curves: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
    arr = np.stack(curves, axis=0)
    return arr.mean(axis=0), arr.std(axis=0)


def _safe_name(name: str) -> str:
    """Keep filenames/keys readable on all OSes."""
    out = str(name).replace(" ", "_")
    out = out.replace("+", "plus")
    out = out.replace("/", "_")
    return out


def main() -> int:
    p = argparse.ArgumentParser(description="Reward collection experiment (key+door) on bottleneck_env layout.")
    p.add_argument("--total_steps", type=int, default=200000)
    p.add_argument(
        "--horizon",
        type=int,
        default=200,
        help="Max episode length (in steps). Default: 200.",
    )
    p.add_argument("--refresh_k", type=int, default=40000, help="Refresh goal every K training steps. Default: 2000.")
    p.add_argument(
        "--random_refresh",
        action=argparse.BooleanOptionalAction,
        default=True,
        help=(
            "If enabled, ignore manual --goal_schedule/--goal_schedule_file and instead sample a random "
            "non-wall goal at each refresh block. The random goal list is generated once per seed and "
            "shared across all methods to ensure identical goals at the same training step."
        ),
    )
    p.add_argument(
        "--goal_schedule",
        type=str,
        default="13,1;12,12;1,11;1,13;13,4",
        help="Optional manual goal schedule per refresh block, e.g. '12,12; 2,5; 7,9'. Overrides sampling for those blocks.",
    )
    p.add_argument(
        "--goal_schedule_file",
        type=str,
        default=None,
        help="Optional path to a text file listing one goal per line as 'x y' or 'x,y' (comments with # allowed).",
    )
    p.add_argument("--n_seeds", type=int, default=10)
    p.add_argument("--seed", type=int, default=1)
    p.add_argument("--start_x", type=int, default=1)
    p.add_argument("--start_y", type=int, default=1)
    p.add_argument("--step_penalty", type=float, default=0.0, help="Per-step reward (except success +1). Default: -0.01.")
    p.add_argument("--eval_every", type=int, default=500, help="Evaluate every N training steps. Default: 1000.")
    p.add_argument("--n_eval_episodes", type=int, default=10, help="Number of evaluation episodes per checkpoint. Default: 30.")
    p.add_argument("--eps_eval", type=float, default=0.1, help="Epsilon for epsilon-greedy evaluation on SARSA Q. Default: 0.1.")
    p.add_argument("--smooth", type=int, default=1, help="Moving-average window over checkpoints. Default: 5 (1 disables).")

    # SARSA (extrinsic)
    p.add_argument("--eps", type=float, default=0.1)#0.1
    p.add_argument("--q_alpha", type=float, default=0.5)#0.5
    p.add_argument("--q_gamma", type=float, default=0.9)#0.9

    # SR learning for intrinsic signals
    p.add_argument("--gamma_sr", type=float, default=0.9)#0.9
    p.add_argument("--alpha_sr", type=float, default=0.1) #0.1
    p.add_argument("--c_decay", type=float, default=0.999) #0.999
    p.add_argument(
        "--reset_c",
        action=argparse.BooleanOptionalAction,
        default=False,
        help="Whether to reset the discounted visitation accumulator C at episode boundaries. "
             "Use --no-reset-c to keep C across episodes.",
    )

    # Exploration scaling
    p.add_argument("--fpvr_alpha", type=float, default=0.01, help="Weight for FPVR redundancy term in Q_b (FPVR+SARSA).")
    p.add_argument(
        "--fpvr_type",
        type=str,
        default="cosine_similarity",
        choices=["cosine_similarity", "inner_product"],
        help="FPVR cost type used in FPVR+SARSA.",
    )
    p.add_argument(
        "--fpvr_sr_target",
        type=str,
        default="min",
        choices=["mean", "min"],
        help=(
            "SR TD target expectation over a' at next state s' (for FPVR+SARSA action-conditioned SR).\n"
            "mean: use mean_a M[s',a]; min: use M[s',a_min] where a_min minimizes FPVR(s',a)."
        ),
    )
    p.add_argument(
        "--fpvr_zscore",
        action=argparse.BooleanOptionalAction,
        default=True,
        help="If enabled, apply action-wise z-score normalization to FPVR costs before using them as a Q-bias term. "
             "Use --no-fpvr-zscore to disable.",
    )
    p.add_argument(
        "--rfp_beta",
        type=float,
        default=0.001,
        help="Weight for the visit-count penalty term in the SARSA reward target (rFP+SARSA).",
    )
    p.add_argument("--sp_beta", type=float, default=0.0001, help="Weight for SP intrinsic term in score (SP+SARSA).")
    p.add_argument("--sr_beta", type=float, default=0.0001, help="Weight for SR intrinsic term in score (SR+SARSA).")
    p.add_argument("--fpvr_beta", type=float, default=1.0, help="(Unused) Kept for compatibility; direct FPVR baseline is disabled.")

    p.add_argument(
        "--methods",
        type=str,
        default='all',
        help=(
            "Which methods to run/plot. Comma-separated list or 'all'. "
            "Choices: FPVR+SARSA,rFP+SARSA,SP+SARSA,SR+SARSA,SARSA"
        ),
    )

    p.add_argument("--out_png", type=str, default="fourrooms_reward.png")
    p.add_argument("--out_eps", type=str, default=None)
    args = p.parse_args()

    # ---------------- Output paths: save under gridworld/results/ by default ----------------
    # Users often run this script from the repo root; in that case relative output paths
    # would otherwise end up in the outer working directory. We instead resolve outputs
    # into a local results directory next to this script (gridworld/results/),
    # unless the user passed an absolute path.
    script_dir = os.path.dirname(os.path.abspath(__file__))
    results_dir = os.path.join(script_dir, "results")
    os.makedirs(results_dir, exist_ok=True)

    def _resolve_out(pth: str | None) -> str | None:
        if pth is None:
            return None
        pth = str(pth)
        if os.path.isabs(pth):
            os.makedirs(os.path.dirname(pth), exist_ok=True)
            return pth
        out = os.path.join(results_dir, pth)
        os.makedirs(os.path.dirname(out), exist_ok=True)
        return out

    args.out_png = _resolve_out(args.out_png)  # type: ignore[assignment]
    args.out_eps = _resolve_out(args.out_eps)  # type: ignore[assignment]

    if SimpleEnv is None:
        raise ModuleNotFoundError(
            "Missing optional dependency for gridworld/minigrid environment (minigrid). "
            "Please install it (e.g., `pip install minigrid`) to run fourrooms_reward_experiment."
        )

    # Build transition model once
    env = SimpleEnv(render_mode=None)
    env.reset(seed=int(args.seed))
    T, wall_mask_hw = build_transition_from_env(env)
    env.close()
    width = wall_mask_hw.shape[1]
    free_states = _free_states(wall_mask_hw)

    # Backward-compat: keep --start_x/--start_y but the episode start state is now randomized.
    # We still compute a valid start_state for any code paths that need a placeholder (e.g., schedule validation).
    start_state = _pos_to_state(int(args.start_x), int(args.start_y), width)
    sx, sy = int(args.start_x), int(args.start_y)
    if not (0 <= sx < int(width) and 0 <= sy < int(wall_mask_hw.shape[0])) or wall_mask_hw[sy, sx]:
        start_state = int(free_states[0])
    goal_schedule_states = _load_goal_schedule(
        goal_schedule=str(args.goal_schedule),
        goal_schedule_file=getattr(args, "goal_schedule_file", None),
        width=width,
        wall_mask_hw=wall_mask_hw,
        start_state=start_state,
    )

    # ---------------- Print goal refresh schedule (precomputed) ----------------
    # Goal refresh happens every K *training steps* and can occur mid-episode.
    refresh_k = int(args.refresh_k)
    total_steps = int(args.total_steps)
    if refresh_k <= 0:
        n_blocks = 1
    else:
        # Blocks cover step indices [0, total_steps-1]
        n_blocks = int((max(total_steps, 1) - 1) // refresh_k) + 1

    def _random_goal_schedule_for_seed(base_seed: int) -> List[int]:
        """
        Generate a full goal schedule of length n_blocks for a given seed.
        Sampling is uniform over free (non-wall) states and deterministic w.r.t. base_seed.
        """
        rng = np.random.default_rng(int(base_seed) + 424242)
        sched = rng.choice(free_states, size=int(n_blocks), replace=True)
        return [int(s) for s in sched.tolist()]

    goal_schedule_by_seed: Dict[int, List[int]] = {}
    if bool(args.random_refresh):
        for si in range(int(args.n_seeds)):
            base_seed = int(args.seed) + si
            goal_schedule_by_seed[int(base_seed)] = _random_goal_schedule_for_seed(int(base_seed))

    def _print_schedule_for_seed(base_seed: int) -> None:
        print(f"\n[Schedule] seed={base_seed} | refresh_k={refresh_k} | total_steps={total_steps} | n_refresh={n_blocks}")
        sched_override = goal_schedule_by_seed.get(int(base_seed)) if bool(args.random_refresh) else goal_schedule_states
        for b in range(int(n_blocks)):
            goal_s = _goal_for_block(
                wall_mask_hw=wall_mask_hw,
                start_state=start_state,
                free_states=free_states,
                base_seed=int(base_seed),
                block_idx=int(b),
                goal_schedule_states=sched_override,
            )
            gx, gy = _state_to_pos(int(goal_s), width)
            if refresh_k <= 0:
                step_l, step_r = 0, max(total_steps - 1, 0)
            else:
                step_l = int(b) * int(refresh_k)
                step_r = min(int((b + 1) * int(refresh_k) - 1), max(total_steps - 1, 0))
            print(
                f"  - block={b:03d} steps=[{step_l},{step_r}] "
                f"goal=(x={gx},y={gy})"
            )

    # Print for each training seed (schedule is shared across methods for the same seed).
    for si in range(int(args.n_seeds)):
        _print_schedule_for_seed(int(args.seed) + si)

    methods_all = [
        ("FPVR+SARSA", "FPVR+SARSA"),
        ("rFP+SARSA", "rFP+SARSA"),
        ("SP+SARSA", "SP+SARSA"),
        ("SR+SARSA", "SR+SARSA"),
        ("SARSA", "SARSA"),
    ]
    # Backward-compatible alias: older runs used this label for the plain SARSA baseline.
    _method_alias = {"Random Walk+SARSA": "SARSA"}
    allowed = {lbl for (lbl, _mid) in methods_all} | set(_method_alias.keys())

    sel = str(args.methods).strip()
    if sel.lower() == "all" or sel == "":
        methods = methods_all
    else:
        req = [s.strip() for s in sel.split(",") if s.strip()]
        req = [_method_alias.get(m, m) for m in req]
        unknown = [m for m in req if m not in allowed]
        if unknown:
            raise ValueError(f"Unknown --methods entries: {unknown}. Allowed: {sorted(list(allowed))}")
        # keep a stable order (methods_all order), regardless of user ordering
        req_set = set(req)
        methods = [pair for pair in methods_all if pair[0] in req_set]

    plot_order = [m[0] for m in methods]

    eval_steps_ref: np.ndarray | None = None
    reward_curves: Dict[str, List[np.ndarray]] = {name: [] for name in plot_order}
    visit_curves: Dict[str, List[np.ndarray]] = {name: [] for name in plot_order}

    for si in range(int(args.n_seeds)):
        base_seed = int(args.seed) + si
        for label, method_id in methods:
            sched_override = goal_schedule_by_seed.get(int(base_seed)) if bool(args.random_refresh) else goal_schedule_states
            steps_arr, eval_arr, visit_arr = _train_one_method_with_eval(
                method=method_id,
                T=T,
                wall_mask_hw=wall_mask_hw,
                start_state=start_state,
                total_steps=int(args.total_steps),
                horizon=int(args.horizon),
                refresh_k_steps=int(args.refresh_k),
                eval_every_steps=int(args.eval_every),
                n_eval_episodes=int(args.n_eval_episodes),
                eps=float(args.eps),
                eps_eval=float(args.eps_eval),
                q_alpha=float(args.q_alpha),
                q_gamma=float(args.q_gamma),
                gamma_sr=float(args.gamma_sr),
                alpha_sr=float(args.alpha_sr),
                c_decay=float(args.c_decay),
                reset_c=bool(args.reset_c),
                fpvr_zscore=bool(args.fpvr_zscore),
                fpvr_type=str(args.fpvr_type),
                fpvr_sr_target=str(args.fpvr_sr_target),
                fpvr_alpha=float(args.fpvr_alpha),
                rfp_beta=float(args.rfp_beta),
                sp_beta=float(args.sp_beta),
                sr_beta=float(args.sr_beta),
                step_penalty=float(args.step_penalty),
                free_states=free_states,
                base_seed=base_seed,
                goal_schedule_states=sched_override,
            )
            if eval_steps_ref is None:
                eval_steps_ref = steps_arr
            else:
                # Ensure checkpoint grids match
                m = min(eval_steps_ref.shape[0], steps_arr.shape[0])
                eval_steps_ref = eval_steps_ref[:m]
                steps_arr = steps_arr[:m]
                eval_arr = eval_arr[:m]
            reward_curves[label].append(eval_arr)
            visit_curves[label].append(visit_arr)

    # Plot (make axes/legend larger for readability)
    label_fs = 20
    tick_fs = 16
    legend_fs = 16

    # Align checkpoint counts across seeds by truncating to minimum length.
    min_len = None
    for name in plot_order:
        lens = [c.shape[0] for c in reward_curves[name]]
        if not lens:
            continue
        m = int(min(lens))
        min_len = m if min_len is None else min(min_len, m)
    if min_len is None or eval_steps_ref is None:
        raise RuntimeError("No curves collected.")

    eval_steps_ref = eval_steps_ref[: int(min_len)]
    fig, ax = plt.subplots(figsize=(18, 5.4))
    x = eval_steps_ref.astype(np.float32)
    stem, _ = os.path.splitext(args.out_png)

    # ---------------- Save reward curve data (for re-plotting) ----------------
    # Save aligned per-seed curves plus mean/std (and smoothed mean/std) for each method.
    reward_data: Dict[str, np.ndarray] = {}
    reward_data["eval_steps"] = eval_steps_ref.astype(np.int32)
    reward_data["methods"] = np.array(plot_order, dtype=object)

    for idx, name in enumerate(plot_order):
        curves = [c[: int(min_len)] for c in reward_curves[name]]
        stack = np.stack(curves, axis=0).astype(np.float32)  # [n_seeds, T]
        mean = stack.mean(axis=0)
        std = stack.std(axis=0)
        # Smooth the mean/std curves over checkpoints
        mean_s = _moving_average(mean.astype(np.float32), int(args.smooth))
        std_s = _moving_average(std.astype(np.float32), int(args.smooth))

        k = _safe_name(name)
        reward_data[f"{k}_curves"] = stack
        reward_data[f"{k}_mean"] = mean.astype(np.float32)
        reward_data[f"{k}_std"] = std.astype(np.float32)
        reward_data[f"{k}_mean_smooth"] = mean_s.astype(np.float32)
        reward_data[f"{k}_std_smooth"] = std_s.astype(np.float32)

        color = REWARD_METHOD_COLORS.get(name, f"C{idx % 10}")
        band_color = tuple(np.array(mcolors.to_rgb(color)) * 0.15 + 0.85)  # lighten toward white
        ax.fill_between(x, mean_s - std_s, mean_s + std_s, color=band_color, linewidth=0.0, zorder=1)
        disp_name = REWARD_METHOD_DISPLAY_NAMES.get(name, name)
        ax.plot(x, mean_s, label=disp_name, linewidth=2.6, color=color, zorder=2)

    try:
        np.savez_compressed(f"{stem}_reward_curves.npz", **reward_data)
    except Exception as e:
        print(f"[Warning] Failed to save reward curve data: {e}")

    ax.set_xlabel("Steps", fontsize=label_fs)
    ax.set_ylabel("Episode return", fontsize=label_fs)
    ax.tick_params(axis="both", which="major", labelsize=tick_fs)
    ax.grid(alpha=0.3)

    # Draw separators at key/door refresh boundaries (step-based).
    k_steps = int(args.refresh_k)
    if k_steps > 0:
        max_x = float(x.max()) if x.size else float(args.total_steps)
        boundary = k_steps
        while boundary < max_x:
            ax.axvline(boundary, color="#8B0000", linewidth=1.4, linestyle="--", zorder=0)
            boundary += k_steps

    ax.legend(fontsize=legend_fs, frameon=True, loc="lower right")
    fig.tight_layout()

    fig.savefig(args.out_png, dpi=200)
    out_eps = args.out_eps
    if out_eps is None:
        out_eps = stem + ".eps"
    fig.savefig(out_eps, format="eps")

    # ---------------- Save visitation heatmaps per method ----------------
    # We visualize mean visitation counts across seeds. Use log1p to reduce dynamic range.
    try:
        env_vis = SimpleEnv(render_mode=None)
        env_vis.reset(seed=int(args.seed))
        viz = BottleneckVisualization(env_vis)

        # reuse stem/_safe_name from above

        for name in plot_order:
            if not visit_curves.get(name):
                continue
            visits_stack = np.stack(visit_curves[name], axis=0).astype(np.float64)
            visits_mean = visits_stack.mean(axis=0)

            # Save raw mean counts
            np.save(f"{stem}_visits_{_safe_name(name)}.npy", visits_mean.astype(np.float64))

            # Plot log-scale heatmap
            fig_v = viz.plot_2d_heatmap(
                np.log1p(visits_mean).astype(np.float32),
                topk=32,
                title=f"State visitation (log1p) - {name}",
                color_bar=True,
                cmap_name="hot",
                show=False,
            )
            fig_v.savefig(f"{stem}_visits_{_safe_name(name)}.png", dpi=200)
            plt.close(fig_v)
    finally:
        try:
            env_vis.close()  # type: ignore[name-defined]
        except Exception:
            pass

    print(f"[OK] Saved: {args.out_png} and {out_eps}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())

