#!/usr/bin/env python
"""Reward maximization with a goal in MiniGrid using options.

Loads saved option Q-tables, trains a controller with ε-greedy over
primitive actions + options, and plots reward curves.
"""
import argparse, random, os, glob
import numpy as np
import matplotlib.pyplot as plt
from bottleneck_env import SimpleEnv
from generate_state_transition_matrix import build_state_transition_matrix

# ---------------- Tasks & hyper-parameters ------------------
# (1, 1)→(10, 4)  VPS , (13, 10)→(13, 7)  Eigen, (13, 1)→(3, 6)  Random
START_POS = (13, 1)
GOAL_POS  = (3, 6)

EPISODES      = 100
STEPS_PER_E   = 200
EVAL_EVERY    = 1
EVAL_TRIALS   = 1
EPSILON       = 0.1
ALPHA         = 0.05
GAMMA         = 0.99
P_PRIM_RAND   = 0.16666   # probability of executing a primitive step when mixing

# ------------------- utilities ------------------------------
def option_terminated(q_row: np.ndarray) -> bool:
    """Terminate an option if: (1) all Q ≤ 0  OR  (2) random 1/L chance."""
    L = 15
    return q_row.max() <= 0 or random.random() <= 1 / L


def available_mask(Qopt, s):
    """
    Boolean mask of actions (+ options) that are available at state s.
    Primitive actions are always available; an option is masked if it terminates.
    """
    if Qopt is None or not hasattr(Qopt, "shape"):
        return np.ones(4, bool)
    k = Qopt.shape[0]
    mask = np.ones(4 + k, bool)
    for oid in range(k):
        if option_terminated(Qopt[oid, s]):
            mask[4 + oid] = False
    return mask


def epsilon_action(Q, Qopt, s):
    """ε-greedy selection over primitive actions + options."""
    mask = available_mask(Qopt, s)
    if random.random() < EPSILON:                     # pure exploration
        idx = np.flatnonzero(mask)
        return int(random.choice(idx))
    qrow = Q[s].copy()
    qrow[~mask] = -np.inf
    best = np.flatnonzero(qrow == qrow.max())
    return int(random.choice(best))


def greedy_action(Qrow, mask):
    """Greedy tie-breaking with random choice among maxima."""
    masked = np.where(mask, Qrow, -np.inf)
    return int(np.random.choice(np.flatnonzero(masked == masked.max())))


# ----------- single run: Sarsa + options --------------------
def train_controller(
    seed,
    T,
    N_STATES,
    start_state,
    goal_state,
    Qopt=None,
    opt_pi=None,
):
    random.seed(seed)
    np.random.seed(seed)

    k_opts = 0 if (Qopt is None or not hasattr(Qopt, "shape")) else Qopt.shape[0]
    N_ACT = 4 + k_opts
    Q = np.zeros((N_STATES, N_ACT), np.float32)
    eval_curve = []

    def run_episode(train=True):
        s, steps, ret = start_state, 0, 0
        while steps < STEPS_PER_E:
            a = (
                epsilon_action(Q, Qopt, s)
                if train
                else greedy_action(Q[s], available_mask(Qopt, s))
            )

            # ---------- primitive action -------------------------------
            if a < 4:
                sn = int(T[s, a])
                r = 1.0 if sn == goal_state else -0.01
                if train:
                    td = r + GAMMA * Q[sn].max() - Q[s, a]
                    Q[s, a] += ALPHA * td
                s, steps, ret = sn, steps + 1, ret + r
                if sn == goal_state:
                    break

            # ---------- option ----------------------------------------
            else:
                oid, sn, accR, l = a - 4, s, 0.0, 0
                while (
                    l < STEPS_PER_E - steps
                    and not option_terminated(Qopt[oid, sn])
                ):
                    ain = opt_pi[oid, sn]
                    sn = int(T[sn, ain])
                    r = 1.0 if sn == goal_state else -0.01
                    accR += (GAMMA**l) * r
                    l += 1
                    ret += r
                    if sn == goal_state:
                        break
                if train:
                    td = accR + (GAMMA**l) * Q[sn].max() - Q[s, a]
                    Q[s, a] += ALPHA * td
                s, steps = sn, steps + l
                if sn == goal_state:
                    break
        return ret

    # ----- training loop ------------------------------------
    for ep in range(1, EPISODES + 1):
        run_episode(train=True)
        if ep % EVAL_EVERY == 0:
            avg_r = np.mean(
                [run_episode(train=False) for _ in range(EVAL_TRIALS)]
            )
            eval_curve.append(avg_r)
    return np.asarray(eval_curve)


