#!/usr/bin/env python
"""Random-walk success-count comparison on KeyLockEnv.

Counts successful episodes (reaching goal) under random mixtures of
primitive actions and options, using option files saved previously.
"""
from __future__ import annotations
import argparse, glob, random
from pathlib import Path
from typing import List, Dict
import numpy as np
import matplotlib.pyplot as plt
import os
import key_lock_options as klo
from key_lock_env import KeyLockEnv
from train_keylock_qlearning import state_to_index


# ---------- option start / termination -----------------------
def option_can_start(q_row: np.ndarray) -> bool:
    """
    Gridworld-style initiation: full state space, but an option is
    considered startable at s only if its local Q-max is strictly > 0.
    """
    return q_row.max() > 0


def option_terminated(q_row: np.ndarray, L: int = 15) -> bool:
    """
    Gridworld-style termination during rollout: an option terminates if
    its local Q-max is non-positive OR with probability 1/L at each step.
    """
    return (q_row.max() <= 0) or (random.random() < 1.0 / L)


# ---------- load option files --------------------------------
def load_option_groups(
    opt_type: str,
    out_dir: Path,
    outer: int,
    num_opts: int = None,
    sign: bool = True,
) -> Dict[str, List[np.ndarray]]:
    """
    Return a dict {type: [Q_group0, Q_group1, …]}.
    Each list element is one *.npy file (one option set).
    If num_opts is provided, only load files matching that number of options.
    Note: For VPS and Eigen, if sign=True, the total number of options is num_opts * 2.
    """
    kinds = {"random": "RandomOption", "eigen": "EigenOpt", "vps": "VPSOpt"}
    results: Dict[str, List[np.ndarray]] = {k: [] for k in kinds}

    for k, tag in kinds.items():
        if opt_type != "all" and k != opt_type:
            continue
        
        if num_opts is not None:
            # key_lock_options.py saves all three (VPS, Eigen, Random) with total = num_opts * 2 when sign=True
            if sign:
                total_opts = num_opts * 2
            else:
                total_opts = num_opts
            pattern = out_dir / f"keylock_{total_opts}_{tag}_*.npy"
        else:
            pattern = out_dir / f"keylock_*_{tag}_*.npy"
        
        files = sorted(glob.glob(str(pattern)))
        if len(files) == 0:
            print(f"[Warning] {k} (num_opts={num_opts}): no files found matching pattern {pattern}")
            continue
        if len(files) < outer:
            print(f"[Warning] {k} (num_opts={num_opts}): require {outer} groups, found {len(files)}. Using available {len(files)} groups.")
            results[k] = [np.load(f) for f in files]
        else:
            results[k] = [np.load(f) for f in files[:outer]]
        
        for f in files[:min(outer, len(files))]:
            print(f"[Load] {Path(f).name}")

    return results


# ---------- one episode: random walk -------------------------
def run_episode_random(
    env: KeyLockEnv,
    size: int,
    Qopt: np.ndarray | None,
    max_len: int = 500,
) -> bool:
    """
    Return True if the agent reaches the goal (reward == 1.0) within
    max_len steps, False otherwise.
    """
    obs, info = env.reset()
    s = state_to_index(
        obs[0], obs[1], obs[2], obs[3], obs[4], obs[5], obs[6], size
    )
    Ap = 6  # Number of primitive actions (up, down, left, right, pickup, toggle)
    success = False
    steps = 0
    K = 0 if Qopt is None else Qopt.shape[0]
    policy = np.argmax(Qopt, 2) if K else None  # (K, S)

    term, trunc = False, False
    while steps < max_len:
        # ----- sample a primitive action or a *startable* option -----
        if K > 0:
            startable_opts = [
                oid for oid in range(K) if option_can_start(Qopt[oid, s])
            ]
        else:
            startable_opts = []

        candidates = list(range(Ap)) + [Ap + oid for oid in startable_opts]
        a = random.choice(candidates)  # primitives 0..Ap-1, options Ap+oid..

        # ----- primitive step --------------------------------------
        if a < Ap or K == 0 or not startable_opts:
            next_obs, r, term, trunc, _ = env.step(a)
            if r == 1.0:  # task success (reached goal)
                success = True
                break
            if term or trunc:
                break
            s = state_to_index(
                next_obs[0], next_obs[1], next_obs[2],
                next_obs[3], next_obs[4], next_obs[5], next_obs[6], size
            )
            steps += 1
        # ----- option rollout --------------------------------------
        else:
            oid = a - Ap
            while steps < max_len and not option_terminated(Qopt[oid, s]):
                ain = int(policy[oid, s])
                next_obs, r, term, trunc, _ = env.step(ain)
                if r == 1.0:
                    success = True
                    steps += 1
                    break
                s = state_to_index(
                    next_obs[0], next_obs[1], next_obs[2],
                    next_obs[3], next_obs[4], next_obs[5], next_obs[6], size
                )
                steps += 1
                if term or trunc:
                    break
            if success or term or trunc:
                break
    return success


