#!/usr/bin/env python
"""Evaluate Primitive vs. Random/Eigen/VPS options on discrete Gym envs.

Loads option Q-tables saved by the training script and runs SMDP-Q
control, plotting reward curves over episodes. Supports multiple
independent option sets (outer) and seeds (inner).
"""
from __future__ import annotations
import argparse, random
from pathlib import Path
from typing import Dict, List
import glob, os
import numpy as np
import gymnasium as gym
import matplotlib.pyplot as plt

# ----------------- default hyper-parameters -----------------
EPISODES       = 1000
STEPS_PER_EP   = 200
EVAL_EVERY     = 1
EVAL_EPISODES  = 10
EPSILON        = 0.1
ALPHA          = 0.1
GAMMA          = 0.99

# ----------------- utility functions ------------------------
# def option_terminated(q_row: np.ndarray) -> bool:
#     return q_row.max() <= 0

def option_can_start(q_row: np.ndarray) -> bool:
    """
    Gridworld-style initiation: full state space, but an option is
    considered startable at s only if its local Q-max is strictly > 0.
    """
    return q_row.max() > 0


def option_terminated(q_row: np.ndarray, L: int = 15) -> bool:
    """
    Gridworld-style termination during rollout: an option terminates if
    its local Q-max is non-positive OR with probability 1/L at each step.
    """
    return (q_row.max() <= 0) or (random.random() < 1.0 / L)


def random_argmax(qrow: np.ndarray, mask: np.ndarray | None = None) -> int:
    if mask is not None:
        qrow = np.where(mask, qrow, -np.inf)
    best = np.flatnonzero(qrow == qrow.max())
    return int(random.choice(best))

# ---------------- SMDP controller (SMDP-Q) ------------------
class SMDPAgent:
    def __init__(
        self,
        env: gym.Env,
        Qopt: np.ndarray | None,
        epsilon: float,
        option_p: float,
        alpha: float,
        gamma: float,
    ):
        self.env = env
        self.S = env.observation_space.n
        self.Ap = env.action_space.n

        self.Qopt = Qopt
        self.K = 0 if Qopt is None else Qopt.shape[0]

        # Q for primitive actions + options
        self.Q = np.zeros((self.S, self.Ap + self.K), np.float32)
        self.eps, self.alpha, self.gamma = epsilon, alpha, gamma
        # When ε-exploration triggers, choose an option with probability p,
        # otherwise choose a primitive action. If no option is startable at
        # the current state, always fall back to a primitive action.
        self.option_p = float(option_p)

        # tie-breaking for each option’s greedy policy
        if self.K:
            self.opt_pi = np.zeros((self.K, self.S), dtype=int)
            for o in range(self.K):
                for s in range(self.S):
                    row = Qopt[o, s]
                    best = np.flatnonzero(row == row.max())
                    self.opt_pi[o, s] = int(random.choice(best))
        else:
            self.opt_pi = None

    # ------ ε-greedy over primitives and options -------------
    def _select_action(self, s: int, *, greedy: bool = False) -> int:
        mask = np.ones(self.Ap + self.K, bool)
        if self.K:
            for oid in range(self.K):
                # Only allow options whose local Q-max is strictly positive.
                if not option_can_start(self.Qopt[oid, s]):
                    mask[self.Ap + oid] = False

        # exploration
        if (not greedy) and random.random() < self.eps:
            # With probability option_p, pick a random *startable* option.
            # Otherwise, pick a random primitive action.
            if self.K and random.random() < self.option_p:
                startable = np.flatnonzero(mask[self.Ap :])
                if startable.size > 0:
                    oid = int(random.choice(startable))
                    return self.Ap + oid
            return int(random.randrange(self.Ap))

        # exploitation
        return random_argmax(self.Q[s], mask)

    # ------ run one episode ---------------------------------
    def run_episode(self, max_len: int, *, train: bool) -> float:
        s, _ = self.env.reset()
        s = int(s)
        total_r, steps = 0.0, 0

        while steps < max_len:
            a = self._select_action(s, greedy=not train)
            term, trunc = False, False

            # ---- primitive action ------------------------------------
            if a < self.Ap or self.K == 0:
                sn, r, term, trunc, _ = self.env.step(a)
                sn = int(sn)
                total_r += r

                if train:
                    boot = 0.0 if (term or trunc) else self.Q[sn].max()
                    td = r + self.gamma * boot - self.Q[s, a]
                    self.Q[s, a] += self.alpha * td

                s = sn
                steps += 1
                if term or trunc:
                    break

            # ---- option rollout -------------------------------------
            else:
                oid = a - self.Ap
                R, L = 0.0, 0
                s0 = s
                while L < max_len - steps and not option_terminated(
                    self.Qopt[oid, s]
                ):
                    ain = int(self.opt_pi[oid, s])
                    sn, r, term, trunc, _ = self.env.step(ain)
                    sn = int(sn)
                    total_r += r
                    R += (self.gamma ** L) * r
                    L += 1
                    s = sn
                    if term or trunc:
                        break

                if train:
                    boot = 0.0 if (term or trunc) else self.Q[s].max()
                    tgt = R + (self.gamma ** L) * boot
                    td = tgt - self.Q[s0, a]
                    self.Q[s0, a] += self.alpha * td

                steps += L
                if term or trunc:
                    break
        return total_r


