#!/usr/bin/env python
"""Coverage comparison under primitive vs. option-augmented random walks.

Reads saved option Q-tables and runs the same number of trials for each
method to produce state coverage curves and visitation heatmaps.
"""

import argparse, random, os, glob, ast
import numpy as np
import matplotlib.pyplot as plt
import math
from . import gridworld_options as go  # only need set_seed
from .bottleneck_env import SimpleEnv
from .generate_state_transition_matrix import build_state_transition_matrix
from .utils import BottleneckVisualization


def option_terminated(q_row):
    """Option termination during rollout: Q-max ≤ 0 or random 1/L chance."""
    L = 15
    return q_row.max() <= 0 or random.random() <= 1 / L


def available_options(Q: np.ndarray, s: int):
    """
    Indices of options that can be *initiated* at state s.
    Initiation set is the full state space, but an option is
    considered startable at s only if its local Q-max is > 0.
    """
    k = Q.shape[0]
    return [oid for oid in range(k) if Q[oid, s].max() > 0]


def random_walk_cover(T, free_idx, s0, max_steps):
    """Pure primitive random walk; return coverage curve and visit counts."""
    visited = np.zeros(T.shape[0], bool)
    visit_c = np.zeros_like(visited, dtype=np.int32)
    curve, s = [], s0
    for _ in range(max_steps):
        visited[s] = True
        visit_c[s] += 1
        curve.append(visited[free_idx].sum())
        s = int(T[s, np.random.randint(4)])
    return curve, visit_c


def mixed_walk_cover(T, free_idx, Q, s0, p_prim, max_steps):
    """
    Random walk that uniformly selects from primitive actions and available options.

    At each step, uniformly chooses from:
    - 4 primitive actions (always available)
    - All options that can start at the current state

    If no options are available, only primitive actions are considered.
    Each action/option has equal probability: 1 / (4 + num_available_options).
    """
    k, pol = Q.shape[0], np.argmax(Q, 2)
    n_prim = 4  # Number of primitive actions
    visited = np.zeros(T.shape[0], bool)
    visit_c = np.zeros_like(visited, dtype=np.int32)
    curve, s = [], s0
    while len(curve) < max_steps:
        visited[s] = True
        visit_c[s] += 1
        curve.append(visited[free_idx].sum())

        # Get available options at current state
        avail_opts = available_options(Q, s)
        total_choices = n_prim + len(avail_opts)

        # Uniformly select from [0, total_choices-1]
        # 0 to (n_prim-1): primitive actions
        # n_prim to (total_choices-1): options
        choice = random.randint(0, total_choices - 1)

        if choice < n_prim:
            # Selected a primitive action
            a = choice
            s = int(T[s, a])
        else:
            # Selected an option (index in avail_opts)
            opt_idx = choice - n_prim
            opt = avail_opts[opt_idx]
            # Execute option rollout
            while (not option_terminated(Q[opt, s])) and len(curve) < max_steps:
                visited[s] = True
                visit_c[s] += 1
                curve.append(visited[free_idx].sum())
                s = int(T[s, pol[opt, s]])
    return curve, visit_c


def pad_curve(curve, L):
    """Pad or truncate a curve to length L for array stacking."""
    if len(curve) >= L:
        return np.asarray(curve[:L], float)
    return np.concatenate([curve, np.full(L - len(curve), curve[-1], float)])


def get_saved_option_files(method: str, out_dir: str):
    """
    Locate saved option files. Expected filename pattern:
      gridworld_*_{method}_*.npy
    """
    pat = os.path.join(out_dir, f"gridworld_*_{method}_*.npy")
    return sorted(glob.glob(pat))


def _parse_test_option_num(x: str) -> list[int]:
    """Parse --test_option_num list.

    Examples:
      - "[2,4,6]"
      - "2,4,6"
      - "4"
    """
    s = str(x).strip()
    if not s:
        raise ValueError("--test_option_num is empty")
    if "," in s and not (s.startswith("[") or s.startswith("(")):
        return [int(p.strip()) for p in s.split(",") if p.strip()]
    try:
        val = ast.literal_eval(s)
    except Exception:
        return [int(s)]
    if isinstance(val, int):
        return [int(val)]
    if isinstance(val, (list, tuple)):
        return [int(v) for v in val]
    raise ValueError(f"Unsupported --test_option_num format: {x!r}")


