#!/usr/bin/env python
"""Q-learning training script for KeyLockEnv.

This script trains a Q-learning agent to complete the complex key-lock task:
1. Pick up yellow key
2. Open yellow door
3. Pick up blue key
4. Open blue door
5. Reach the goal

The script plots the training reward curve.
"""

import numpy as np
import matplotlib.pyplot as plt
import random
import glob
import os
from pathlib import Path
from typing import Optional, Dict, List
from key_lock_env import KeyLockEnv


def state_to_index(x, y, dir, yellow_door_open, blue_door_open, yellow_key_on_map, blue_key_on_map, size=15):
    """
    Convert state tuple to flattened index (consistent with key_lock_options.py).
    
    Args:
        x, y: position (0 to size-1)
        dir: direction (0-3)
        yellow_door_open: 0 or 1
        blue_door_open: 0 or 1
        yellow_key_on_map: 0 or 1
        blue_key_on_map: 0 or 1
        size: grid size
    
    Returns:
        state_index: flattened state index
    """
    index = x
    index = index * size + y
    index = index * 4 + dir
    index = index * 2 + yellow_door_open
    index = index * 2 + blue_door_open
    index = index * 2 + yellow_key_on_map
    index = index * 2 + blue_key_on_map
    return int(index)


def index_to_state(state_index, size=15):
    """
    Convert flattened index back to state tuple.
    
    Args:
        state_index: flattened state index
        size: grid size
    
    Returns:
        (x, y, dir, yellow_door_open, blue_door_open, yellow_key_on_map, blue_key_on_map) tuple
    """
    total = size * size * 4 * 2 * 2 * 2 * 2
    state_index = state_index % total
    
    blue_key_on_map = state_index % 2
    state_index //= 2
    
    yellow_key_on_map = state_index % 2
    state_index //= 2
    
    blue_door_open = state_index % 2
    state_index //= 2
    
    yellow_door_open = state_index % 2
    state_index //= 2
    
    dir = state_index % 4
    state_index //= 4
    
    y = state_index % size
    state_index //= size
    
    x = state_index % size
    
    return (x, y, dir, yellow_door_open, blue_door_open, yellow_key_on_map, blue_key_on_map)


def option_can_start(q_row: np.ndarray) -> bool:
    """Check if an option can start at the current state."""
    return q_row.max() > 0


def option_terminated(q_row: np.ndarray, L: int = 15) -> bool:
    """Check if an option should terminate."""
    return (q_row.max() <= 0) or (random.random() < 1.0 / L)


def epsilon_greedy_action(Q, state_idx, epsilon, num_actions=6, Qopt=None):
    """
    Select action using epsilon-greedy policy (primitive actions or options).
    
    Args:
        Q: Q-table for primitive actions (num_states, num_actions)
        state_idx: current state index
        epsilon: exploration probability
        num_actions: number of primitive actions (default: 6)
        Qopt: Option Q-table (K, num_states, num_actions) or None
    
    Returns:
        action: selected action (0 to num_actions-1 for primitives, num_actions+oid for options)
        is_option: True if action is an option, False if primitive
    """
    Ap = num_actions
    K = 0 if Qopt is None else Qopt.shape[0]
    
    # Get available actions (primitive + startable options)
    if K > 0:
        startable_opts = [
            oid for oid in range(K) if option_can_start(Qopt[oid, state_idx])
        ]
    else:
        startable_opts = []
    
    candidates = list(range(Ap)) + [Ap + oid for oid in startable_opts]
    
    if random.random() < epsilon:
        # Explore: random action
        a = random.choice(candidates)
        is_option = (a >= Ap)
        return a, is_option
    else:
        # Exploit: greedy action
        # Evaluate primitive actions
        primitive_values = Q[state_idx]
        
        # Evaluate options
        if startable_opts:
            option_values = np.array([
                Qopt[oid, state_idx].max() for oid in startable_opts
            ])
            best_option_idx = np.argmax(option_values)
            best_option_value = option_values[best_option_idx]
            best_option_id = startable_opts[best_option_idx]
        else:
            best_option_value = float('-inf')
            best_option_id = None
        
        # Choose best action
        best_primitive_value = np.max(primitive_values)
        best_primitive_action = int(np.argmax(primitive_values))
        
        if best_option_value > best_primitive_value:
            a = Ap + best_option_id
            is_option = True
        else:
            a = best_primitive_action
            is_option = False
        
        return a, is_option


