#!/usr/bin/env python
"""Visualize Random/Eigen/VPS options interactively in Gymnasium.

Loads option Q-tables and renders each option rollout in a human window
until the option terminates or `max_len` is reached.
"""
from __future__ import annotations
import argparse, glob, os, time, random
from pathlib import Path
from typing import List
import numpy as np
import gymnasium as gym

# ---------- util: option start / termination ----------------
# def option_terminated(q_row: np.ndarray) -> bool:
#     """Option terminates if all Q values ≤ 0."""
#     return q_row.max() <= 0

def option_can_start(q_row: np.ndarray) -> bool:
    """
    Gridworld-style initiation: full state space, but an option is
    considered startable at s only if its local Q-max is strictly > 0.
    """
    return q_row.max() > 0


def option_terminated(q_row: np.ndarray, L: int = 15) -> bool:
    """
    Gridworld-style termination during rollout: an option terminates if
    its local Q-max is non-positive OR with probability 1/L at each step.
    """
    return (q_row.max() <= 0) or (random.random() < 1.0 / L)


# ---------- load *.npy files --------------------------------
def load_option_files(
    env_id: str,
    opt_type: str,
    out_dir: Path,
    outer: int
) -> List[np.ndarray]:
    """
    Return up to `outer` *.npy files as a list of np.ndarray.
    """
    type_map = {
        "random": "RandomOpt",
        "eigen":  "EigenOpt",
        "vps":    "VPSOpt",
    }
    key = type_map[opt_type.lower()]
    pattern = out_dir / f"{env_id}_*_{key}_*.npy"
    files = sorted(glob.glob(str(pattern)))
    if len(files) < outer:
        raise RuntimeError(
            f"outer={outer} requested, but found only {len(files)} {key} files"
        )
    print(f"[Load] {len(files)} {key} files, only using the first {outer} groups")
    return [np.load(f) for f in files[:outer]]


# ---------- main visualization ------------------------------
def visualize(
    env_id: str,
    Q_groups: List[np.ndarray],
    max_len: int
):
    for gid, Q in enumerate(Q_groups):
        K = Q.shape[0]
        policy = np.argmax(Q, 2)  # (K, S)

        print(f"\n=== Group {gid} • {K} options ===")
        for oid in range(K):
            print(f"[{gid}:{oid}]  press ENTER to start this option …", end="")
            input()

            env = gym.make(env_id, render_mode="human")
            s, _ = env.reset(seed=random.randint(0, 1_000_000))
            s = int(s)
            steps = 0
            while steps < max_len:
                if option_terminated(Q[oid, s]):
                    print(f" option {oid} terminated at step {steps}")
                    break
                a = int(policy[oid, s])
                s, _, term, trunc, _ = env.step(a)
                s = int(s)
                steps += 1
                if term or trunc:
                    print(f" environment done at step {steps}")
                    break
            env.close()
            time.sleep(1)  # short pause for readability


# ----------------------------- CLI --------------------------
def main():
    pa = argparse.ArgumentParser()
    pa.add_argument("--env", default="Taxi-v3",
                    help="Gymnasium discrete environment ID")
    pa.add_argument("--opt_type", default="vps",
                    choices=["random", "eigen", "vps"],
                    help="Which option type to visualize")
    pa.add_argument("--out_dir", default="option_results",
                    help="Directory (under this folder) that stores *.npy option files")
    pa.add_argument("--outer", type=int, default=1,
                    help="Number of option groups to visualize")
    pa.add_argument("--max_len", type=int, default=200,
                    help="Maximum steps to run each option")
    args = pa.parse_args()

    # Resolve option directory relative to this script:
    # `discrete_gym/option_results/<Env>/...`.
    script_dir = os.path.dirname(os.path.abspath(__file__))
    root_dir = os.path.join(script_dir, args.out_dir)
    opt_dir = os.path.join(root_dir, args.env)

    Q_groups = load_option_files(
        args.env,
        args.opt_type,
        Path(opt_dir),
        args.outer,
    )
    visualize(args.env, Q_groups, args.max_len)


if __name__ == "__main__":
    main()
