#!/usr/bin/env python
"""Coverage comparison under primitive vs. option-augmented random walks for KeyLockEnv.

Reads saved option Q-tables and runs the same number of trials for each
method to produce state coverage curves and visitation statistics.
"""
import argparse, random, os, glob, ast
import numpy as np
import matplotlib.pyplot as plt
import math
import key_lock_options as klo              # only need set_seed
from key_lock_env import KeyLockEnv
from train_keylock_qlearning import state_to_index, index_to_state


# ---------------- utility functions -------------------------
def option_terminated(q_row, L=15):
    """Option termination during rollout: Q-max ≤ 0 or random 1/L chance (same as keylock_option_random_walk)."""
    return q_row.max() <= 0 or random.random() < 1.0 / L


def available_options(Q: np.ndarray, s: int):
    """
    Indices of options that can be *initiated* at state s.
    Initiation set is the full state space, but an option is
    considered startable at s only if its local Q-max is > 0.
    """
    k = Q.shape[0]
    return [oid for oid in range(k) if Q[oid, s].max() > 0]


def random_walk_cover(env, size, N, s0, max_steps):
    """Pure primitive random walk; return coverage curve and visit counts."""
    visited = np.zeros(N, bool)
    visit_c = np.zeros_like(visited, dtype=np.int32)
    curve = []
    
    obs, info = env.reset()
    s = state_to_index(
        obs[0], obs[1], obs[2], obs[3], obs[4], obs[5], obs[6], size
    )
    
    for _ in range(max_steps):
        visited[s] = True
        visit_c[s] += 1
        
        # Count visited states (all states are "free" in this environment)
        curve.append(visited.sum())
        
        # Random primitive action (0-5: up, down, left, right, pickup, toggle)
        a = random.randint(0, 5)
        
        next_obs, reward, terminated, truncated, info = env.step(a)
        s = state_to_index(
            next_obs[0], next_obs[1], next_obs[2],
            next_obs[3], next_obs[4], next_obs[5], next_obs[6], size
        )
        
        if terminated or truncated:
            obs, info = env.reset()
            s = state_to_index(
                obs[0], obs[1], obs[2], obs[3], obs[4], obs[5], obs[6], size
            )
    
    return curve, visit_c


def mixed_walk_cover(env, size, N, Q, s0, max_steps):
    """
    Random walk that uniformly selects from primitive actions and available options.
    
    At each step, uniformly chooses from:
    - 6 primitive actions (always available)
    - All options that can start at the current state
    
    If no options are available, only primitive actions are considered.
    Each action/option has equal probability: 1 / (6 + num_available_options).
    """
    k, pol = Q.shape[0], np.argmax(Q, 2)
    n_prim = 6  # Number of primitive actions (up, down, left, right, pickup, toggle)
    visited = np.zeros(N, bool)
    visit_c = np.zeros_like(visited, dtype=np.int32)
    curve = []
    
    obs, info = env.reset()
    s = state_to_index(
        obs[0], obs[1], obs[2], obs[3], obs[4], obs[5], obs[6], size
    )
    
    while len(curve) < max_steps:
        visited[s] = True
        visit_c[s] += 1
        curve.append(visited.sum())

        # Same logic as keylock_option_random_walk: candidates = primitives + startable options, uniform choice
        startable_opts = available_options(Q, s)
        candidates = list(range(n_prim)) + [n_prim + oid for oid in startable_opts]
        a = random.choice(candidates)

        if a < n_prim:
            # Selected a primitive action
            next_obs, reward, terminated, truncated, info = env.step(a)
            s = state_to_index(
                next_obs[0], next_obs[1], next_obs[2],
                next_obs[3], next_obs[4], next_obs[5], next_obs[6], size
            )
            if terminated or truncated:
                obs, info = env.reset()
                s = state_to_index(
                    obs[0], obs[1], obs[2], obs[3], obs[4], obs[5], obs[6], size
                )
        else:
            # Selected an option: oid = a - n_prim
            oid = a - n_prim
            # Execute option rollout
            while (not option_terminated(Q[oid, s])) and len(curve) < max_steps:
                visited[s] = True
                visit_c[s] += 1
                curve.append(visited.sum())
                
                # Get action from option policy
                a = pol[oid, s]
                next_obs, reward, terminated, truncated, info = env.step(a)
                s = state_to_index(
                    next_obs[0], next_obs[1], next_obs[2],
                    next_obs[3], next_obs[4], next_obs[5], next_obs[6], size
                )
                
                if terminated or truncated:
                    obs, info = env.reset()
                    s = state_to_index(
                        obs[0], obs[1], obs[2], obs[3], obs[4], obs[5], obs[6], size
                    )
                    break
    
    return curve, visit_c