def load_option_file(opt_type: str, out_dir: Path, group_idx: int = 0, num_opts: int = None, sign: bool = True) -> Optional[np.ndarray]:
    """
    Load a single option file.
    
    Args:
        opt_type: Option type ("random", "eigen", "vps")
        out_dir: Directory containing option files
        group_idx: Group index (outer index) to select (default: 0)
        num_opts: Number of base options (if None, loads any available)
        sign: Whether options use sign doubling (affects VPS/Eigen total option count)
    
    Returns:
        Q: option Q-table (K, N, 6) or None if not found
    """
    type_map = {
        "random": "RandomOption",
        "eigen": "EigenOpt",
        "vps": "VPSOpt",
    }
    tag = type_map.get(opt_type.lower())
    if tag is None:
        return None
    
    if num_opts is not None:
        # key_lock_options.py saves all three (VPS, Eigen, Random) with total = num_opts * 2 when sign=True
        if sign:
            total_opts = num_opts * 2
        else:
            total_opts = num_opts
        pattern = out_dir / f"keylock_{total_opts}_{tag}_*.npy"
    else:
        pattern = out_dir / f"keylock_*_{tag}_*.npy"
    
    files = sorted(glob.glob(str(pattern)))
    if len(files) == 0:
        print(f"[Warning] No {tag} files found matching pattern {pattern}")
        return None
    
    if group_idx >= len(files):
        print(f"[Warning] Requested group {group_idx}, but only {len(files)} files found. Using last file.")
        group_idx = len(files) - 1
    
    selected_file = files[group_idx]
    Q = np.load(selected_file)
    print(f"[Load] {Path(selected_file).name}  (K={Q.shape[0]} options)")
    return Q


def _run_one_episode(
    env: KeyLockEnv,
    Q: np.ndarray,
    Qopt: Optional[np.ndarray],
    option_policy: Optional[np.ndarray],
    size: int,
    max_steps_per_episode: int,
    epsilon: float,
    train: bool,
    gamma: float,
    learning_rate: float,
    num_actions: int = 6,
) -> tuple[float, int, bool]:
    """
    Run one episode. When train=False, use epsilon=0 (greedy) and do not update Q.

    Returns:
        (episode_reward, episode_steps, episode_success)
    """
    obs, info = env.reset()
    state_idx = state_to_index(
        obs[0], obs[1], obs[2], obs[3], obs[4], obs[5], obs[6], size
    )
    episode_reward = 0.0
    episode_steps = 0
    episode_success = False
    epsilon_eff = 0.0 if not train else epsilon

    while episode_steps < max_steps_per_episode:
        action, is_option = epsilon_greedy_action(
            Q, state_idx, epsilon_eff, num_actions, Qopt
        )

        if is_option:
            option_id = action - num_actions
            option_reward = 0.0
            option_steps = 0

            while episode_steps + option_steps < max_steps_per_episode and not option_terminated(
                Qopt[option_id, state_idx]
            ):
                current_state = state_idx
                primitive_action = int(option_policy[option_id, state_idx])

                next_obs, reward, terminated, truncated, info = env.step(
                    primitive_action
                )
                next_state_idx = state_to_index(
                    next_obs[0], next_obs[1], next_obs[2],
                    next_obs[3], next_obs[4], next_obs[5], next_obs[6], size
                )
                option_reward += reward
                option_steps += 1

                if train:
                    if terminated or truncated:
                        target = reward
                    else:
                        max_primitive = np.max(Q[next_state_idx])
                        if Qopt is not None:
                            startable_opts = [
                                oid for oid in range(Qopt.shape[0])
                                if option_can_start(Qopt[oid, next_state_idx])
                            ]
                            if startable_opts:
                                max_option = max([
                                    Qopt[oid, next_state_idx].max()
                                    for oid in startable_opts
                                ])
                                max_value = max(max_primitive, max_option)
                            else:
                                max_value = max_primitive
                        else:
                            max_value = max_primitive
                        target = reward + gamma * max_value
                    td_error = target - Q[current_state, primitive_action]
                    Q[current_state, primitive_action] += learning_rate * td_error

                state_idx = next_state_idx
                if reward == 1.0:
                    episode_success = True
                if terminated or truncated:
                    break

            episode_reward += option_reward
            episode_steps += option_steps
            if option_steps == 0:
                break  # option already terminated at s, no step taken
            if terminated or truncated:
                break
        else:
            if episode_steps >= max_steps_per_episode:
                break
            next_obs, reward, terminated, truncated, info = env.step(action)
            next_state_idx = state_to_index(
                next_obs[0], next_obs[1], next_obs[2],
                next_obs[3], next_obs[4], next_obs[5], next_obs[6], size
            )

            if train:
                if terminated or truncated:
                    target = reward
                else:
                    max_primitive = np.max(Q[next_state_idx])
                    if Qopt is not None:
                        startable_opts = [
                            oid for oid in range(Qopt.shape[0])
                            if option_can_start(Qopt[oid, next_state_idx])
                        ]
                        if startable_opts:
                            max_option = max([
                                Qopt[oid, next_state_idx].max()
                                for oid in startable_opts
                            ])
                            max_value = max(max_primitive, max_option)
                        else:
                            max_value = max_primitive
                    else:
                        max_value = max_primitive
                    target = reward + gamma * max_value
                td_error = target - Q[state_idx, action]
                Q[state_idx, action] += learning_rate * td_error

            episode_reward += reward
            episode_steps += 1
            state_idx = next_state_idx
            if reward == 1.0:
                episode_success = True
            if terminated or truncated:
                break

    return episode_reward, episode_steps, episode_success