# ---------- main experiment ----------------------------------
def evaluate_groups(
    env: KeyLockEnv,
    size: int,
    groups: Dict[str, List[np.ndarray]],
    outer: int,
    episodes: int,
    max_len: int,
) -> Dict[str, List[int]]:
    """
    Return {type: [succ_cnt_group0, succ_cnt_group1, …]}.
    The primitive baseline has only one "group".
    """
    results = {"primitive": []}
    for k in groups:
        results[k] = []

    Ap = 6  # Number of primitive actions

    # --- primitive baseline ---
    for outer_idx in range(outer):
        klo.set_seed(outer_idx)
        succ = 0
        for _ in range(episodes):
            obs, info = env.reset()
            steps = 0
            while steps < max_len:
                a = random.randrange(Ap)
                next_obs, r, term, trunc, _ = env.step(a)
                if r == 1.0:  # task success (reached goal)
                    succ += 1
                    break
                steps += 1
                if term or trunc:
                    break
        results["primitive"].append(succ)
        print(f"[primitive] group {outer_idx}: {succ}/{episodes} successes")

    # --- each option family ---
    for k, lst in groups.items():
        for gidx, Q in enumerate(lst):
            klo.set_seed(gidx)
            succ = 0
            for _ in range(episodes):
                if run_episode_random(env, size, Q, max_len):
                    succ += 1
            results[k].append(succ)
            print(f"[{k}] group {gidx}: {succ}/{episodes} successes")
    return results


# ---------- box-plot ----------------------------------------
def plot_box(results: Dict[str, List[int]], episodes: int, save_path: str = None, num_opts: int = None):
    labels = {"primitive": "Primitive",
              "random":    "Random",
              "eigen":     "Eigen",
              "vps":       "VPS"}
    order = ["primitive", "random", "eigen", "vps"]
    data = [results[k] for k in order if results[k]]
    labels_list = [labels[k] for k in order if results[k]]

    plt.figure(figsize=(6, 4))
    plt.boxplot(data,
                labels=labels_list,
                showmeans=True)
    plt.ylabel(f"Successful Trials / {episodes} Episodes", fontsize=14)
    title = "KeyLockEnv Task Success Rate"
    if num_opts is not None:
        title += f" (num_options={num_opts})"
    plt.title(title, fontsize=16)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.grid(alpha=.3, axis="y")
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"\n[Save] Box plot saved to {save_path}")
    plt.show()


def plot_box_multi_subplots(
    all_results: Dict[int, Dict[str, List[int]]],
    episodes: int,
    save_path: str = None,
):
    """
    Plot multiple subplots, one for each num_options value.
    Each subplot compares different option types and primitive actions.
    """
    labels = {"primitive": "Primitive",
              "random":    "Random",
              "eigen":     "Eigen",
              "vps":       "VPS"}
    order = ["primitive", "random", "eigen", "vps"]
    
    num_plots = len(all_results)
    if num_plots == 0:
        print("[Warning] No results to plot")
        return
    
    # One row, shared y-axis
    rows, cols = 1, num_plots
    fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 4), sharey=True)
    if num_plots == 1:
        axes = [axes]
    else:
        axes = axes.flatten()
    
    sorted_num_opts = sorted(all_results.keys())
    
    for idx, num_opts in enumerate(sorted_num_opts):
        ax = axes[idx]
        results = all_results[num_opts]
        
        data = [results[k] for k in order if k in results and results[k]]
        labels_list = [labels[k] for k in order if k in results and results[k]]
        
        ax.boxplot(data,
                   labels=labels_list,
                   showmeans=True)
        if idx == 0:
            ax.set_ylabel(f"Successful Trials / {episodes} Episodes", fontsize=12)
        ax.set_title(f"num_options = {num_opts}", fontsize=14)
        ax.tick_params(axis='x', labelsize=12)
        ax.tick_params(axis='y', labelsize=12)
        ax.grid(alpha=.3, axis="y")
    
    # Hide unused subplots
    for idx in range(num_plots, len(axes)):
        axes[idx].set_visible(False)
    
    plt.suptitle("KeyLockEnv Task Success Rate Comparison", fontsize=16, y=1.02)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"\n[Save] Multi-subplot box plot saved to {save_path}")
    plt.show()