# ------------- load local option files ----------------------
def load_option_files(method: str, out_dir: str):
    pattern = os.path.join(out_dir, f"gridworld_*_{method}_*.npy")
    files = sorted(glob.glob(pattern))
    Q_list = [np.load(f) for f in files]
    for f, Q in zip(files, Q_list):
        print(f"[Load] {f}  (K={Q.shape[0]})")
    return Q_list


# ---------------------------- main --------------------------
def main():
    pa = argparse.ArgumentParser()
    pa.add_argument(
        "--inner", type=int, default=10, help="# runs / each option set"
    )
    pa.add_argument("--out_dir", default="option_results")
    args = pa.parse_args()

    # ---- environment ----
    env = SimpleEnv(render_mode=None)
    env.reset()
    T, wall = build_state_transition_matrix(env)
    N_STATES = T.shape[0]
    start_state = START_POS[0] + START_POS[1] * env.width
    goal_state = GOAL_POS[0] + GOAL_POS[1] * env.width

    # ---- load options ----
    Q_random = load_option_files("RandomOption", args.out_dir)
    Q_eigen = load_option_files("EigenOpt", args.out_dir)
    Q_vps = load_option_files("VPSOpt", args.out_dir)

    # ---- experiments ----
    methods_available = {
        "Primitive": None,
        "Random": Q_random,
        "Eigen": Q_eigen,
        "VPS": Q_vps,
    }

    T_points = EPISODES // EVAL_EVERY
    curves = {m: [] for m in methods_available}

    # Ensure equal run count for Primitive
    total_option_sets = max(len(Q_random), len(Q_eigen), len(Q_vps), 1)
    for outer in range(total_option_sets):
        random.seed(outer)
        np.random.seed(outer)

        # ----- pick current option set (or None) -----
        Q_v = Q_vps[outer] if outer < len(Q_vps) else None
        Q_e = Q_eigen[outer] if outer < len(Q_eigen) else None
        Q_r = Q_random[outer] if outer < len(Q_random) else None
        pi_v = np.argmax(Q_v, 2) if Q_v is not None else None
        pi_e = np.argmax(Q_e, 2) if Q_e is not None else None
        pi_r = np.argmax(Q_r, 2) if Q_r is not None else None

        for inner in range(args.inner):
            seed = outer * 1000 + inner
            # Primitive
            curves["Primitive"].append(
                train_controller(
                    seed,
                    T=T,
                    N_STATES=N_STATES,
                    start_state=start_state,
                    goal_state=goal_state,
                )
            )
            # Random Option
            if Q_r is not None:
                curves["Random"].append(
                    train_controller(
                        seed,
                        Qopt=Q_r,
                        opt_pi=pi_r,
                        T=T,
                        N_STATES=N_STATES,
                        start_state=start_state,
                        goal_state=goal_state,
                    )
                )
            # Eigen Option
            if Q_e is not None:
                curves["Eigen"].append(
                    train_controller(
                        seed,
                        Qopt=Q_e,
                        opt_pi=pi_e,
                        T=T,
                        N_STATES=N_STATES,
                        start_state=start_state,
                        goal_state=goal_state,
                    )
                )
            # VPS Option
            if Q_v is not None:
                curves["VPS"].append(
                    train_controller(
                        seed,
                        Qopt=Q_v,
                        opt_pi=pi_v,
                        T=T,
                        N_STATES=N_STATES,
                        start_state=start_state,
                        goal_state=goal_state,
                    )
                )

    # ---- aggregate & plot ----------------------------------
    colors = {
        "Primitive": "tab:blue",
        "Random": "tab:purple",
        "Eigen": "tab:orange",
        "VPS": "tab:green",
    }

    legend_labels = {
        "Primitive": "Flat Q-Learning",
        "Random": "Random Option",
        "Eigen": "Eigenoption",
        "VPS": "VPS Option",
    }

    means, stds = {}, {}
    for m, runs in curves.items():
        if not runs:
            continue
        mat = np.vstack(runs)
        means[m] = mat.mean(0)
        stds[m] = mat.std(0)

    x = np.arange(T_points) * EVAL_EVERY
    plt.figure(figsize=(8, 4.5))
    for m in ["Primitive", "Random", "Eigen", "VPS"]:
        if m not in means:
            continue
        plt.plot(x, means[m], color=colors[m], label=legend_labels[m])
        plt.fill_between(
            x,
            means[m] - stds[m],
            means[m] + stds[m],
            color=colors[m],
            alpha=0.20,
        )

    plt.xlabel("Episodes", fontsize=16)
    plt.ylabel("Return", fontsize=16)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.title("Reward Maximization Task", fontsize=16)
    plt.grid(alpha=0.3)
    plt.legend(loc="right", fontsize=14)
    plt.tight_layout()
    plt.show()

    print("[✓] Done.")


if __name__ == "__main__":
    main()