def train_qlearning(
    num_episodes=1000,
    max_steps_per_episode=500,
    learning_rate=0.1,
    gamma=0.99,
    epsilon_start=1.0,
    epsilon_end=0.01,
    epsilon_decay=0.995,
    size=15,
    seed=None,
    verbose=True,
    Qopt=None,
    option_name="primitive",
    max_steps=None,
    eval_step_interval=5000,
    eval_trials=10,
):
    """
    Train Q-learning agent on KeyLockEnv.

    When max_steps is set: train until total_steps >= max_steps, evaluate every
    eval_step_interval steps (greedy, eval_trials episodes), return (Q, eval_rewards, eval_steps).
    When max_steps is None: run num_episodes, return (Q, episode_rewards, episode_lengths, success_rate_history).
    """
    if seed is not None:
        np.random.seed(seed)
        random.seed(seed)
    
    # Create environment
    # Set max_steps to match max_steps_per_episode to ensure environment truncates at the same limit
    env = KeyLockEnv(
        size=size,
        agent_start_pos=(1, 1),
        agent_start_dir=0,
        yellow_key_pos=(12, 3),
        yellow_door_pos=(3, 8),
        blue_key_pos=(12, 12),
        blue_door_pos=(9, 3),
        goal_pos=(3, 12),
        max_steps=max_steps_per_episode,  # Set environment max_steps to match training loop limit
        render_mode=None,
    )
    
    # Initialize Q-table for primitive actions
    # State space: size * size * 4 * 2 * 2 * 2 * 2
    # (x, y, dir, yellow_door_open, blue_door_open, yellow_key_on_map, blue_key_on_map)
    num_states = size * size * 4 * 2 * 2 * 2 * 2
    num_actions = 6  # Primitive actions
    Q = np.zeros((num_states, num_actions), dtype=np.float32)
    
    # Get option policy if available
    option_policy = None
    if Qopt is not None:
        option_policy = np.argmax(Qopt, 2)  # (K, N)
    
    # Training statistics
    episode_rewards = []
    episode_lengths = []
    success_history = []  # Track if each episode was successful
    epsilon = epsilon_start
    
    # For plotting: track success rate over windows
    window_size = 50
    success_rate_history = []
    
    print("=" * 70)
    print(f"Q-Learning Training for KeyLockEnv ({option_name})")
    print("=" * 70)
    print(f"Configuration:")
    if max_steps is not None:
        print(f"  Max steps (total): {max_steps}")
        print(f"  Eval step interval: {eval_step_interval}")
        print(f"  Eval trials: {eval_trials}")
    else:
        print(f"  Number of episodes: {num_episodes}")
    print(f"  Max steps per episode: {max_steps_per_episode}")
    print(f"  Learning rate (alpha): {learning_rate}")
    print(f"  Discount factor (gamma): {gamma}")
    print(f"  Epsilon: {epsilon_start} -> {epsilon_end} (decay: {epsilon_decay})")
    print(f"  State space size: {num_states}")
    print(f"  Primitive actions: {num_actions}")
    if Qopt is not None:
        print(f"  Options available: {Qopt.shape[0]}")
    else:
        print(f"  Options available: 0 (primitive only)")
    print("=" * 70)
    print()

    # ----- Step-based training with periodic evaluation -----
    if max_steps is not None:
        total_steps = 0
        last_eval_step = 0
        eval_rewards = []
        eval_steps_list = []
        epsilon = epsilon_start

        while total_steps < max_steps:
            ep_rew, ep_steps, _ = _run_one_episode(
                env, Q, Qopt, option_policy, size, max_steps_per_episode,
                epsilon, train=True, gamma=gamma, learning_rate=learning_rate,
                num_actions=num_actions,
            )
            total_steps += ep_steps
            epsilon = max(epsilon_end, epsilon * epsilon_decay)

            if total_steps - last_eval_step >= eval_step_interval:
                rets = [
                    _run_one_episode(
                        env, Q, Qopt, option_policy, size, max_steps_per_episode,
                        0.0, train=False, gamma=gamma, learning_rate=learning_rate,
                        num_actions=num_actions,
                    )[0]
                    for _ in range(eval_trials)
                ]
                eval_rewards.append(sum(rets) / len(rets))
                eval_steps_list.append(total_steps)
                last_eval_step = total_steps

            if total_steps >= max_steps:
                break

        if len(eval_steps_list) == 0 or eval_steps_list[-1] < max_steps:
            rets = [
                _run_one_episode(
                    env, Q, Qopt, option_policy, size, max_steps_per_episode,
                    0.0, train=False, gamma=gamma, learning_rate=learning_rate,
                    num_actions=num_actions,
                )[0]
                for _ in range(eval_trials)
            ]
            eval_rewards.append(sum(rets) / len(rets))
            eval_steps_list.append(total_steps)

        print("\n" + "=" * 70)
        print("Training Complete (periodic eval)")
        print("=" * 70)
        print(f"Total steps: {total_steps}")
        print(f"Eval points: {len(eval_rewards)}")
        print("=" * 70)
        return Q, eval_rewards, eval_steps_list

    # ----- Episode-based training (no periodic eval) -----
    for episode in range(num_episodes):
        episode_reward, episode_steps, episode_success = _run_one_episode(
            env, Q, Qopt, option_policy, size, max_steps_per_episode,
            epsilon, train=True, gamma=gamma, learning_rate=learning_rate,
            num_actions=num_actions,
        )
        episode_rewards.append(episode_reward)
        episode_lengths.append(episode_steps)
        success_history.append(episode_success)
        epsilon = max(epsilon_end, epsilon * epsilon_decay)

        if episode >= window_size - 1:
            recent_successes = sum(success_history[-window_size:])
            success_rate = recent_successes / window_size
            success_rate_history.append(success_rate)
        else:
            recent_successes = sum(success_history)
            success_rate = recent_successes / (episode + 1)
            success_rate_history.append(success_rate)

        if verbose and (episode + 1) % 50 == 0:
            recent_reward = np.mean(episode_rewards[-50:])
            recent_length = np.mean(episode_lengths[-50:])
            recent_success = sum(success_history[-50:])
            print(f"Episode {episode + 1}/{num_episodes} | "
                  f"Recent avg reward: {recent_reward:.3f} | "
                  f"Recent avg length: {recent_length:.1f} | "
                  f"Recent successes: {recent_success}/50 | "
                  f"Epsilon: {epsilon:.3f}")

    total_successes = sum(success_history)
    success_rate_final = total_successes / num_episodes

    print("\n" + "=" * 70)
    print("Training Complete")
    print("=" * 70)
    print(f"Total episodes: {num_episodes}")
    print(f"Successful episodes: {total_successes}")
    print(f"Final success rate: {100 * success_rate_final:.2f}%")
    print(f"Average reward: {np.mean(episode_rewards):.3f}")
    print(f"Average episode length: {np.mean(episode_lengths):.1f}")

    if total_successes > 0:
        successful_rewards = [r for i, r in enumerate(episode_rewards) if success_history[i]]
        successful_lengths = [l for i, l in enumerate(episode_lengths) if success_history[i]]
        print(f"\nSuccessful episodes:")
        print(f"  Average reward: {np.mean(successful_rewards):.3f}")
        print(f"  Average length: {np.mean(successful_lengths):.1f}")
        print(f"  Min length: {min(successful_lengths)}")
        print(f"  Max length: {max(successful_lengths)}")

    print("=" * 70)
    return Q, episode_rewards, episode_lengths, success_rate_history


