#!/usr/bin/env python
"""Visualize saved Taxi-v3 VPS options by writing per-option GIFs."""
import time
import os
import numpy as np
import gymnasium as gym
import imageio.v2 as imageio   # import at top of file

BASE_DIR      = os.path.dirname(os.path.abspath(__file__))
SAVE_DIR      = os.path.join(BASE_DIR, "rollouts")
REALTIME_SHOW = True           # If False, record only (no window)
FPS           = 2              # GIF frame rate
os.makedirs(SAVE_DIR, exist_ok=True)
ENV_ID      = "Taxi-v3"

# ---------- Configuration ----------
OPTION_PATH   = os.path.join(BASE_DIR, "option_results", "taxi_VPS_option_Q_test.npy")
N_STEPS       = 50      # Steps per option rollout
SLEEP_SEC     = 0.1     # Render interval; adjust as needed

# ---------- Load ----------
if not os.path.exists(OPTION_PATH):
    raise FileNotFoundError(f"No file at {OPTION_PATH}")

option_Q = np.load(OPTION_PATH)            # (k, 500, 6)
n_opts, n_states, n_actions = option_Q.shape
print(f"Loaded {n_opts} options  (state dim={n_states}, action dim={n_actions})")

# Greedy policy for each option
opt_policy = np.argmax(option_Q, axis=2)    # (k, 500)

# ---------- Environment ----------
env = gym.make(ENV_ID, render_mode="human")
rng = np.random.RandomState(2025)

def option_terminated(q_row) -> bool:
    """Option terminates when all Q-values ≤ 0."""
    return q_row.max() <= 0

def decode_state(idx: int) -> str:
    taxi_row, taxi_col, pass_loc, dest = env.unwrapped.decode(idx)
    return f"Taxi=({taxi_row},{taxi_col})  Pax={pass_loc}  Dest={dest}"

# ---------- Test function ----------
def play_single_option(opt_id: int, option_Q: np.ndarray, opt_policy: np.ndarray):
    """
    Play and save a single option rollout.
    Re-create the environment each time to avoid reusing the pygame surface.
    """
    # Re-create environment: rgb_array render mode, no pygame GUI
    env = gym.make(ENV_ID, render_mode="rgb_array")
    obs, info = env.reset(seed=int(np.random.randint(1_000_000)))
    state = int(obs)

    frames = []                       # store frames → later save gif/mp4
    t = 0
    while t < N_STEPS:
        frame = env.render()          # ndarray (H,W,3)
        frames.append(frame)
        time.sleep(SLEEP_SEC)

        # termination?
        if option_Q[opt_id, state].max() <= 0:
            break

        action = int(opt_policy[opt_id, state])
        next_obs, _, terminated, truncated, _ = env.step(action)
        state = int(next_obs)
        t += 1
        if terminated or truncated:
            break

    # -------- Save animation (GIF) --------
    gif_path = os.path.join(SAVE_DIR, f"option_{opt_id:02d}.gif")
    imageio.mimsave(gif_path, frames, fps=int(1 / SLEEP_SEC))
    print(f"[✓] option {opt_id} rollout saved → {gif_path}  ({t} frames)")

    env.close()


# ---------- Main loop ----------
n_opts = option_Q.shape[0]
for k in range(n_opts):
    play_single_option(k, option_Q, opt_policy)