# ---------------- CLI ---------------------------------------
def main():
    import ast
    pa = argparse.ArgumentParser()
    pa.add_argument("--out_dir", default="keylock_option_results")
    pa.add_argument("--opt_type",
                    choices=["random", "eigen", "vps", "all"],
                    default="all",
                    help="which option type(s) to test")
    pa.add_argument("--outer", type=int, default=5,
                    help="# option groups to load / per type")
    pa.add_argument("--episodes", type=int, default=10,
                    help="# random-walk episodes per group")
    pa.add_argument("--max_len", type=int, default=200)
    pa.add_argument("--size", type=int, default=15, help="grid size")
    pa.add_argument("--num_opts", type=str, default="[4,8,16]",
                    help='Number(s) of options to test. Examples: "4", "[4,8,16]", "4,8,16"')
    pa.add_argument("--sign", type=bool, default=True,
                    help="Whether options use sign doubling (affects VPS/Eigen total option count)")
    pa.add_argument("--yellow_key_pos", type=int, nargs=2, default=[12, 3], help="yellow key position (x, y)")
    pa.add_argument("--yellow_door_pos", type=int, nargs=2, default=[3, 8], help="yellow door position (x, y)")
    pa.add_argument("--blue_key_pos", type=int, nargs=2, default=[12, 12], help="blue key position (x, y)")
    pa.add_argument("--blue_door_pos", type=int, nargs=2, default=[9, 3], help="blue door position (x, y)")
    pa.add_argument("--goal_pos", type=int, nargs=2, default=[3, 12], help="goal position (x, y)")
    args = pa.parse_args()

    # Parse num_opts
    num_opts_str = args.num_opts.strip()
    if num_opts_str.startswith("[") and num_opts_str.endswith("]"):
        num_opts_list = ast.literal_eval(num_opts_str)
    elif "," in num_opts_str:
        num_opts_list = [int(x.strip()) for x in num_opts_str.split(",")]
    else:
        num_opts_list = [int(num_opts_str)]
    
    if not isinstance(num_opts_list, list):
        num_opts_list = [num_opts_list]
    
    print(f"[Config] Testing num_options: {num_opts_list}")

    # Create environment
    env = KeyLockEnv(
        size=args.size,
        agent_start_pos=(1, 1),
        agent_start_dir=0,
        yellow_key_pos=tuple(args.yellow_key_pos),
        yellow_door_pos=tuple(args.yellow_door_pos),
        blue_key_pos=tuple(args.blue_key_pos),
        blue_door_pos=tuple(args.blue_door_pos),
        goal_pos=tuple(args.goal_pos),
        render_mode=None,
    )

    # Resolve option directory relative to this script:
    # `gridworld/keylock_option_results/...`.
    script_dir = os.path.dirname(os.path.abspath(__file__))
    out_dir = os.path.join(script_dir, args.out_dir)
    os.makedirs(out_dir, exist_ok=True)
    
    # Evaluate for each num_opts
    all_results = {}
    
    for num_opts in num_opts_list:
        print(f"\n{'='*70}")
        print(f"Evaluating num_options = {num_opts}")
        print(f"{'='*70}")
        
        groups = load_option_groups(
            args.opt_type,
            Path(out_dir),
            args.outer,
            num_opts=num_opts,
            sign=args.sign,
        )

        results = evaluate_groups(
            env,
            args.size,
            groups,
            args.outer,
            args.episodes,
            args.max_len,
        )
        
        all_results[num_opts] = results
        
        # Print summary statistics for this num_opts
        print(f"\n[Summary] num_options = {num_opts}")
        print("-" * 70)
        for k, success_counts in results.items():
            if not success_counts:
                continue
            mean_success = np.mean(success_counts)
            std_success = np.std(success_counts)
            success_rate = mean_success / args.episodes * 100
            print(f"{k.capitalize()}:")
            print(f"  Mean successes: {mean_success:.1f} ± {std_success:.2f} / {args.episodes}")
            print(f"  Success rate: {success_rate:.2f}%")
            print(f"  Min: {min(success_counts)}, Max: {max(success_counts)}")
            print()
    
    # Plot multi-subplot comparison
    if len(all_results) > 1:
        save_path = os.path.join(out_dir, "keylock_success_rate_multi_subplot.png")
        plot_box_multi_subplots(all_results, args.episodes, save_path)
    else:
        # Single plot if only one num_opts
        num_opts = num_opts_list[0]
        save_path = os.path.join(out_dir, f"keylock_success_rate_boxplot_numopts{num_opts}.png")
        plot_box(all_results[num_opts], args.episodes, save_path, num_opts=num_opts)


if __name__ == "__main__":
    main()