def _legend_name(option_name: str) -> str:
    """Display name for legend: Primitive, Random, Eigen, VPS."""
    names = {"primitive": "Primitive", "random": "Random", "eigen": "Eigen", "vps": "VPS"}
    return names.get(option_name, option_name.capitalize())


def _get_mean_std(arr):
    """If arr is list of lists (multiple runs), return (mean, std). Else return (arr, None)."""
    if not arr or not isinstance(arr, (list, np.ndarray)):
        return np.array(arr), None
    first = arr[0]
    if isinstance(first, (list, np.ndarray)) and len(first) > 0 and not isinstance(first[0], (list, np.ndarray)):
        arr = np.array(arr)
        return np.mean(arr, axis=0), np.std(arr, axis=0)
    return np.array(arr), None


def _moving_average(arr: np.ndarray, k: int) -> np.ndarray:
    """1-D moving average; k<=1 returns input unchanged."""
    if k <= 1:
        return arr
    kernel = np.ones(k) / k
    return np.convolve(arr, kernel, mode="valid")


def _plot_eval_curves(
    all_results: Dict[str, Dict],
    save_path=None,
    window_size=50,
    color_map=None,
    fallback_colors=None,
):
    """Plot Eval Return vs Time Steps (periodic-eval mode). Mean ± std, optional smoothing."""
    if color_map is None:
        color_map = {"primitive": "tab:blue", "random": "tab:purple", "eigen": "tab:orange", "vps": "tab:green"}
    if fallback_colors is None:
        fallback_colors = ['blue', 'red', 'green', 'orange', 'purple', 'brown', 'pink', 'gray']

    plt.figure(figsize=(8, 4.5))
    for idx, (option_name, results) in enumerate(all_results.items()):
        mat = results['eval_rewards']
        step_mat = results['eval_steps']
        mean_rewards = mat.mean(0)
        std_rewards = mat.std(0)
        mean_steps = step_mat.mean(0).astype(float)
        if window_size > 1 and len(mean_rewards) >= window_size:
            mean_rewards = _moving_average(mean_rewards, window_size)
            std_rewards = _moving_average(std_rewards, window_size)
            mean_steps = mean_steps[window_size - 1:]
        color = color_map.get(option_name, fallback_colors[idx % len(fallback_colors)])
        label = _legend_name(option_name)
        plt.plot(mean_steps, mean_rewards, color=color, linewidth=2, label=label)
        plt.fill_between(mean_steps, mean_rewards - std_rewards, mean_rewards + std_rewards,
                        color=color, alpha=0.25)
    plt.xlabel("Time Steps", fontsize=16)
    plt.ylabel("Return", fontsize=16)
    plt.title("KeyLockEnv Q-Learning (Periodic Eval)", fontsize=16, fontweight='bold')
    plt.legend(fontsize=14)
    plt.grid(alpha=0.3)
    plt.tick_params(axis='both', labelsize=14)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"\nPlot saved to: {save_path}")
    else:
        plt.savefig('keylock_qlearning_eval_curves.png', dpi=300, bbox_inches='tight')
        print(f"\nPlot saved to: keylock_qlearning_eval_curves.png")
    plt.show()


