#!/usr/bin/env python
"""Coverage comparison under primitive vs. option-augmented random walks.

Reads saved option Q-tables and runs the same number of trials for each
method to produce state coverage curves and visitation heatmaps.
"""
import argparse, random, os, glob
import numpy as np
import matplotlib.pyplot as plt
import math
import gridworld_options as go              # only need set_seed
from bottleneck_env import SimpleEnv
from generate_state_transition_matrix import build_state_transition_matrix
from utils import BottleneckVisualization


# ---------------- utility functions -------------------------
def option_terminated(q_row):
    """Terminate an option when every Q-value ≤ 0 or by 1/L random chance."""
    L = 15
    return q_row.max() <= 0 or random.random() <= 1 / L


def random_walk_cover(T, free_idx, s0, max_steps):
    """Pure primitive random walk; return coverage curve and visit counts."""
    visited = np.zeros(T.shape[0], bool)
    visit_c = np.zeros_like(visited, dtype=np.int32)
    curve, s = [], s0
    for _ in range(max_steps):
        visited[s] = True
        visit_c[s] += 1
        curve.append(visited[free_idx].sum())
        s = int(T[s, np.random.randint(4)])
    return curve, visit_c


def mixed_walk_cover(T, free_idx, Q, s0, p_prim, max_steps):
    """
    Random walk that mixes primitive actions (prob. p_prim)
    and randomly chosen options (prob. 1-p_prim).
    """
    k, pol = Q.shape[0], np.argmax(Q, 2)
    visited = np.zeros(T.shape[0], bool)
    visit_c = np.zeros_like(visited, dtype=np.int32)
    curve, s = [], s0
    while len(curve) < max_steps:
        visited[s] = True
        visit_c[s] += 1
        curve.append(visited[free_idx].sum())

        if random.random() < p_prim:           # primitive step
            s = int(T[s, np.random.randint(4)])
        else:                                  # option rollout
            opt = random.randrange(k)
            print(opt)
            while (not option_terminated(Q[opt, s])) and len(curve) < max_steps:
                visited[s] = True
                visit_c[s] += 1
                curve.append(visited[free_idx].sum())
                s = int(T[s, pol[opt, s]])
    return curve, visit_c


def pad_curve(curve, L):
    """Pad or truncate a curve to length L for array stacking."""
    if len(curve) >= L:
        return np.asarray(curve[:L], float)
    return np.concatenate([curve, np.full(L - len(curve), curve[-1], float)])


def get_saved_option_files(method: str, out_dir: str):
    """
    Locate saved option files. Expected filename pattern:
      gridworld_*_{method}_*.npy
    """
    pat = os.path.join(out_dir, f"gridworld_*_{method}_*.npy")
    return sorted(glob.glob(pat))


# ----------------------------- main -------------------------
def main():
    p = argparse.ArgumentParser()
    g = p.add_argument
    g("--inner", type=int, default=10, help="# runs per option set")
    g("--max_steps", type=int, default=3000)
    g("--p_prim", type=float, default=0.16666,
      help="probability of choosing a primitive action")
    g("--out_dir", default="option_results")
    args = p.parse_args()

    # ---------- environment ----------
    env = SimpleEnv(render_mode=None)
    env.reset()
    T, wall = build_state_transition_matrix(env)
    free_idx = np.where(wall.flatten() == 0)[0]
    N = T.shape[0]

    os.makedirs(args.out_dir, exist_ok=True)

    methods_with_opt = ["RandomOption", "EigenOpt", "VPSOpt"]
    Q_sets = {m: [] for m in methods_with_opt}

    # ---------- load options ----------
    for m in methods_with_opt:
        files = get_saved_option_files(m, args.out_dir)
        if not files:
            print(f"[Warning] no saved {m} options found in '{args.out_dir}'")
        for f in files:
            Q_sets[m].append(np.load(f))
            print(f"[Load] {f}  (K={Q_sets[m][-1].shape[0]})")

    # n_outer equals the largest count among option files
    n_outer = max((len(v) for v in Q_sets.values()), default=0)
    if n_outer == 0:
        raise RuntimeError("No option files found — run the training script first!")

    # ---------------- experiments ----------------
    methods_plot = ["Random", *methods_with_opt]
    curves = {m: [] for m in methods_plot}
    visits = {m: np.zeros(N, np.int64) for m in methods_plot}

    total_runs = n_outer * args.inner
    print(f"[Run]  {total_runs} runs per method …\n")

    # -- primitive baseline --
    for outer in range(n_outer):
        go.set_seed(outer)
        for _ in range(args.inner):
            s0 = int(np.random.choice(free_idx))
            c, v = random_walk_cover(T, free_idx, s0, args.max_steps)
            curves["Random"].append(c)
            visits["Random"] += v

    # -- three option types --
    for m in methods_with_opt:
        for outer_idx, Q in enumerate(Q_sets[m]):
            go.set_seed(outer_idx)
            for _ in range(args.inner):
                s0 = int(np.random.choice(free_idx))
                c, v = mixed_walk_cover(T, free_idx, Q, s0,
                                        args.p_prim, args.max_steps)
                curves[m].append(c)
                visits[m] += v

    # ---------------- coverage curves ----------------
    L = args.max_steps
    colors = {"Random": "tab:blue",
              "RandomOption": "tab:purple",
              "EigenOpt": "tab:orange",
              "VPSOpt": "tab:green"}

    legend_labels = {"Random": "Primitive Action",
                     "RandomOption": "Random Option",
                     "EigenOpt": "Eigenoption",
                     "VPSOpt": "VPS Option"}

    plt.figure(figsize=(8.5, 4))
    x = np.arange(L)
    for m in methods_plot:
        if not curves[m]:
            continue
        mat = np.vstack([pad_curve(c, L) for c in curves[m]])
        mu, sig = mat.mean(0), mat.std(0)
        plt.plot(x, mu, color=colors[m], lw=1.7, label=legend_labels[m])
        plt.fill_between(x, mu - sig, mu + sig, color=colors[m], alpha=.15)
    plt.xlabel("Time Steps", fontsize=16)
    plt.ylabel("Number of Visited States", fontsize=16)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.title("State Coverage Curves", fontsize=16)
    plt.grid(alpha=.3)
    plt.xlim(-5, args.max_steps)
    plt.legend(fontsize=16)
    plt.tight_layout()
    plt.show()

    # ---------------- visitation heat maps ----------------
    viz = BottleneckVisualization(env)
    for m in methods_plot:
        if not curves[m]:
            continue

        # split total count into scientific notation a × 10^b
        total = visits[m].sum()
        if total == 0:
            sci_str = "0"
        else:
            exp = int(math.floor(math.log10(total)))
            coef = total / (10 ** exp)
            sci_str = rf"{coef:.2g} $\times\,10^{{{exp}}}$"

        viz.plot_2d_heatmap(
            visits[m].astype(float),
            topk=0,
            color_bar=True,
            title=f"State Visit-Count • {legend_labels[m]}"
        )
    print("\n[✓] Experiment finished.")


if __name__ == "__main__":
    main()