def pad_curve(curve, L):
    """Pad or truncate a curve to length L for array stacking."""
    if len(curve) >= L:
        return np.asarray(curve[:L], float)
    return np.concatenate([curve, np.full(L - len(curve), curve[-1], float)])


def get_saved_option_files(method: str, out_dir: str):
    """
    Locate saved option files. Expected filename pattern:
      keylock_*_{method}_*.npy
    """
    pat = os.path.join(out_dir, f"keylock_*_{method}_*.npy")
    return sorted(glob.glob(pat))


def _parse_test_option_num(x: str) -> list[int]:
    """Parse --test_option_num list.

    Examples:
      - "[2,4,6]"
      - "2,4,6"
      - "4"
    """
    s = str(x).strip()
    if not s:
        raise ValueError("--test_option_num is empty")
    if "," in s and not (s.startswith("[") or s.startswith("(")):
        return [int(p.strip()) for p in s.split(",") if p.strip()]
    try:
        val = ast.literal_eval(s)
    except Exception:
        return [int(s)]
    if isinstance(val, int):
        return [int(val)]
    if isinstance(val, (list, tuple)):
        return [int(v) for v in val]
    raise ValueError(f"Unsupported --test_option_num format: {x!r}")


def _k_from_filename(path: str) -> int | None:
    """Extract K from `keylock_{K}_{Method}_{outer}.npy`."""
    base = os.path.basename(path)
    stem = os.path.splitext(base)[0]
    parts = stem.split("_")
    if len(parts) < 4 or parts[0] != "keylock":
        return None
    try:
        return int(parts[1])
    except ValueError:
        return None


def _find_saved_k(requested_base_k: int, available_total_ks: set[int]) -> int | None:
    """Map requested base k (e.g. 4) to saved total K in files.

    IMPORTANT: In this codebase, `key_lock_options.py` typically trains with
    `sign=True`, so it saves total K = 2*k_base for VPS/Eigen (and for Random
    in this repo's convention as well). When both K and 2K exist on disk, we
    *prefer 2K* to avoid ambiguity (e.g., base=8 should map to saved K=16 even
    if there are also saved K=8 files from base=4).
    """
    k = int(requested_base_k)
    if (2 * k) in available_total_ks:
        return 2 * k
    if k in available_total_ks:
        return k
    return None