# ---------------- option file loading -----------------------
def load_option_groups(env_id: str, out_dir: Path, outer: int):
    groups = {"random": [], "eigen": [], "vps": []}
    for f in out_dir.glob(f"{env_id}_*_*Opt_*.npy"):
        name = f.name.lower()
        if "random" in name:
            groups["random"].append(np.load(f))
        elif "eigen" in name:
            groups["eigen"].append(np.load(f))
        elif "vps" in name:
            groups["vps"].append(np.load(f))

    for k, lst in groups.items():
        if len(lst) < outer:
            raise RuntimeError(f"[{env_id}] {k} groups={len(lst)} < outer={outer}")
        groups[k] = lst[:outer]
    return groups


# ---------------- run experiment ----------------------------
def run_experiment(
    env_id: str,
    groups,
    outer,
    inner,
    episodes,
    max_len,
    eps,
    option_p: float,
    alpha,
    gamma,
    eval_trials: int,
):
    methods = ["primitive", "random", "eigen", "vps"]
    Tpts = episodes // EVAL_EVERY
    curves = {m: np.zeros((outer * inner, Tpts), np.float32) for m in methods}

    run = 0
    for g in range(outer):
        Q_r, Q_e, Q_v = groups["random"][g], groups["eigen"][g], groups["vps"][g]

        for inn in range(inner):
            seed = g * 10000 + inn
            random.seed(seed)
            np.random.seed(seed)

            for meth, Qopt in [
                ("primitive", None),
                ("random", Q_r),
                ("eigen", Q_e),
                ("vps", Q_v),
            ]:
                env = gym.make(env_id)
                ag = SMDPAgent(env, Qopt, eps, option_p, alpha, gamma)
                for ep in range(episodes):
                    ag.run_episode(max_len, train=True)
                    if (ep + 1) % EVAL_EVERY == 0:
                        idx = ep // EVAL_EVERY
                        ret = sum(
                            ag.run_episode(max_len, train=False)
                            for _ in range(eval_trials)
                        ) / eval_trials
                        curves[meth][run, idx] = ret

                env.close()
            run += 1
            print(f"[{env_id}] group={g} seed={inn} finished")
    return curves


# ---------------- plotting helpers --------------------------
def _moving_average(arr: np.ndarray, k: int) -> np.ndarray:
    """Simple 1-D moving average; k=1 returns the input unchanged."""
    if k <= 1:
        return arr
    kernel = np.ones(k) / k
    return np.convolve(arr, kernel, mode="valid")  # length T-k+1