def _k_from_filename(path: str) -> int | None:
    """Extract K from `gridworld_{K}_{Method}_{outer}.npy`."""
    base = os.path.basename(path)
    stem = os.path.splitext(base)[0]
    parts = stem.split("_")
    if len(parts) < 4 or parts[0] != "gridworld":
        return None
    try:
        return int(parts[1])
    except ValueError:
        return None


def _find_saved_k(requested_base_k: int, available_total_ks: set[int]) -> int | None:
    """Map requested base k (e.g. 4) to saved total K in files.

    IMPORTANT: In this codebase, `gridworld_options.py` typically trains with
    `sign=True`, so it saves total K = 2*k_base for VPS/Eigen (and for Random
    in this repo's convention as well). When both K and 2K exist on disk, we
    *prefer 2K* to avoid ambiguity (e.g., base=8 should map to saved K=16 even
    if there are also saved K=8 files from base=4).
    """
    k = int(requested_base_k)
    if (2 * k) in available_total_ks:
        return 2 * k
    if k in available_total_ks:
        return k
    return None


def main():
    p = argparse.ArgumentParser()
    g = p.add_argument
    g("--inner", type=int, default=10, help="# runs per option set")
    g("--max_steps", type=int, default=3000)
    g("--p_prim", type=float, default=0.9, help="probability of choosing a primitive action")
    g("--out_dir", default="option_results")
    g(
        "--test_option_num",
        type=str,
        default="[4,8,16,32]",
        help='Option counts (base k) to test. Examples: "4", "2,4,6", "[2,4,6]".',
    )
    args = p.parse_args()

    # ---------- environment ----------
    env = SimpleEnv(render_mode=None)
    env.reset()
    T, wall = build_state_transition_matrix(env)
    free_idx = np.where(wall.flatten() == 0)[0]
    N = T.shape[0]

    # Match the saving convention in gridworld_options.py:
    # always resolve the output directory relative to this script,
    # not the current working directory.
    script_dir = os.path.dirname(os.path.abspath(__file__))
    out_dir = os.path.join(script_dir, args.out_dir)
    os.makedirs(out_dir, exist_ok=True)

    methods_with_opt = ["RandomOption", "EigenOpt", "VPSOpt"]

    # ---------- load options grouped by saved total K ----------
    Q_by_method_and_K: dict[str, dict[int, list[np.ndarray]]] = {m: {} for m in methods_with_opt}
    available_total_ks: set[int] = set()

    for m in methods_with_opt:
        files = get_saved_option_files(m, out_dir)
        if not files:
            print(f"[Warning] no saved {m} options found in '{out_dir}'")
        for f in files:
            Q = np.load(f)
            K = _k_from_filename(f)
            if K is None:
                K = int(Q.shape[0])
            Q_by_method_and_K[m].setdefault(int(K), []).append(Q)
            available_total_ks.add(int(K))
            print(f"[Load] {f}  (K={Q.shape[0]})")

    if not available_total_ks:
        raise RuntimeError("No option files found — run the training script first!")

    base_ks = _parse_test_option_num(args.test_option_num)
    pairs: list[tuple[int, int]] = []  # (base_k, saved_total_K)
    skipped: list[int] = []
    for bk in base_ks:
        bk_i = int(bk)
        K = _find_saved_k(bk_i, available_total_ks)
        if K is None:
            skipped.append(bk_i)
            continue
        pairs.append((bk_i, int(K)))
    if skipped:
        print(
            f"[Warning] requested base k not found in saved files (also tried 2*k): {skipped}. "
            f"Available saved K values: {sorted(available_total_ks)}"
        )
    if not pairs:
        raise RuntimeError(
            f"No option files found for requested base k list={base_ks}. "
            f"Available saved K values: {sorted(available_total_ks)}"
        )
    print(f"\n[Test] base k list={base_ks}  → (base→savedK)={pairs}\n")

    # ---------------- experiments ----------------
    curves_by_key: dict[tuple[str, int] | tuple[str, None], list[list[int]]] = {}
    visits_by_key: dict[tuple[str, int] | tuple[str, None], np.ndarray] = {}

    # Primitive baseline (single curve series)
    curves_by_key[("Random", None)] = []
    visits_by_key[("Random", None)] = np.zeros(N, np.int64)

    # For options: create slots per (method, saved_total_K)
    for _, K in pairs:
        for m in methods_with_opt:
            curves_by_key[(m, K)] = []
            visits_by_key[(m, K)] = np.zeros(N, np.int64)

    # Primitive baseline run count: match the maximum n_outer across all (method, K) we test
    n_outer_prim = 0
    for _, K in pairs:
        for m in methods_with_opt:
            n_outer_prim = max(n_outer_prim, len(Q_by_method_and_K[m].get(K, [])))
    if n_outer_prim == 0:
        raise RuntimeError("No option sets found for requested K values.")

    total_runs = n_outer_prim * args.inner
    print(f"[Run]  {total_runs} runs per method …\n")

    # -- primitive baseline --
    for outer in range(n_outer_prim):
        go.set_seed(outer)
        for _ in range(args.inner):
            s0 = int(np.random.choice(free_idx))
            c, v = random_walk_cover(T, free_idx, s0, args.max_steps)
            curves_by_key[("Random", None)].append(c)
            visits_by_key[("Random", None)] += v

    # -- three option types, grouped by K --
    for _, K in pairs:
        for m in methods_with_opt:
            Q_sets = Q_by_method_and_K[m].get(K, [])
            for outer_idx, Q in enumerate(Q_sets):
                go.set_seed(outer_idx)
                for _ in range(args.inner):
                    s0 = int(np.random.choice(free_idx))
                    c, v = mixed_walk_cover(T, free_idx, Q, s0, args.p_prim, args.max_steps)
                    curves_by_key[(m, K)].append(c)
                    visits_by_key[(m, K)] += v

    # ---------------- coverage curves ----------------
    L = args.max_steps
    colors = {
        "Random": "tab:blue",
        "RandomOption": "tab:purple",
        "EigenOpt": "tab:orange",
        "VPSOpt": "tab:green",
    }

    legend_labels = {
        "Random": "Primitive Action",
        "RandomOption": "Random Option",
        "EigenOpt": "Eigenoption",
        "VPSOpt": "VPS Option",
    }

    linestyles = ["-", "--", ":", "-."]
    plt.figure(figsize=(9.5, 4.2))
    x = np.arange(L)
    # Primitive baseline (single line)
    runs_prim = curves_by_key.get(("Random", None), [])
    if runs_prim:
        mat = np.vstack([pad_curve(c, L) for c in runs_prim])
        mu, sig = mat.mean(0), mat.std(0)
        plt.plot(x, mu, color=colors["Random"], lw=1.9, label=legend_labels["Random"])
        plt.fill_between(x, mu - sig, mu + sig, color=colors["Random"], alpha=0.15)

    # Options: plot each (method, K) on the same figure
    for k_idx, (base_k, K) in enumerate(pairs):
        ls = linestyles[k_idx % len(linestyles)]
        for m in methods_with_opt:
            runs = curves_by_key.get((m, K), [])
            if not runs:
                continue
            mat = np.vstack([pad_curve(c, L) for c in runs])
            mu, sig = mat.mean(0), mat.std(0)
            lbl = f"{legend_labels[m]} (k={base_k})"
            plt.plot(x, mu, color=colors[m], lw=1.7, linestyle=ls, label=lbl)
            plt.fill_between(x, mu - sig, mu + sig, color=colors[m], alpha=0.10)
    plt.xlabel("Time Steps", fontsize=16)
    plt.ylabel("Number of Visited States", fontsize=16)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.title("State Coverage Curves (multiple K on one plot)", fontsize=16)
    plt.grid(alpha=0.3)
    plt.xlim(-5, args.max_steps)
    plt.legend(fontsize=11, ncol=2)
    plt.tight_layout()
    plt.show()

    # ---------------- visitation heat maps ----------------
    viz = BottleneckVisualization(env)
    for (meth, K), v in visits_by_key.items():
        if v.sum() == 0:
            continue
        if meth == "Random":
            title = f"State Visit-Count • {legend_labels['Random']}"
        else:
            title = f"State Visit-Count • {legend_labels[meth]} (K={K})"
        viz.plot_2d_heatmap(
            v.astype(float),
            topk=0,
            color_bar=True,
            title=title,
        )
    print("\n[✓] Experiment finished.")


if __name__ == "__main__":
    main()

