#!/usr/bin/env python
"""Visualize trained options in KeyLockEnv.

Loads option Q-tables and renders option rollouts in a human window.
Supports sequential execution of multiple options with automatic switching.
"""
from __future__ import annotations
import argparse, glob, os, time, random
from pathlib import Path
from typing import List, Optional
import numpy as np
import key_lock_options as klo
from key_lock_env import KeyLockEnv
from train_keylock_qlearning import state_to_index


# ---------- util: option start / termination ----------------
def option_can_start(q_row: np.ndarray) -> bool:
    """
    Gridworld-style initiation: full state space, but an option is
    considered startable at s only if its local Q-max is strictly > 0.
    """
    return q_row.max() > 0


def option_terminated(q_row: np.ndarray, L: int = 15) -> bool:
    """
    Gridworld-style termination during rollout: an option terminates if
    its local Q-max is non-positive OR with probability 1/L at each step.
    """
    # return (q_row.max() <= 0) or (random.random() < 1.0 / L)
    return (q_row.max() <= 0)


# ---------- extract K from filename -------------------------
def _k_from_filename(path: str) -> Optional[int]:
    """Extract K from `keylock_{K}_{Method}_{outer}.npy`."""
    base = os.path.basename(path)
    stem = os.path.splitext(base)[0]
    parts = stem.split("_")
    if len(parts) < 4 or parts[0] != "keylock":
        return None
    try:
        return int(parts[1])
    except ValueError:
        return None


# ---------- load option files --------------------------------
def load_option_files(
    opt_type: str,
    out_dir: Path,
    group_idx: int = 0,
    k_value: Optional[int] = None,
) -> Optional[np.ndarray]:
    """
    Load a single option file by type, group index, or k value.
    
    Args:
        opt_type: Option type (random, eigen, vps)
        out_dir: Directory containing option files
        group_idx: Group index (outer index) to select (default: 0)
        k_value: If specified, select file with this K value (overrides group_idx)
    
    Returns:
        Q: option Q-table (K, N, 6) or None if not found
    """
    type_map = {
        "random": "RandomOption",
        "eigen":  "EigenOpt",
        "vps":    "VPSOpt",
    }
    key = type_map[opt_type.lower()]
    pattern = out_dir / f"keylock_*_{key}_*.npy"
    files = glob.glob(str(pattern))
    
    if len(files) == 0:
        print(f"[Warning] No {key} files found in {out_dir}")
        return None
    
    # Extract K values and sort files by K (numerically)
    files_with_k = []
    for f in files:
        k = _k_from_filename(f)
        if k is None:
            # Fallback: load file to get K from shape
            try:
                Q_temp = np.load(f)
                k = Q_temp.shape[0]
            except:
                continue
        files_with_k.append((k, f))
    
    # Sort by K value (numerically)
    files_with_k.sort(key=lambda x: x[0])
    
    # Filter by k_value if specified
    if k_value is not None:
        matching_files = [(k, f) for k, f in files_with_k if k == k_value]
        if len(matching_files) == 0:
            available_ks = sorted(set(k for k, _ in files_with_k))
            print(f"[Warning] No files found with K={k_value} for {key}")
            print(f"  Available K values: {available_ks}")
            return None
        files_with_k = matching_files
    
    # Select by group_idx (outer index)
    if group_idx >= len(files_with_k):
        print(f"[Warning] Requested group {group_idx}, but only {len(files_with_k)} files found. Using last file.")
        group_idx = len(files_with_k) - 1
    
    k, selected_file = files_with_k[group_idx]
    Q = np.load(selected_file)
    print(f"[Load] {Path(selected_file).name}  (K={Q.shape[0]})")
    return Q


def list_available_options(out_dir: Path):
    """List all available option files, grouped by K value."""
    type_map = {
        "random": "RandomOption",
        "eigen":  "EigenOpt",
        "vps":    "VPSOpt",
    }
    
    print("\n" + "=" * 70)
    print("Available Option Files")
    print("=" * 70)
    
    for opt_type, key in type_map.items():
        pattern = out_dir / f"keylock_*_{key}_*.npy"
        files = glob.glob(str(pattern))
        
        print(f"\n{opt_type.upper()} ({key}):")
        if len(files) == 0:
            print("  (none)")
        else:
            # Group files by K value
            files_by_k = {}
            for f in files:
                k = _k_from_filename(f)
                if k is None:
                    try:
                        Q_temp = np.load(f)
                        k = Q_temp.shape[0]
                    except:
                        k = "unknown"
                if k not in files_by_k:
                    files_by_k[k] = []
                files_by_k[k].append(f)
            
            # Sort by K value
            for k in sorted(files_by_k.keys(), key=lambda x: (isinstance(x, str), x)):
                files_list = sorted(files_by_k[k])
                print(f"  K={k}:")
                for i, f in enumerate(files_list):
                    try:
                        Q = np.load(f)
                        actual_k = Q.shape[0]
                        print(f"    Group {i}: {Path(f).name}  (K={actual_k} options)")
                    except:
                        print(f"    Group {i}: {Path(f).name}  (could not load)")
    
    print("=" * 70 + "\n")


