#!/usr/bin/env python
"""Reward maximization with a goal in MiniGrid using options.

Loads saved option Q-tables, trains a controller with ε-greedy over
primitive actions + options, and plots reward curves.
"""

import argparse, random, os, glob
import numpy as np
import matplotlib.pyplot as plt
from .bottleneck_env import SimpleEnv
from .generate_state_transition_matrix import build_state_transition_matrix

# ---------------- Tasks & hyper-parameters ------------------
# (1, 1)→(10, 4)  VPS , (13, 10)→(13, 7)  Eigen, (13, 1)→(3, 6)  Random
START_POS = (13, 1)
GOAL_POS = (3, 6)

EPISODES = 100
STEPS_PER_E = 200
EVAL_EVERY = 1
EVAL_TRIALS = 1
EPSILON = 0.1
ALPHA = 0.05
GAMMA = 0.99
P_PRIM_RAND = 0.16666  # probability of executing a primitive step when mixing


def option_terminated(q_row: np.ndarray) -> bool:
    """Terminate an option if: (1) all Q ≤ 0  OR  (2) random 1/L chance.

    This is used *during rollout* of an already-initiated option.
    """
    L = 15
    return q_row.max() <= 0 or random.random() <= 1 / L


def option_can_start(q_row: np.ndarray) -> bool:
    """
    Whether an option can be *initiated* at the current state.
    Initiation set is the full state space, but we only allow
    starting options whose local Q-max is strictly positive.
    """
    return q_row.max() > 0


def available_options(Qopt, s):
    """
    Indices of options that can be *initiated* at state s for exploration.
    All options share the same rule: startable iff max Q(s, ·) > 0.
    """
    if Qopt is None or not hasattr(Qopt, "shape"):
        return []
    k = Qopt.shape[0]
    return [oid for oid in range(k) if option_can_start(Qopt[oid, s])]


def train_controller(
    seed,
    T,
    N_STATES,
    start_state,
    goal_state,
    Qopt=None,
    opt_pi=None,
):
    random.seed(seed)
    np.random.seed(seed)

    # Upper-level controller learns Q over primitive actions (4) + options (if any).
    N_PRIM = 4
    N_OPT = Qopt.shape[0] if (Qopt is not None and hasattr(Qopt, "shape")) else 0
    N_ACT = N_PRIM + N_OPT
    Q = np.zeros((N_STATES, N_ACT), np.float32)
    eval_curve = []

    def available_action_indices(s):
        """Primitive actions are always available; options only if they can start."""
        actions = list(range(N_PRIM))
        if N_OPT > 0:
            for oid in range(N_OPT):
                if option_can_start(Qopt[oid, s]):
                    actions.append(N_PRIM + oid)
        return actions

    def run_episode(train=True):
        s, steps, ret = start_state, 0, 0.0
        while steps < STEPS_PER_E:
            avail_actions = available_action_indices(s)

            # ------------- ε-greedy over primitive + options -------------
            if train and random.random() < EPSILON:
                a = int(random.choice(avail_actions))
            else:
                qrow = Q[s]
                mask = np.zeros_like(qrow, dtype=bool)
                for a_idx in avail_actions:
                    mask[a_idx] = True
                masked_q = qrow.copy()
                if mask.any():
                    masked_q[~mask] = masked_q[mask].min() - 1.0
                best = np.flatnonzero(masked_q == masked_q.max())
                a = int(random.choice(best))

            # ---------- primitive action ----------
            if a < N_PRIM:
                sn = int(T[s, a])
                r = 1.0 if sn == goal_state else -0.01

                if train:
                    # One-step Q-learning
                    if sn == goal_state or steps + 1 >= STEPS_PER_E:
                        next_term = 0.0
                    else:
                        next_term = GAMMA * Q[sn].max()
                    td = r + next_term - Q[s, a]
                    Q[s, a] += ALPHA * td

                s = sn
                steps += 1
                ret += r
                if sn == goal_state:
                    break

            # ---------- option (SMDP-style update) ----------
            else:
                oid = a - N_PRIM
                sn = s
                cumulative = 0.0  # ∑ γ^i r_{t+i+1}
                disc = 1.0  # γ^i
                l = 0  # duration τ

                while l < STEPS_PER_E - steps and not option_terminated(Qopt[oid, sn]):
                    ain = int(opt_pi[oid, sn])  # primitive chosen by option
                    sn_next = int(T[sn, ain])
                    r = 1.0 if sn_next == goal_state else -0.01

                    cumulative += disc * r
                    disc *= GAMMA
                    ret += r

                    sn = sn_next
                    l += 1
                    steps += 1

                    if sn == goal_state:
                        break

                if train:
                    # SMDP Q-learning
                    if sn == goal_state or steps >= STEPS_PER_E:
                        next_term = 0.0
                    else:
                        next_term = disc * Q[sn].max()  # disc == γ^τ
                    td = cumulative + next_term - Q[s, a]
                    Q[s, a] += ALPHA * td

                s = sn
                if sn == goal_state:
                    break

        return ret

    # ----- training loop ------------------------------------
    for ep in range(1, EPISODES + 1):
        run_episode(train=True)
        if ep % EVAL_EVERY == 0:
            avg_r = np.mean([run_episode(train=False) for _ in range(EVAL_TRIALS)])
            eval_curve.append(avg_r)
    return np.asarray(eval_curve)