def plot_training_curves(
    all_results: Dict[str, Dict],
    save_path=None,
    window_size=50,
):
    """
    Plot training curves: rewards, episode lengths, and success rate for multiple option types.
    Only smoothed curves; legend: Primitive, Random, Eigen, VPS.
    Supports single-run or multi-run (outer) per type; multi-run plots mean ± std.
    
    Args:
        all_results: Dict mapping option_name to {
            'episode_rewards': list or list of lists (outer runs),
            'episode_lengths': list or list of lists,
            'success_rate_history': list or list of lists
        }
        save_path: path to save the plot (optional)
        window_size: window size for moving average
    """
    # Match option_exploration_qlearning colors: primitive, random, eigen, vps
    color_map = {
        "primitive": "tab:blue",
        "random": "tab:purple",
        "eigen": "tab:orange",
        "vps": "tab:green",
    }
    fallback_colors = ['blue', 'red', 'green', 'orange', 'purple', 'brown', 'pink', 'gray']

    first_name = list(all_results.keys())[0]
    first_results = all_results[first_name]
    if 'eval_rewards' in first_results and 'eval_steps' in first_results:
        _plot_eval_curves(all_results, save_path, window_size, color_map, fallback_colors)
        return

    # Episode-based mode: single plot (reward only)
    rew = first_results['episode_rewards']
    num_episodes = len(rew[0]) if rew and isinstance(rew[0], (list, np.ndarray)) else len(rew)
    episodes = np.arange(1, num_episodes + 1)

    plt.figure(figsize=(8, 5))
    
    # Episode Rewards (smoothed only, mean ± std shade)
    for idx, (option_name, results) in enumerate(all_results.items()):
        episode_rewards, rew_std = _get_mean_std(results['episode_rewards'])
        color = color_map.get(option_name, fallback_colors[idx % len(fallback_colors)])
        label = _legend_name(option_name)
        if num_episodes >= window_size:
            moving_avg = np.convolve(
                episode_rewards,
                np.ones(window_size) / window_size,
                mode='valid'
            )
            x = np.arange(window_size, num_episodes + 1)
            plt.plot(x, moving_avg, color=color, linewidth=2, label=label)
            if rew_std is not None:
                moving_std = np.convolve(
                    rew_std,
                    np.ones(window_size) / window_size,
                    mode='valid'
                )
                plt.fill_between(x, moving_avg - moving_std, moving_avg + moving_std, color=color, alpha=0.25)
        else:
            plt.plot(episodes, episode_rewards, color=color, linewidth=2, label=label)
            if rew_std is not None:
                plt.fill_between(episodes, episode_rewards - rew_std, episode_rewards + rew_std, color=color, alpha=0.25)
    plt.xlabel('Episode', fontsize=16)
    plt.ylabel('Episode Reward', fontsize=16)
    plt.title('Q-Learning Training Curves Comparison for KeyLockEnv', fontsize=16, fontweight='bold')
    plt.legend(fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.tick_params(axis='both', labelsize=14)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"\nPlot saved to: {save_path}")
    else:
        plt.savefig('keylock_qlearning_training_curves.png', dpi=300, bbox_inches='tight')
        print(f"\nPlot saved to: keylock_qlearning_training_curves.png")
    
    plt.show()