# ---------- visualize single option --------------------------
def visualize_single_option(
    env: KeyLockEnv,
    size: int,
    Q: np.ndarray,
    option_id: int,
    max_steps: int = 200,
    render_delay: float = 0.3,
    auto_switch: bool = False,
    reset_env: bool = True,
):
    """
    Visualize execution of a single option.
    
    Args:
        env: KeyLockEnv instance with render_mode="human"
        size: grid size
        Q: option Q-table (K, N, 6)
        option_id: which option to execute (0 to K-1)
        max_steps: maximum steps for this option
        render_delay: delay between steps (seconds)
        auto_switch: if True, switch to next option after max_steps
        reset_env: if True, reset environment before starting option
    """
    K = Q.shape[0]
    if option_id < 0 or option_id >= K:
        print(f"[Error] Invalid option_id {option_id}. Must be 0 to {K-1}.")
        return False
    
    policy = np.argmax(Q, 2)  # (K, N, 6) -> (K, N)
    
    if reset_env:
        obs, info = env.reset()
        env.render()
        time.sleep(0.5)  # Pause after reset to see initial state
    else:
        # Get current observation without resetting
        obs = env._get_obs()
        info = {}
    
    s = state_to_index(
        obs[0], obs[1], obs[2], obs[3], obs[4], obs[5], obs[6], size
    )
    
    print(f"\n[Option {option_id}] Starting execution at state {s}...")
    time.sleep(0.3)  # Pause before starting option execution
    steps = 0
    
    # Check if option can start at initial state
    if not option_can_start(Q[option_id, s]):
        print(f"[Option {option_id}] Cannot start at initial state {s} (Q-max <= 0)")
        print(f"[Option {option_id}] Attempting to find a startable state by random exploration...")
        
        # Try to find a startable state by random exploration
        max_explore_steps = 50
        found_startable = False
        for explore_step in range(max_explore_steps):
            # Try random action
            a_random = random.randint(0, 5)
            next_obs, reward, terminated, truncated, info = env.step(a_random)
            s = state_to_index(
                next_obs[0], next_obs[1], next_obs[2],
                next_obs[3], next_obs[4], next_obs[5], next_obs[6], size
            )
            env.render()
            time.sleep(render_delay * 0.5)  # Faster exploration
            
            if reward == 1.0:
                print(f"[Option {option_id}] Goal reached during exploration!")
                return True
            
            if terminated or truncated:
                print(f"[Option {option_id}] Environment terminated during exploration")
                return False
            
            if option_can_start(Q[option_id, s]):
                print(f"[Option {option_id}] Found startable state {s} after {explore_step + 1} exploration steps")
                found_startable = True
                break
        
        if not found_startable:
            print(f"[Option {option_id}] Could not find startable state after {max_explore_steps} steps")
            print(f"[Option {option_id}] Proceeding anyway (may execute with Q-max <= 0)...")
            # Continue execution even if Q-max <= 0
    
    while steps < max_steps:
        
        # Get action from option policy
        a = int(policy[option_id, s])
        
        # Execute action
        next_obs, reward, terminated, truncated, info = env.step(a)
        s = state_to_index(
            next_obs[0], next_obs[1], next_obs[2],
            next_obs[3], next_obs[4], next_obs[5], next_obs[6], size
        )
        
        steps += 1
        
        # Render
        env.render()
        if render_delay > 0:
            time.sleep(render_delay)
        
        # Check for goal
        if reward == 1.0:
            print(f"[Option {option_id}] ✓ Goal reached at step {steps}!")
            time.sleep(1.0)  # Pause longer when goal is reached
            return True
        
        # Check environment termination
        if terminated or truncated:
            print(f"[Option {option_id}] Environment terminated at step {steps}")
            time.sleep(0.5)  # Pause when environment terminates
            break
        
        # Check if option terminated (after executing action and updating state)
        if option_terminated(Q[option_id, s]):
            print(f"[Option {option_id}] Terminated at step {steps}")
            time.sleep(0.5)  # Pause when option terminates
            break
    
    if steps >= max_steps:
        print(f"[Option {option_id}] Reached max_steps ({max_steps})")
        time.sleep(0.5)  # Pause when max steps reached
        return auto_switch
    
    time.sleep(0.3)  # Pause before returning
    return False


