"""
Cover-time experiment (boxplot) for 5 exploration methods on bottleneck_env.py.

Cover time = the number of steps until *all reachable states* (from the start state)
have been visited at least once.

We reuse the exact same policies/updates as fourrooms_exploration.py, but change the
metric + visualization.
"""

from __future__ import annotations

import argparse
import os
from typing import Dict, List, Tuple

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

# Optional dependency: minigrid env
try:
    from bottleneck_env import SimpleEnv  # type: ignore
except Exception:
    SimpleEnv = None  # type: ignore

# Reuse implementations from the coverage script (keeps training/exploration identical).
from fourrooms_exploration import (
    RunConfig,
    METHOD_COLORS,
    build_transition_from_env,
    run_fpvr_direct,
    run_fpvr_reward_sarsa,
    run_sp_reward_sarsa,
    run_random_walk,
    _pos_to_state,
    _state_to_pos,
)


def reachable_states_from_start(T: np.ndarray, wall_mask_hw: np.ndarray, start_state: int) -> np.ndarray:
    """Return a boolean mask (length N_STATES) of states reachable from start via T."""
    n_states, n_act = T.shape
    width = wall_mask_hw.shape[1]

    reachable = np.zeros((n_states,), dtype=bool)
    # Exclude blocked cells
    sx, sy = _state_to_pos(int(start_state), width)
    if wall_mask_hw[sy, sx]:
        raise ValueError("start_state is blocked (wall/door).")

    q = [int(start_state)]
    reachable[int(start_state)] = True
    head = 0
    while head < len(q):
        s = q[head]
        head += 1
        for a in range(n_act):
            sn = int(T[s, a])
            x, y = _state_to_pos(sn, width)
            if 0 <= y < wall_mask_hw.shape[0] and 0 <= x < width and (not wall_mask_hw[y, x]):
                if not reachable[sn]:
                    reachable[sn] = True
                    q.append(sn)
    return reachable


def cover_time_from_curve(cov: np.ndarray, target_n: int) -> int:
    """
    cov[t] = number of unique visited states up to step t (0-based index).
    Return cover time in steps (1-based), or -1 if not covered within horizon.
    """
    idx = np.flatnonzero(cov >= int(target_n))
    if idx.size == 0:
        return -1
    return int(idx[0]) + 1