def test_trained_agent(Q, num_test_episodes=10, size=15, verbose=True):
    """
    Test the trained Q-learning agent.
    
    Args:
        Q: trained Q-table
        num_test_episodes: number of test episodes
        size: grid size
        verbose: print test results
    """
    env = KeyLockEnv(
        size=size,
        agent_start_pos=(1, 1),
        agent_start_dir=0,
        yellow_key_pos=(12, 3),
        yellow_door_pos=(3, 8),
        blue_key_pos=(12, 12),
        blue_door_pos=(9, 3),
        goal_pos=(3, 12),
        render_mode=None,
    )
    
    test_rewards = []
    test_lengths = []
    test_successes = []
    
    if verbose:
        print("\n" + "=" * 70)
        print("Testing Trained Agent (Greedy Policy)")
        print("=" * 70)
    
    for episode in range(num_test_episodes):
        obs, info = env.reset()
        state_idx = state_to_index(
            obs[0], obs[1], obs[2], obs[3], obs[4], obs[5], obs[6], size
        )
        
        episode_reward = 0.0
        episode_steps = 0
        episode_success = False
        
        for step in range(500):  # Max steps
            # Greedy action (no exploration)
            action = int(np.argmax(Q[state_idx]))
            
            next_obs, reward, terminated, truncated, info = env.step(action)
            next_state_idx = state_to_index(
                next_obs[0], next_obs[1], next_obs[2],
                next_obs[3], next_obs[4], next_obs[5], next_obs[6], size
            )
            
            episode_reward += reward
            episode_steps += 1
            state_idx = next_state_idx
            
            if reward == 1.0:
                episode_success = True
            
            if terminated or truncated:
                break
        
        test_rewards.append(episode_reward)
        test_lengths.append(episode_steps)
        test_successes.append(episode_success)
        
        if verbose:
            status = "✓ SUCCESS" if episode_success else "✗ FAILED"
            print(f"Episode {episode + 1}: {status} | "
                  f"Reward: {episode_reward:.2f} | "
                  f"Steps: {episode_steps}")
    
    if verbose:
        print("=" * 70)
        print(f"Test Results:")
        print(f"  Success rate: {100 * sum(test_successes) / num_test_episodes:.1f}%")
        print(f"  Average reward: {np.mean(test_rewards):.3f}")
        print(f"  Average length: {np.mean(test_lengths):.1f}")
        print("=" * 70)
    
    return test_rewards, test_lengths, test_successes