# ---------- sequential visualization -------------------------
def visualize_sequential_options(
    env: KeyLockEnv,
    size: int,
    Q: np.ndarray,
    option_ids: List[int],
    max_steps_per_option: int = 200,
    render_delay: float = 0.1,
    auto_switch: bool = True,
    pause_between_options: float = 1.0,
    reset_between_options: bool = False,
):
    """
    Visualize multiple options sequentially.
    
    Args:
        env: KeyLockEnv instance with render_mode="human"
        size: grid size
        Q: option Q-table (K, N, 6)
        option_ids: list of option IDs to execute in order
        max_steps_per_option: maximum steps per option
        render_delay: delay between steps (seconds)
        auto_switch: if True, automatically switch to next option after max_steps
        pause_between_options: pause time between options (seconds)
        reset_between_options: if True, reset environment between options
    """
    K = Q.shape[0]
    
    print("\n" + "=" * 70)
    print("Sequential Option Visualization")
    print("=" * 70)
    print(f"Total options in Q-table: {K}")
    print(f"Options to execute: {option_ids}")
    print(f"Max steps per option: {max_steps_per_option}")
    print(f"Auto-switch: {auto_switch}")
    print(f"Reset between options: {reset_between_options}")
    print("=" * 70)
    
    # Reset environment before starting
    env.reset()
    env.render()
    time.sleep(1.0)  # Longer pause to see initial state
    
    for idx, opt_id in enumerate(option_ids):
        if opt_id < 0 or opt_id >= K:
            print(f"\n[Skipping] Invalid option_id {opt_id}. Must be 0 to {K-1}.")
            continue
        
        print(f"\n--- Option {idx + 1}/{len(option_ids)}: Option ID {opt_id} ---")
        
        # Reset environment if requested (or for first option)
        reset_env = (idx == 0) or reset_between_options
        
        goal_reached = visualize_single_option(
            env, size, Q, opt_id, max_steps_per_option, 
            render_delay, auto_switch, reset_env=reset_env
        )
        
        if goal_reached and not auto_switch:
            print("\n✓ Goal reached! Stopping visualization.")
            break
        
        # Pause between options
        if idx < len(option_ids) - 1 and pause_between_options > 0:
            print(f"\nPausing {pause_between_options}s before next option...")
            env.render()  # Render one more time before pause
            time.sleep(pause_between_options)


# ---------- interactive selection ---------------------------
def interactive_select_options(Q: np.ndarray) -> List[int]:
    """
    Interactively select which options to visualize.
    
    Returns:
        List of selected option IDs
    """
    K = Q.shape[0]
    print(f"\nAvailable options: 0 to {K-1}")
    print("Enter option IDs separated by commas (e.g., '0,2,5' or '0-4' for range)")
    print("Or press ENTER to visualize all options sequentially.")
    
    user_input = input("Option IDs: ").strip()
    
    if not user_input:
        # Default: all options
        return list(range(K))
    
    option_ids = []
    
    # Parse input
    parts = user_input.split(',')
    for part in parts:
        part = part.strip()
        if '-' in part:
            # Range: e.g., "0-4"
            try:
                start, end = part.split('-')
                start, end = int(start.strip()), int(end.strip())
                option_ids.extend(range(start, end + 1))
            except ValueError:
                print(f"[Warning] Invalid range: {part}")
        else:
            # Single number
            try:
                opt_id = int(part)
                if 0 <= opt_id < K:
                    option_ids.append(opt_id)
                else:
                    print(f"[Warning] Option ID {opt_id} out of range [0, {K-1}]")
            except ValueError:
                print(f"[Warning] Invalid option ID: {part}")
    
    # Remove duplicates and sort
    option_ids = sorted(list(set(option_ids)))
    
    if not option_ids:
        print("[Warning] No valid options selected. Using all options.")
        return list(range(K))
    
    return option_ids


