#!/usr/bin/env python
"""Random-walk success-count comparison on Taxi-v3.

Counts successful episodes (pickup+drop-off) under random mixtures of
primitive actions and options, using option files saved previously.
"""
from __future__ import annotations
import argparse, glob, random
from pathlib import Path
from typing import List, Dict
import numpy as np
import matplotlib.pyplot as plt
import gymnasium as gym
import os

# ---------- option start / termination -----------------------
def option_can_start(q_row: np.ndarray) -> bool:
    """
    Gridworld-style initiation: full state space, but an option is
    considered startable at s only if its local Q-max is strictly > 0.
    """
    return q_row.max() > 0


def option_terminated(q_row: np.ndarray, L: int = 15) -> bool:
    """
    Gridworld-style termination during rollout: an option terminates if
    its local Q-max is non-positive OR with probability 1/L at each step.
    """
    return (q_row.max() <= 0) or (random.random() < 1.0 / L)


# ---------- load option files --------------------------------
def load_option_groups(
    env_id: str,
    opt_type: str,
    out_dir: Path,
    outer: int,
) -> Dict[str, List[np.ndarray]]:
    """
    Return a dict {type: [Q_group0, Q_group1, …]}.
    Each list element is one *.npy file (one option set).
    """
    kinds = {"random": "RandomOpt", "eigen": "EigenOpt", "vps": "VPSOpt"}
    results: Dict[str, List[np.ndarray]] = {k: [] for k in kinds}

    for k, tag in kinds.items():
        if opt_type != "all" and k != opt_type:
            continue
        pattern = out_dir / f"{env_id}_*_{tag}_*.npy"
        files = sorted(glob.glob(str(pattern)))
        if len(files) < outer:
            raise RuntimeError(f"{k}: require {outer} groups, found {len(files)}")
        results[k] = [np.load(f) for f in files[:outer]]
        for f in files[:outer]:
            print(f"[Load] {Path(f).name}")

    return results


# ---------- one episode: random walk -------------------------
def run_episode_random(
    env: gym.Env,
    Qopt: np.ndarray | None,
    max_len: int = 200,
) -> bool:
    """
    Return True if the agent gets the +20 reward (successful
    pickup + drop-off) within max_len steps, False otherwise.
    """
    s, _ = env.reset()
    s = int(s)
    Ap = env.action_space.n
    success = False
    steps = 0
    K = 0 if Qopt is None else Qopt.shape[0]
    policy = np.argmax(Qopt, 2) if K else None  # (K, S)

    term, trunc = False, False
    while steps < max_len:
        # ----- sample a primitive action or a *startable* option -----
        if K > 0:
            startable_opts = [
                oid for oid in range(K) if option_can_start(Qopt[oid, s])
            ]
        else:
            startable_opts = []

        candidates = list(range(Ap)) + [Ap + oid for oid in startable_opts]
        a = random.choice(candidates)  # primitives 0..Ap-1, options Ap+oid..

        # ----- primitive step --------------------------------------
        if a < Ap or K == 0 or not startable_opts:
            sn, r, term, trunc, _ = env.step(a)
            if r == 20:  # task success
                success = True
                break
            if term or trunc:
                break
            s = int(sn)
            steps += 1
        # ----- option rollout --------------------------------------
        else:
            oid = a - Ap
            while steps < max_len and not option_terminated(Qopt[oid, s]):
                ain = int(policy[oid, s])
                sn, r, term, trunc, _ = env.step(ain)
                if r == 20:
                    success = True
                    steps += 1
                    break
                s = int(sn)
                steps += 1
                if term or trunc:
                    break
            if success or term or trunc:
                break
    return success


# ---------- main experiment ----------------------------------
def evaluate_groups(
    env_id: str,
    groups: Dict[str, List[np.ndarray]],
    outer: int,
    episodes: int,
    max_len: int,
) -> Dict[str, List[int]]:
    """
    Return {type: [succ_cnt_group0, succ_cnt_group1, …]}.
    The primitive baseline has only one "group".
    """
    results = {"primitive": []}
    for k in groups:
        results[k] = []

    Ap = gym.make(env_id).action_space.n  # 6 actions in Taxi-v3

    # --- primitive baseline ---
    for _ in range(outer):
        env = gym.make(env_id)
        succ = 0
        for _ in range(episodes):
            env.reset()
            steps = 0
            while steps < max_len:
                a = random.randrange(Ap)
                sn, r, term, trunc, _ = env.step(a)
                if r == 20:
                    succ += 1
                    break
                steps += 1
                if term or trunc:
                    break
        env.close()
        results["primitive"].append(succ)

    # --- each option family ---
    for k, lst in groups.items():
        for gidx, Q in enumerate(lst):
            env = gym.make(env_id)
            succ = 0
            for _ in range(episodes):
                if run_episode_random(env, Q, max_len):
                    succ += 1
            env.close()
            results[k].append(succ)
            print(f"[{k}] group {gidx}: {succ}/{episodes} successes")
    return results


# ---------- box-plot ----------------------------------------
def plot_box(results: Dict[str, List[int]], episodes: int, env_id: str):
    labels = {"primitive": "Primitive",
              "random":    "Random",
              "eigen":     "Eigen",
              "vps":       "VPS"}
    order = ["primitive", "random", "eigen", "vps"]
    data = [results[k] for k in order if results[k]]

    plt.figure(figsize=(6, 4))
    plt.boxplot(data,
                labels=[labels[k] for k in order if results[k]],
                showmeans=True)
    plt.ylabel(f"Successful Trials / {episodes} Episodes", fontsize=14)
    plt.title(env_id, fontsize=16)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.grid(alpha=.3, axis="y")
    plt.tight_layout()
    plt.show()


# ---------------- CLI ---------------------------------------
def main():
    pa = argparse.ArgumentParser()
    pa.add_argument("--env", default="Taxi-v3")
    pa.add_argument("--out_dir", default="option_results")
    pa.add_argument("--opt_type",
                    choices=["random", "eigen", "vps", "all"],
                    default="all",
                    help="which option type(s) to test")
    pa.add_argument("--outer", type=int, default=5,
                    help="# option groups to load / per type")
    pa.add_argument("--episodes", type=int, default=1000,
                    help="# random-walk episodes per group")
    pa.add_argument("--max_len", type=int, default=100)
    args = pa.parse_args()

    # Resolve option directory relative to this script:
    # `discrete_gym/option_results/<Env>/...`.
    script_dir = os.path.dirname(os.path.abspath(__file__))
    root_dir = os.path.join(script_dir, args.out_dir)
    save_dir = os.path.join(root_dir, args.env)
    groups = load_option_groups(
        args.env,
        args.opt_type,
        Path(save_dir),
        args.outer,
    )

    results = evaluate_groups(
        args.env,
        groups,
        args.outer,
        args.episodes,
        args.max_len,
    )
    plot_box(results, args.episodes, args.env)


if __name__ == "__main__":
    main()