def main():
    """Main function."""
    import argparse
    
    parser = argparse.ArgumentParser(description="Q-learning training for KeyLockEnv with options")
    parser.add_argument(
        "--num_episodes",
        type=int,
        default=5000,
        help="Number of training episodes (default: 1000)"
    )
    parser.add_argument(
        "--max_steps",
        type=int,
        default=200,
        help="Maximum steps per episode (default: 500)"
    )
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=0.5,
        help="Learning rate (alpha) (default: 0.1)"
    )
    parser.add_argument(
        "--gamma",
        type=float,
        default=0.99,
        help="Discount factor (default: 0.99)"
    )
    parser.add_argument(
        "--epsilon_start",
        type=float,
        default=1.0,
        help="Initial epsilon (default: 1.0)"
    )
    parser.add_argument(
        "--epsilon_end",
        type=float,
        default=0.1,
        help="Final epsilon (default: 0.01)"
    )
    parser.add_argument(
        "--epsilon_decay",
        type=float,
        default=0.999,
        help="Epsilon decay rate (default: 0.995)"
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=None,
        help="Random seed (default: None)"
    )
    parser.add_argument(
        "--test_episodes",
        type=int,
        default=10,
        help="Number of test episodes (default: 10)"
    )
    parser.add_argument(
        "--save_plot",
        type=str,
        default=None,
        help="Path to save plot (default: keylock_qlearning_training_curves.png)"
    )
    parser.add_argument(
        "--save_qtable",
        type=str,
        default=None,
        help="Path to save Q-table (optional)"
    )
    parser.add_argument(
        "--option_dir",
        type=str,
        default="keylock_option_results",
        help="Directory containing option files (default: keylock_option_results)"
    )
    parser.add_argument(
        "--option_types",
        type=str,
        nargs='+',
        default=["primitive", "random", "eigen", "vps"],
        choices=["primitive", "random", "eigen", "vps"],
        help="Option types to compare (default: primitive)"
    )
    parser.add_argument(
        "--num_opts",
        type=int,
        default=4,
        help="Number of base options k (default: 4)"
    )
    parser.add_argument(
        "--outer",
        type=int,
        default=5,
        help="Number of option instances per type (primitive uses inner only)"
    )
    parser.add_argument(
        "--inner",
        type=int,
        default=1,
        help="Number of training runs per option instance (or per primitive)"
    )
    parser.add_argument(
        "--group_idx",
        type=int,
        default=0,
        help="Deprecated: ignored when using outer×inner; option group index (default: 0)"
    )
    parser.add_argument(
        "--sign",
        type=bool,
        default=True,
        help="Whether options use sign doubling (default: True)"
    )
    parser.add_argument(
        "--total_steps",
        type=int,
        default=None,
        help="Total training steps per run; if set, use step-based training with periodic eval"
    )
    parser.add_argument(
        "--eval_step_interval",
        type=int,
        default=1000,
        help="Evaluate every N steps when --total_steps is set"
    )
    parser.add_argument(
        "--eval_trials",
        type=int,
        default=5,
        help="Number of greedy episodes per evaluation when --total_steps is set"
    )
    
    args = parser.parse_args()
    
    # Resolve option directory
    script_dir = os.path.dirname(os.path.abspath(__file__))
    option_dir = os.path.join(script_dir, args.option_dir)
    
    # Fixed order for legend/colors (match option_exploration_qlearning)
    methods = ["primitive", "random", "eigen", "vps"]
    option_types = [m for m in methods if m in args.option_types]
    
    all_results = {}
    outer, inner = args.outer, args.inner
    
    use_eval_mode = args.total_steps is not None

    for opt_type in option_types:
        print(f"\n{'='*70}")
        print(f"Training with {opt_type} options (outer={outer}, inner={inner})")
        if use_eval_mode:
            print(f"Periodic eval: total_steps={args.total_steps}, interval={args.eval_step_interval}, trials={args.eval_trials}")
        print(f"{'='*70}\n")
        
        rewards_list = []
        lengths_list = []
        success_list = []
        steps_list = []  # used only in eval mode
        Q_last = None
        
        def run_one_train(seed_val, Qopt_val):
            if use_eval_mode:
                return train_qlearning(
                    num_episodes=args.num_episodes,
                    max_steps_per_episode=args.max_steps,
                    learning_rate=args.learning_rate,
                    gamma=args.gamma,
                    epsilon_start=args.epsilon_start,
                    epsilon_end=args.epsilon_end,
                    epsilon_decay=args.epsilon_decay,
                    seed=seed_val,
                    Qopt=Qopt_val,
                    option_name=opt_type,
                    max_steps=args.total_steps,
                    eval_step_interval=args.eval_step_interval,
                    eval_trials=args.eval_trials,
                )
            return train_qlearning(
                num_episodes=args.num_episodes,
                max_steps_per_episode=args.max_steps,
                learning_rate=args.learning_rate,
                gamma=args.gamma,
                epsilon_start=args.epsilon_start,
                epsilon_end=args.epsilon_end,
                epsilon_decay=args.epsilon_decay,
                seed=seed_val,
                Qopt=Qopt_val,
                option_name=opt_type,
            )
        
        if opt_type == "primitive":
            for inn in range(inner):
                out = run_one_train(inn, None)
                Q_last = out[0]
                if use_eval_mode:
                    rewards_list.append(out[1])
                    steps_list.append(out[2])
                else:
                    rewards_list.append(out[1])
                    lengths_list.append(out[2])
                    success_list.append(out[3])
        else:
            for g in range(outer):
                Qopt = load_option_file(
                    opt_type, Path(option_dir),
                    group_idx=g,
                    num_opts=args.num_opts,
                    sign=args.sign
                )
                if Qopt is None:
                    print(f"[Warning] Could not load {opt_type} options group {g}. Skipping group.")
                    continue
                for inn in range(inner):
                    out = run_one_train(g * 10000 + inn, Qopt)
                    Q_last = out[0]
                    if use_eval_mode:
                        rewards_list.append(out[1])
                        steps_list.append(out[2])
                    else:
                        rewards_list.append(out[1])
                        lengths_list.append(out[2])
                        success_list.append(out[3])
        
        if not rewards_list:
            print(f"[Warning] No runs for {opt_type}. Skipping.")
            continue
        
        if use_eval_mode:
            min_len = min(len(r) for r in rewards_list)
            num_runs = len(rewards_list)
            eval_rewards_arr = np.zeros((num_runs, min_len), dtype=np.float32)
            eval_steps_arr = np.zeros((num_runs, min_len), dtype=np.int64)
            for run_idx in range(num_runs):
                eval_rewards_arr[run_idx, :] = rewards_list[run_idx][:min_len]
                eval_steps_arr[run_idx, :] = steps_list[run_idx][:min_len]
            all_results[opt_type] = {
                'eval_rewards': eval_rewards_arr,
                'eval_steps': eval_steps_arr,
            }
        else:
            all_results[opt_type] = {
                'episode_rewards': rewards_list,
                'episode_lengths': lengths_list,
                'success_rate_history': success_list,
            }
        
        # Test and save using last run's Q
        if Q_last is not None:
            print(f"\n{'='*70}")
            print(f"Testing {opt_type} agent (last run)")
            print(f"{'='*70}")
            test_trained_agent(Q_last, num_test_episodes=args.test_episodes)
            if args.save_qtable:
                save_path = args.save_qtable.replace('.npy', f'_{opt_type}.npy')
                np.save(save_path, Q_last)
                print(f"Q-table saved to: {save_path}")
    
    # Plot training curves comparison (mean ± std, smoothed)
    if len(all_results) > 0:
        plot_training_curves(
            all_results,
            save_path=args.save_plot,
        )
    
    return 0


if __name__ == "__main__":
    exit(main())