def main() -> int:
    p = argparse.ArgumentParser(description="Cover-time boxplot for 5 exploration methods (fourrooms).")
    p.add_argument("--n_runs", type=int, default=50, help="Number of independent trials per method.")
    p.add_argument("--seed", type=int, default=0)
    p.add_argument("--max_steps", type=int, default=30000, help="Max steps per trial (censored if not covered).")
    p.add_argument("--start_x", type=int, default=1)
    p.add_argument("--start_y", type=int, default=1)

    # Match hyperparams used in fourrooms_exploration.py
    p.add_argument("--gamma_sr", type=float, default=0.9)
    p.add_argument("--alpha_sr", type=float, default=0.1)
    p.add_argument("--c_decay", type=float, default=0.999)
    p.add_argument("--fpvr_beta", type=float, default=10.0)
    p.add_argument(
        "--fpvr_sr_target",
        type=str,
        default="min",
        choices=["mean", "min"],
        help=(
            "SR TD target expectation over a' at next state s'. "
            "mean: use mean_a M[s',a]; min: use M[s',a_min] where a_min minimizes FPVR(s',a)."
        ),
    )
    p.add_argument("--eps", type=float, default=0.1)
    p.add_argument("--sarsa_alpha", type=float, default=0.1)
    p.add_argument("--sarsa_gamma", type=float, default=0.99)
    p.add_argument("--r_fpvr_scale", type=float, default=1.0)
    p.add_argument("--r_sp_scale", type=float, default=1.0)
    p.add_argument("--r_sr_scale", type=float, default=1.0)
    p.add_argument("--sr_l1_eps", type=float, default=1e-6)

    p.add_argument("--out_png", type=str, default="fourrooms_cover_time.png")
    p.add_argument("--out_eps", type=str, default=None)
    args = p.parse_args()

    if SimpleEnv is None:
        raise ModuleNotFoundError(
            "Missing optional dependency for gridworld/minigrid environment (minigrid). "
            "Please install it (e.g., `pip install minigrid`) to run fourrooms_cover_time."
        )

    # Build transition model once from env layout
    env = SimpleEnv(render_mode=None)
    env.reset(seed=int(args.seed))
    T, wall_mask_hw = build_transition_from_env(env)
    env.close()

    width = wall_mask_hw.shape[1]
    start_state = _pos_to_state(int(args.start_x), int(args.start_y), width)
    reachable = reachable_states_from_start(T, wall_mask_hw, start_state)
    target_n = int(reachable.sum())
    print(f"[CoverTime] reachable states from start=({args.start_x},{args.start_y}): {target_n}")

    cfg = RunConfig(
        total_steps=int(args.max_steps),
        n_seeds=1,
        seed=int(args.seed),
        gamma_sr=float(args.gamma_sr),
        alpha_sr=float(args.alpha_sr),
        c_decay=float(args.c_decay),
        fpvr_beta=float(args.fpvr_beta),
        fpvr_sr_target=str(args.fpvr_sr_target),
        eps=float(args.eps),
        sarsa_alpha=float(args.sarsa_alpha),
        sarsa_gamma=float(args.sarsa_gamma),
        r_fpvr_scale=float(args.r_fpvr_scale),
        r_sp_scale=float(args.r_sp_scale),
        r_sr_scale=float(args.r_sr_scale),
        sr_l1_eps=float(args.sr_l1_eps),
        start_x=int(args.start_x),
        start_y=int(args.start_y),
        coverage_reset_interval=0,  # not used here
    )

    methods = {
        "FPVR": run_fpvr_direct,
        r"$r^{FP}$ + SARSA": run_fpvr_reward_sarsa,
        "SP+SARSA": run_sp_reward_sarsa,
        "Random Walk": run_random_walk,
    }

    times_by_method: Dict[str, List[int]] = {k: [] for k in methods.keys()}
    censored_by_method: Dict[str, int] = {k: 0 for k in methods.keys()}

    for run_i in range(int(args.n_runs)):
        for m_idx, (name, fn) in enumerate(methods.items()):
            # Use a deterministic per-(run,method) RNG seed for reproducibility.
            rng = np.random.default_rng(int(args.seed) + 1000 * run_i + 17 * m_idx)
            cov_full, _cov_win, _visit_counts = fn(T, wall_mask_hw, cfg, rng)
            ct = cover_time_from_curve(cov_full, target_n)
            if ct < 0:
                # Censor at max_steps+1 so boxplot can show "did not cover"
                ct = int(args.max_steps) + 1
                censored_by_method[name] += 1
            times_by_method[name].append(int(ct))

    for name in methods.keys():
        cnum = censored_by_method[name]
        if cnum > 0:
            print(f"[CoverTime] {name}: censored {cnum}/{int(args.n_runs)} (ct=max_steps+1)")

    # ---------------- Plot (boxplot) ----------------
    label_fs = 18
    tick_fs = 14
    title_fs = 16

    data = [times_by_method[name] for name in methods.keys()]
    labels = list(methods.keys())

    fig, ax = plt.subplots(figsize=(10, 5.2))
    bp = ax.boxplot(
        data,
        labels=labels,
        showfliers=True,
        patch_artist=True,
        medianprops=dict(color="black", linewidth=1.8),
        boxprops=dict(linewidth=1.3),
        whiskerprops=dict(linewidth=1.3),
        capprops=dict(linewidth=1.3),
    )
    # Color boxes and add a legend. Use the SAME fixed mapping as fourrooms_exploration.py
    # so removing a method won't shift colors (e.g., Random Walk stays purple).
    handles = []
    for i, (patch, name) in enumerate(zip(bp["boxes"], labels)):
        c = METHOD_COLORS.get(name, f"C{i % 10}")
        patch.set_facecolor(c)
        patch.set_alpha(0.35)
        handles.append(mpatches.Patch(facecolor=c, edgecolor="black", label=name, alpha=0.35))
    ax.set_ylabel("Cover time (steps)", fontsize=label_fs)
    ax.tick_params(axis="y", which="major", labelsize=tick_fs)
    ax.tick_params(axis="x", which="major", labelsize=tick_fs, rotation=15)
    ax.grid(axis="y", alpha=0.25)
    #ax.set_title(f"Cover time boxplot (n_runs={int(args.n_runs)}, max_steps={int(args.max_steps)})", fontsize=title_fs)
    ax.legend(handles=handles, loc="best", fontsize=12, frameon=True)

    fig.tight_layout()
    fig.savefig(args.out_png, dpi=200)
    print(f"[OK] Saved plot to: {args.out_png}")

    out_eps = args.out_eps
    if out_eps is None:
        stem, _ext = os.path.splitext(args.out_png)
        out_eps = stem + ".eps"
    fig.savefig(out_eps, format="eps")
    print(f"[OK] Saved plot to: {out_eps}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())