# ----------------------------- main -------------------------
def main():
    p = argparse.ArgumentParser()
    g = p.add_argument
    g("--inner", type=int, default=1, help="# runs per option set")
    g("--max_steps", type=int, default=2000)
    g("--out_dir", default="keylock_option_results")
    g(
        "--test_option_num",
        type=str,
        default="[4,8,16]",
        help='Option counts (base k) to test. Examples: "4", "2,4,6", "[2,4,6]".',
    )
    g("--size", type=int, default=15, help="grid size")
    g("--yellow_key_pos", type=int, nargs=2, default=[12, 3], help="yellow key position (x, y)")
    g("--yellow_door_pos", type=int, nargs=2, default=[3, 8], help="yellow door position (x, y)")
    g("--blue_key_pos", type=int, nargs=2, default=[12, 12], help="blue key position (x, y)")
    g("--blue_door_pos", type=int, nargs=2, default=[9, 3], help="blue door position (x, y)")
    g("--goal_pos", type=int, nargs=2, default=[3, 12], help="goal position (x, y)")
    args = p.parse_args()

    # ---------- environment ----------
    env = KeyLockEnv(
        size=args.size,
        agent_start_pos=(1, 1),
        agent_start_dir=0,
        yellow_key_pos=tuple(args.yellow_key_pos),
        yellow_door_pos=tuple(args.yellow_door_pos),
        blue_key_pos=tuple(args.blue_key_pos),
        blue_door_pos=tuple(args.blue_door_pos),
        goal_pos=tuple(args.goal_pos),
        render_mode=None,
    )
    env.reset()
    
    # State space size: size * size * 4 * 2 * 2 * 2 * 2
    N = args.size * args.size * 4 * 2 * 2 * 2 * 2

    # Match the saving convention in key_lock_options.py:
    # always resolve the output directory relative to this script,
    # not the current working directory.
    script_dir = os.path.dirname(os.path.abspath(__file__))
    out_dir = os.path.join(script_dir, args.out_dir)
    os.makedirs(out_dir, exist_ok=True)

    methods_with_opt = ["RandomOption", "EigenOpt", "VPSOpt"]
    methods_plot = ["Random", *methods_with_opt]

    # ---------- load options grouped by saved total K ----------
    Q_by_method_and_K: dict[str, dict[int, list[np.ndarray]]] = {m: {} for m in methods_with_opt}
    available_total_ks: set[int] = set()

    for m in methods_with_opt:
        files = get_saved_option_files(m, out_dir)
        if not files:
            print(f"[Warning] no saved {m} options found in '{out_dir}'")
        for f in files:
            Q = np.load(f)
            K = _k_from_filename(f)
            if K is None:
                K = int(Q.shape[0])
            Q_by_method_and_K[m].setdefault(int(K), []).append(Q)
            available_total_ks.add(int(K))
            print(f"[Load] {f}  (K={Q.shape[0]})")

    if not available_total_ks:
        raise RuntimeError("No option files found — run the training script first!")

    base_ks = _parse_test_option_num(args.test_option_num)
    pairs: list[tuple[int, int]] = []  # (base_k, saved_total_K)
    skipped: list[int] = []
    for bk in base_ks:
        bk_i = int(bk)
        K = _find_saved_k(bk_i, available_total_ks)
        if K is None:
            skipped.append(bk_i)
            continue
        pairs.append((bk_i, int(K)))
    if skipped:
        print(
            f"[Warning] requested base k not found in saved files (also tried 2*k): {skipped}. "
            f"Available saved K values: {sorted(available_total_ks)}"
        )
    if not pairs:
        raise RuntimeError(
            f"No option files found for requested base k list={base_ks}. "
            f"Available saved K values: {sorted(available_total_ks)}"
        )
    print(f"\n[Test] base k list={base_ks}  → (base→savedK)={pairs}\n")

    # ---------------- experiments ----------------
    # curves_by_key[(method, K)] = list[curves]; primitive baseline is method="Random" (no K)
    curves_by_key: dict[tuple[str, int] | tuple[str, None], list[list[int]]] = {}
    visits_by_key: dict[tuple[str, int] | tuple[str, None], np.ndarray] = {}

    # Primitive baseline (single curve series)
    curves_by_key[("Random", None)] = []
    visits_by_key[("Random", None)] = np.zeros(N, np.int64)

    # For options: create slots per (method, saved_total_K)
    for _, K in pairs:
        for m in methods_with_opt:
            curves_by_key[(m, K)] = []
            visits_by_key[(m, K)] = np.zeros(N, np.int64)

    # Primitive baseline run count: match the maximum n_outer across all (method, K) we test
    n_outer_prim = 0
    for _, K in pairs:
        for m in methods_with_opt:
            n_outer_prim = max(n_outer_prim, len(Q_by_method_and_K[m].get(K, [])))
    if n_outer_prim == 0:
        raise RuntimeError("No option sets found for requested K values.")

    total_runs = n_outer_prim * args.inner
    print(f"[Run]  {total_runs} runs per method …\n")

    # -- primitive baseline --
    for outer in range(n_outer_prim):
        klo.set_seed(outer)
        for _ in range(args.inner):
            # Reset environment to get random initial state
            obs, info = env.reset()
            s0 = state_to_index(
                obs[0], obs[1], obs[2], obs[3], obs[4], obs[5], obs[6], args.size
            )
            c, v = random_walk_cover(env, args.size, N, s0, args.max_steps)
            curves_by_key[("Random", None)].append(c)
            visits_by_key[("Random", None)] += v

    # -- three option types, grouped by K --
    for _, K in pairs:
        for m in methods_with_opt:
            Q_sets = Q_by_method_and_K[m].get(K, [])
            for outer_idx, Q in enumerate(Q_sets):
                klo.set_seed(outer_idx)
                for _ in range(args.inner):
                    # Reset environment to get random initial state
                    obs, info = env.reset()
                    s0 = state_to_index(
                        obs[0], obs[1], obs[2], obs[3], obs[4], obs[5], obs[6], args.size
                    )
                    c, v = mixed_walk_cover(env, args.size, N, Q, s0, args.max_steps)
                    curves_by_key[(m, K)].append(c)
                    visits_by_key[(m, K)] += v

    # ---------------- coverage curves ----------------
    L = args.max_steps
    colors = {"Random": "tab:blue",
              "RandomOption": "tab:purple",
              "EigenOpt": "tab:orange",
              "VPSOpt": "tab:green"}

    legend_labels = {"Random": "Primitive Action",
                     "RandomOption": "Random Option",
                     "EigenOpt": "Eigenoption",
                     "VPSOpt": "VPS Option"}

    linestyles = ["-", "--", ":", "-."]
    plt.figure(figsize=(9.5, 4.2))
    x = np.arange(L)
    # Primitive baseline (single line)
    runs_prim = curves_by_key.get(("Random", None), [])
    if runs_prim:
        mat = np.vstack([pad_curve(c, L) for c in runs_prim])
        mu, sig = mat.mean(0), mat.std(0)
        plt.plot(x, mu, color=colors["Random"], lw=1.9, label=legend_labels["Random"])
        plt.fill_between(x, mu - sig, mu + sig, color=colors["Random"], alpha=.15)

    # Options: plot each (method, K) on the same figure
    for k_idx, (base_k, K) in enumerate(pairs):
        ls = linestyles[k_idx % len(linestyles)]
        for m in methods_with_opt:
            runs = curves_by_key.get((m, K), [])
            if not runs:
                continue
            mat = np.vstack([pad_curve(c, L) for c in runs])
            mu, sig = mat.mean(0), mat.std(0)
            # Label by the *requested base-k* so it matches the CLI expectation.
            lbl = f"{legend_labels[m]} (k={base_k})"
            plt.plot(x, mu, color=colors[m], lw=1.7, linestyle=ls, label=lbl)
            plt.fill_between(x, mu - sig, mu + sig, color=colors[m], alpha=.10)
    plt.xlabel("Time Steps", fontsize=16)
    plt.ylabel("Number of Visited States", fontsize=16)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.title("State Coverage Curves (KeyLockEnv)", fontsize=16)
    plt.grid(alpha=.3)
    plt.xlim(-5, args.max_steps)
    plt.legend(fontsize=11, ncol=2)
    plt.tight_layout()
    plt.savefig(os.path.join(out_dir, "keylock_coverage_curves.png"), dpi=300, bbox_inches='tight')
    print(f"\n[Save] Coverage curves saved to {os.path.join(out_dir, 'keylock_coverage_curves.png')}")
    plt.show()

    # ---------------- print statistics ----------------
    print("\n" + "=" * 70)
    print("State Coverage Statistics")
    print("=" * 70)
    for (meth, K), v in visits_by_key.items():
        if v.sum() == 0:
            continue
        total_visits = v.sum()
        unique_states = (v > 0).sum()
        coverage_pct = 100 * unique_states / N
        if meth == "Random":
            print(f"{legend_labels['Random']}:")
        else:
            print(f"{legend_labels[meth]} (K={K}):")
        print(f"  Total visits: {total_visits}")
        print(f"  Unique states visited: {unique_states} / {N} ({coverage_pct:.2f}%)")
        print(f"  Average visits per state: {total_visits / N:.2f}")
        print()
    
    print("\n[✓] Experiment finished.")


if __name__ == "__main__":
    main()