def load_option_files(method: str, out_dir: str):
    pattern = os.path.join(out_dir, f"gridworld_*_{method}_*.npy")
    files = sorted(glob.glob(pattern))
    Q_list = [np.load(f) for f in files]
    for f, Q in zip(files, Q_list):
        print(f"[Load] {f}  (K={Q.shape[0]})")
    return Q_list


def main():
    pa = argparse.ArgumentParser()
    pa.add_argument("--inner", type=int, default=10, help="# runs / each option set")
    pa.add_argument("--out_dir", default="option_results")
    args = pa.parse_args()

    # ---- environment ----
    env = SimpleEnv(render_mode=None)
    env.reset()
    T, wall = build_state_transition_matrix(env)
    N_STATES = T.shape[0]
    start_state = START_POS[0] + START_POS[1] * env.width
    goal_state = GOAL_POS[0] + GOAL_POS[1] * env.width

    # Match the saving convention in gridworld_options.py:
    script_dir = os.path.dirname(os.path.abspath(__file__))
    out_dir = os.path.join(script_dir, args.out_dir)
    os.makedirs(out_dir, exist_ok=True)

    # ---- load options ----
    Q_random = load_option_files("RandomOption", out_dir)
    Q_eigen = load_option_files("EigenOpt", out_dir)
    Q_vps = load_option_files("VPSOpt", out_dir)

    methods_available = {"Primitive": None, "Random": Q_random, "Eigen": Q_eigen, "VPS": Q_vps}

    T_points = EPISODES // EVAL_EVERY
    curves = {m: [] for m in methods_available}

    # Ensure equal run count for Primitive
    total_option_sets = max(len(Q_random), len(Q_eigen), len(Q_vps), 1)
    for outer in range(total_option_sets):
        random.seed(outer)
        np.random.seed(outer)

        Q_v = Q_vps[outer] if outer < len(Q_vps) else None
        Q_e = Q_eigen[outer] if outer < len(Q_eigen) else None
        Q_r = Q_random[outer] if outer < len(Q_random) else None
        pi_v = np.argmax(Q_v, 2) if Q_v is not None else None
        pi_e = np.argmax(Q_e, 2) if Q_e is not None else None
        pi_r = np.argmax(Q_r, 2) if Q_r is not None else None

        for inner in range(args.inner):
            seed = outer * 1000 + inner
            curves["Primitive"].append(
                train_controller(seed, T=T, N_STATES=N_STATES, start_state=start_state, goal_state=goal_state)
            )
            if Q_r is not None:
                curves["Random"].append(
                    train_controller(
                        seed,
                        Qopt=Q_r,
                        opt_pi=pi_r,
                        T=T,
                        N_STATES=N_STATES,
                        start_state=start_state,
                        goal_state=goal_state,
                    )
                )
            if Q_e is not None:
                curves["Eigen"].append(
                    train_controller(
                        seed,
                        Qopt=Q_e,
                        opt_pi=pi_e,
                        T=T,
                        N_STATES=N_STATES,
                        start_state=start_state,
                        goal_state=goal_state,
                    )
                )
            if Q_v is not None:
                curves["VPS"].append(
                    train_controller(
                        seed,
                        Qopt=Q_v,
                        opt_pi=pi_v,
                        T=T,
                        N_STATES=N_STATES,
                        start_state=start_state,
                        goal_state=goal_state,
                    )
                )

    # ---- aggregate & plot ----------------------------------
    colors = {"Primitive": "tab:blue", "Random": "tab:purple", "Eigen": "tab:orange", "VPS": "tab:green"}
    legend_labels = {
        "Primitive": "Flat Q-Learning",
        "Random": "Random Option",
        "Eigen": "Eigenoption",
        "VPS": "VPS Option",
    }

    means, stds = {}, {}
    for m, runs in curves.items():
        if not runs:
            continue
        mat = np.vstack(runs)
        means[m] = mat.mean(0)
        stds[m] = mat.std(0)

    x = np.arange(T_points) * EVAL_EVERY
    plt.figure(figsize=(8, 4.5))
    for m in ["Primitive", "Random", "Eigen", "VPS"]:
        if m not in means:
            continue
        plt.plot(x, means[m], color=colors[m], label=legend_labels[m])
        plt.fill_between(x, means[m] - stds[m], means[m] + stds[m], color=colors[m], alpha=0.20)

    plt.xlabel("Episodes", fontsize=16)
    plt.ylabel("Return", fontsize=16)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.title("Reward Maximization Task", fontsize=16)
    plt.grid(alpha=0.3)
    plt.legend(loc="right", fontsize=14)
    plt.tight_layout()
    plt.show()

    print("[✓] Done.")


if __name__ == "__main__":
    main()

