#!/usr/bin/env python
"""Train and visualize VPS options specifically for Taxi-v3.

Trains a VPSOptionAgent, saves the option Q-table, and then
visualizes rollouts per option by saving GIFs.
"""
import os
import numpy as np
import gymnasium as gym
from discrete_options import VPSOptionAgent
from PIL import Image, ImageDraw, ImageFont

# ---------------- Parameters ----------------
NUM_OPTIONS   = 50
PHASE1_EPS    = 10_000
PHASE2_EPOCHS = 10
SEED          = 2025

# ---------------- Environment ----------------
env = gym.make("Taxi-v3", render_mode="rgb_array")
agent = VPSOptionAgent(
    env,
    k=NUM_OPTIONS,
    sign=False,
    seed=SEED,
)

def taxi_decode(idx: int) -> str:
    """Decode the Taxi-v3 discrete index into a human-readable tuple."""
    taxi_row, taxi_col, pass_loc, dest = env.unwrapped.decode(idx)
    return f"(taxi=({taxi_row},{taxi_col}), pass={pass_loc}, dest={dest})"

# State index → RGB frame
def taxi_render(idx: int) -> np.ndarray:
    """
    Convert a discrete Taxi-v3 state into an RGB ndarray (H,W,3) uint8.
    Compatible with both render formats returned by Gymnasium (rgb_array or ANSI text).
    """
    # ---- Temporarily set and restore env state ----
    backup = env.unwrapped.s
    env.unwrapped.s = idx
    frame = env.render()
    env.unwrapped.s = backup

    # ---- Convert to RGB ndarray ----
    # ① Already an ndarray of numeric type
    if isinstance(frame, np.ndarray) and frame.dtype != object:
        return frame

    # ② Some Gym versions return a list / ndarray of strings (ANSI grid)
    if isinstance(frame, (list, tuple)):
        frame = np.asarray(frame, dtype=str)

    if frame.dtype == object or frame.dtype.kind in {"U", "S"}:
        text_lines = frame.astype(str).tolist()
        # Render the ASCII grid to a grayscale image, then upscale to RGB
        font = ImageFont.load_default()
        h = (len(text_lines) + 1) * 10
        w = (max(len(l) for l in text_lines) + 1) * 6
        img = Image.new("L", (w, h), color=255)  # white background
        draw = ImageDraw.Draw(img)
        for r, line in enumerate(text_lines):
            draw.text((2, r * 10), line, fill=0, font=font)
        rgb = np.array(img.convert("RGB"), dtype=np.uint8)
        return rgb

    raise ValueError("Unknown frame format from env.render()")


# ---------------- Training ----------------
print(">>> Training VPS Options on Taxi-v3 …")
agent.train(
    phase1_eps    = PHASE1_EPS,
    phase2_epochs = PHASE2_EPOCHS,
    max_ep_len    = 200,
)

# ---------------- Save ----------------
os.makedirs("option_results", exist_ok=True)
save_path = "option_results/taxi_VPS_option_Q_test.npy"
agent.save_option_Q(save_path)
print(f"[✓] option_Q saved to  {save_path}")

# ---------------- Bottleneck candidates ----------------
print("\n=== Top-1 Potential Bottleneck States (decode row,col,pax,dst) ===")
agent.get_top_states(
    decode_fn = taxi_decode,
    render_fn = taxi_render,
    save_dir  = "vps_max_states"
)

env.close()