def plot_curves(
    curves: Dict[str, np.ndarray],
    env_id: str,
    win: int = 1,
    save_path: Path | None = None,
):
    """Display (and optionally save) reward curves. `win` is MA window."""
    colors = {
        "primitive": "tab:blue",
        "random": "tab:purple",
        "eigen": "tab:orange",
        "vps": "tab:green",
    }
    labels = {
        "primitive": "Flat Q-Learning",
        "random": "Random Option",
        "eigen": "Eigenoption",
        "vps": "VPS Option",
    }

    plt.figure(figsize=(8, 4.5))
    raw_x = np.arange(curves["primitive"].shape[1]) * EVAL_EVERY
    x = raw_x if win <= 1 else raw_x[win - 1 :]  # shorter after smoothing

    for m, mat in curves.items():
        mean, std = mat.mean(0), mat.std(0)
        if win > 1:
            mean = _moving_average(mean, win)
            std = _moving_average(std, win)
        plt.plot(x, mean, color=colors[m], label=labels[m])
        plt.fill_between(x, mean - std, mean + std, color=colors[m], alpha=0.25)

    plt.xlabel("Episode", fontsize=16)
    plt.ylabel("Return", fontsize=16)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    title = f"{env_id}  (window={win})" if win > 1 else f"{env_id}"
    plt.title(title, fontsize=16)
    plt.legend(fontsize=14)
    plt.grid(alpha=0.3)
    plt.tight_layout()

    if save_path is not None:
        plt.savefig(save_path)
        print(f"[✓] plot saved → {save_path.resolve()}")

    plt.show()


# ----------------------------- CLI --------------------------
def main():
    pa = argparse.ArgumentParser()
    pa.add_argument("--env", default="Taxi-v3")
    pa.add_argument("--out_dir", default="option_results")
    pa.add_argument("--outer", type=int, default=1)
    pa.add_argument("--inner", type=int, default=1)
    pa.add_argument("--episodes", type=int, default=EPISODES)
    pa.add_argument("--max_len", type=int, default=STEPS_PER_EP)
    pa.add_argument("--epsilon", type=float, default=EPSILON)
    pa.add_argument(
        "--option_p",
        type=float,
        default=0.01,
        help=(
            "When ε-greedy triggers exploration, choose a random *startable* option "
            "with probability p and a random primitive action with probability 1-p. "
            "If no option is startable at the current state, always fall back to a "
            "random primitive action."
        ),
    )
    pa.add_argument("--alpha", type=float, default=ALPHA)
    pa.add_argument("--gamma", type=float, default=GAMMA)
    pa.add_argument("--plot_path", default="reward_curve.png")
    pa.add_argument(
        "--smooth",
        type=int,
        default=1,  # 1 = no smoothing
        help="moving-average window length for the reward curve",
    )
    pa.add_argument(
        "--eval_trials",
        type=int,
        default=10,
        help="number of evaluation episodes to average",
    )

    args = pa.parse_args()

    # Resolve option directory relative to this script:
    # `discrete_gym/option_results/<Env>/...`.
    script_dir = os.path.dirname(os.path.abspath(__file__))
    root_dir = os.path.join(script_dir, args.out_dir)
    save_dir = os.path.join(root_dir, args.env)
    os.makedirs(save_dir, exist_ok=True)
    plot_dir = os.path.join(save_dir, args.plot_path)

    groups = load_option_groups(args.env, Path(save_dir), args.outer)
    curves = run_experiment(
        args.env,
        groups,
        args.outer,
        args.inner,
        args.episodes,
        args.max_len,
        args.epsilon,
        args.option_p,
        args.alpha,
        args.gamma,
        args.eval_trials,
    )
    plot_curves(curves, env_id=args.env, win=args.smooth, save_path=Path(plot_dir))


if __name__ == "__main__":
    main()
