#!/usr/bin/env python
"""Visualize Random/Eigen/VPS options interactively in Gymnasium.

Loads option Q-tables and renders each option rollout in a human window
until the option terminates or `max_len` is reached.
"""
from __future__ import annotations
import argparse, glob, os, time, random
from pathlib import Path
from typing import List
import numpy as np
import gymnasium as gym

# ---------- util: check whether an option terminates --------
# def option_terminated(q_row: np.ndarray) -> bool:
#     """Option terminates if all Q values ≤ 0."""
#     return q_row.max() <= 0
def option_terminated(q_row: np.ndarray, L: int = 30) -> bool:
    """Terminate with probability 1/L at each step (simple stochastic horizon)."""
    return random.random() < 1.0 / L


# ---------- load *.npy files --------------------------------
def load_option_files(
    env_id: str,
    opt_type: str,
    out_dir: Path,
    outer: int
) -> List[np.ndarray]:
    """
    Return up to `outer` *.npy files as a list of np.ndarray.
    """
    type_map = {
        "random": "RandomOpt",
        "eigen":  "EigenOpt",
        "vps":    "VPSOpt",
    }
    key = type_map[opt_type.lower()]
    pattern = out_dir / f"{env_id}_*_{key}_*.npy"
    files = sorted(glob.glob(str(pattern)))
    if len(files) < outer:
        raise RuntimeError(
            f"outer={outer} requested, but found only {len(files)} {key} files"
        )
    print(f"[Load] {len(files)} {key} files, only using the first {outer} groups")
    return [np.load(f) for f in files[:outer]]


# ---------- main visualization ------------------------------
def visualize(
    env_id: str,
    Q_groups: List[np.ndarray],
    max_len: int
):
    for gid, Q in enumerate(Q_groups):
        K = Q.shape[0]
        policy = np.argmax(Q, 2)  # (K, S)

        print(f"\n=== Group {gid} • {K} options ===")
        for oid in range(K):
            print(f"[{gid}:{oid}]  press ENTER to start this option …", end="")
            input()

            env = gym.make(env_id, render_mode="human")
            s, _ = env.reset(seed=random.randint(0, 1_000_000))
            s = int(s)
            steps = 0
            while steps < max_len:
                if option_terminated(Q[oid, s]):
                    print(f" option {oid} terminated at step {steps}")
                    break
                a = int(policy[oid, s])
                s, _, term, trunc, _ = env.step(a)
                s = int(s)
                steps += 1
                if term or trunc:
                    print(f" environment done at step {steps}")
                    break
            env.close()
            time.sleep(1)  # short pause for readability


# ----------------------------- CLI --------------------------
def main():
    pa = argparse.ArgumentParser()
    pa.add_argument("--env", default="Taxi-v3",
                    help="Gymnasium discrete environment ID")
    pa.add_argument("--opt_type", default="vps",
                    choices=["random", "eigen", "vps"],
                    help="Which option type to visualize")
    pa.add_argument("--out_dir", default="option_results",
                    help="Directory that stores *.npy option files")
    pa.add_argument("--outer", type=int, default=1,
                    help="Number of option groups to visualize")
    pa.add_argument("--max_len", type=int, default=200,
                    help="Maximum steps to run each option")
    args = pa.parse_args()

    Q_groups = load_option_files(
        args.env,
        args.opt_type,
        Path(args.out_dir),
        args.outer,
    )
    visualize(args.env, Q_groups, args.max_len)


if __name__ == "__main__":
    main()