# ----------------------------- CLI --------------------------
def main():
    pa = argparse.ArgumentParser(description="Visualize trained options in KeyLockEnv")
    pa.add_argument("--opt_type", default="vps",
                    choices=["random", "eigen", "vps"],
                    help="Which option type to visualize")
    pa.add_argument("--out_dir", default="keylock_option_results",
                    help="Directory that stores *.npy option files")
    pa.add_argument("--group", type=int, default=0,
                    help="Which option group (outer index) to load (default: 0). "
                          "Ignored if --k is specified.")
    pa.add_argument("--k", type=int, default=16,
                    help="Select option file by K value (number of options). "
                          "If specified, overrides --group. Example: --k 32")
    pa.add_argument("--option_ids", type=str, default=None,
                    help="Comma-separated option IDs (e.g., '0,2,5') or range (e.g., '0-4'). "
                          "If not specified, will prompt interactively.")
    pa.add_argument("--max_steps", type=int, default=1000,
                    help="Maximum steps per option (default: 200)")
    pa.add_argument("--render_delay", type=float, default=0.3,
                    help="Delay between steps in seconds (default: 0.3)")
    pa.add_argument("--auto_switch", action="store_true",
                    help="Automatically switch to next option after max_steps")
    pa.add_argument("--pause", type=float, default=2.0,
                    help="Pause time between options in seconds (default: 2.0)")
    pa.add_argument("--reset_between", action="store_true",
                    help="Reset environment between options (default: continue from current state)")
    pa.add_argument("--size", type=int, default=15, help="grid size")
    pa.add_argument("--yellow_key_pos", type=int, nargs=2, default=[12, 3], 
                    help="yellow key position (x, y)")
    pa.add_argument("--yellow_door_pos", type=int, nargs=2, default=[3, 8], 
                    help="yellow door position (x, y)")
    pa.add_argument("--blue_key_pos", type=int, nargs=2, default=[12, 12], 
                    help="blue key position (x, y)")
    pa.add_argument("--blue_door_pos", type=int, nargs=2, default=[9, 3], 
                    help="blue door position (x, y)")
    pa.add_argument("--goal_pos", type=int, nargs=2, default=[3, 12], 
                    help="goal position (x, y)")
    pa.add_argument("--list", action="store_true",
                    help="List all available option files and exit")
    args = pa.parse_args()

    # Resolve option directory
    script_dir = os.path.dirname(os.path.abspath(__file__))
    out_dir = os.path.join(script_dir, args.out_dir)
    os.makedirs(out_dir, exist_ok=True)
    
    # List available options if requested
    if args.list:
        list_available_options(Path(out_dir))
        return
    
    # Load option file
    Q = load_option_files(args.opt_type, Path(out_dir), args.group, k_value=args.k)
    if Q is None:
        print("\n[Error] Could not load option file. Use --list to see available files.")
        return
    
    K = Q.shape[0]
    print(f"\nLoaded option Q-table: shape {Q.shape} ({K} options)")
    
    # Select options to visualize
    if args.option_ids:
        # Parse from command line
        option_ids = []
        parts = args.option_ids.split(',')
        for part in parts:
            part = part.strip()
            if '-' in part:
                try:
                    start, end = part.split('-')
                    start, end = int(start.strip()), int(end.strip())
                    option_ids.extend(range(start, end + 1))
                except ValueError:
                    print(f"[Warning] Invalid range: {part}")
            else:
                try:
                    opt_id = int(part)
                    if 0 <= opt_id < K:
                        option_ids.append(opt_id)
                except ValueError:
                    print(f"[Warning] Invalid option ID: {part}")
        option_ids = sorted(list(set(option_ids)))
        if not option_ids:
            print("[Error] No valid options selected.")
            return
    else:
        # Interactive selection
        option_ids = interactive_select_options(Q)
    
    print(f"\nSelected options: {option_ids}")
    input("\nPress ENTER to start visualization...")
    
    # Create visualization environment
    env = KeyLockEnv(
        size=args.size,
        agent_start_pos=(1, 1),
        agent_start_dir=0,
        yellow_key_pos=tuple(args.yellow_key_pos),
        yellow_door_pos=tuple(args.yellow_door_pos),
        blue_key_pos=tuple(args.blue_key_pos),
        blue_door_pos=tuple(args.blue_door_pos),
        goal_pos=tuple(args.goal_pos),
        render_mode="human",
        highlight=False,
    )
    
    # Visualize
    try:
        visualize_sequential_options(
            env, args.size, Q, option_ids,
            max_steps_per_option=args.max_steps,
            render_delay=args.render_delay,
            auto_switch=args.auto_switch,
            pause_between_options=args.pause,
            reset_between_options=args.reset_between,
        )
    except KeyboardInterrupt:
        print("\n\n[Interrupted] Visualization stopped by user.")
    finally:
        env.close()
        print("\n[✓] Visualization complete.")


if __name__ == "__main__":
    main()
