#!/usr/bin/env python
"""Train Random/Eigen/VPS options on a MiniGrid layout (tabular).

Two-stage training for VPS and Random options (buffer then offline Q),
eigenoptions from graph Laplacian eigenvectors. Produces saved Q-tables
compatible with downstream experiments and visualization.
"""
import os, argparse, random, collections
import numpy as np
import scipy.sparse as sp
from scipy.sparse.linalg import eigsh
from bottleneck_env import SimpleEnv
from generate_state_transition_matrix import build_state_transition_matrix
from utils import BottleneckVisualization


# ---------- misc utilities ----------------------------------
def set_seed(sd):
    """Set both NumPy and Python RNG seeds."""
    if sd is not None:
        np.random.seed(sd)
        random.seed(sd)


def bfs_distance(T, goal, N):
    """Breadth-first search distance over a deterministic transition graph."""
    d = np.full(N, np.inf, np.float32)
    d[goal] = 0
    q = collections.deque([goal])
    while q:
        s = q.popleft()
        for a in range(4):
            sn = int(T[s, a])
            if d[sn] > d[s] + 1:
                d[sn] = d[s] + 1
                q.append(sn)
    return d


def option_rows_dup(arr, sign):
    """Duplicate positive/negative rows when sign=True."""
    return np.vstack([arr, arr]) if sign else arr


# ---------- phase-0 : collect buffer ------------------------
def collect_buffer(T, free_idx, *, episodes=1000, max_len=200):
    """
    Execute a random policy for 'episodes', each with length up to 'max_len'.
    Returns:
        · buffer : List[(s,a,sn)]
        · visited_set : set(states)
    """
    buffer = []
    visited = set()
    for _ in range(episodes):
        s = int(np.random.choice(free_idx))
        for _ in range(max_len):
            a = random.randint(0, 3)
            sn = int(T[s, a])
            buffer.append((s, a, sn))
            visited.add(s)
            visited.add(sn)
            s = sn
    return buffer, visited


# ---------- 1. VPS-Option -----------------------------------
def train_vps_options(
    T,
    free_idx,
    *,
    k_base,
    sign,
    gamma,
    alpha,
    lam=0.9,
    collect_ep,
    ep_len,
):
    """
    Two stages:
      (1) Estimate V and φ with TD(λ) under random Gaussian rewards
      (2) Offline Q-learning using the collected buffer
    """
    n_states = T.shape[0]
    total_opts = k_base * (2 if sign else 1)

    # ------- collect buffer ----------
    buffer, _ = collect_buffer(
        T, free_idx, episodes=collect_ep, max_len=ep_len
    )

    # ------- Phase-1 : learn V & φ ----
    V = np.zeros((k_base, n_states), np.float32)
    phi = np.zeros_like(V)                         # |TD|
    elig = np.zeros(n_states, np.float32)
    randR = np.random.randn(k_base, n_states).astype(np.float32)

    for ep in range(collect_ep):
        elig.fill(0.0)
        offset = ep * ep_len  # start index of this episode
        for t in range(ep_len):
            s, a, sn = buffer[offset + t]
            elig *= gamma * lam
            elig[s] += 1
            td = randR[:, sn] + gamma * V[:, sn] - V[:, s]
            td_v = (V[:, sn] - V[:, s]) ** 2
            V += alpha * td[:, None] * elig
            phi[:, s] += 0.001 * (td_v - phi[:, s])

    # ------- Phase-2 : learn Q options -
    Q = np.zeros((total_opts, n_states, 4), np.float32)
    for (s, a, sn) in buffer:
        r_vec = phi[:, sn] - phi[:, s]            # (+φ) reward
        for i, r in enumerate(r_vec):
            pos = i
            Q[pos, s, a] += alpha * (
                r + gamma * Q[pos, sn].max() - Q[pos, s, a]
            )
            if sign:
                neg = i + k_base
                Q[neg, s, a] += alpha * (
                    -r + gamma * Q[neg, sn].max() - Q[neg, s, a]
                )

    return Q, option_rows_dup(V, sign), option_rows_dup(phi, sign)


# ---------- 2. Eigen-Option ---------------------------------
def train_eigen_options(
    T,
    free_idx,
    N,
    *,
    k_base,
    sign,
    gamma,
    alpha,
    collect_ep,
    ep_len,
):
    buffer, _ = collect_buffer(
        T, free_idx, episodes=collect_ep, max_len=ep_len
    )

    # ----- build un-weighted adjacency from buffer -----
    counts = np.zeros((N, N), np.int32)
    for s, _, sn in buffer:
        counts[s, sn] = 1
    A = sp.csr_matrix(counts)
    deg = np.ravel(A.sum(1))
    Dinv = sp.diags(1 / np.sqrt(np.maximum(deg, 1e-8)))
    L = sp.eye(N, dtype=np.float32) - Dinv @ A @ Dinv

    k_need = k_base + 1  # skip the trivial eigenvector
    _, vecs = eigsh(L, k=k_need, which="SM")
    eig_vecs = vecs[:, 1:]

    total = k_base * (2 if sign else 1)
    Q = np.zeros((total, N, 4), np.float32)

    opt_id = 0
    for phi in eig_vecs.T:
        for sg in ([1, -1] if sign else [1]):
            if opt_id >= total:
                break
            for s, a, sn in buffer:
                r = sg * (phi[sn] - phi[s])
                Q[opt_id, s, a] += alpha * (
                    r + gamma * Q[opt_id, sn].max() - Q[opt_id, s, a]
                )
            opt_id += 1
        if opt_id >= total:
            break
    return Q


# ------------------------------------------------------------
#  Random-Option (Grid-World)
#  · Each option uses a random potential φ ~ N(0,1)
#  · shaped-reward  r = φ(sn) − φ(s)
# ------------------------------------------------------------
def train_random_options(
    T,            # transition table
    free_idx,     # walkable indices (kept for interface consistency)
    N,            # number of states
    *,            # keyword-only
    k_base,       # number of options
    gamma,
    alpha,
    collect_ep,
    ep_len,
    seed=None,
):
    """
    Returns:
        Q   : (k_base, N, 4)   Q-table for each Random Option
        phi : (k_base, N)      random potential functions
    """
    rng = np.random.default_rng(seed)

    # ---------- Phase-0: collect buffer ------------------
    buffer, _ = collect_buffer(
        T, free_idx, episodes=collect_ep, max_len=ep_len
    )

    # ---------- random potentials φ ~ N(0,1) -------------
    phi = rng.standard_normal((k_base, N)).astype(np.float32)

    # ---------- offline parallel Q-learning --------------
    Q = np.zeros((k_base, N, 4), np.float32)
    for s, a, sn in buffer:
        td = phi[:, sn] - phi[:, s]          # shaped reward
        td += gamma * Q[:, sn].max(1)        # bootstrap
        td -= Q[:, s, a]
        Q[:, s, a] += alpha * td

    return Q, phi


# -------------------------- main ----------------------------
def main():
    pa = argparse.ArgumentParser()
    g = pa.add_argument
    g("--num_opts", type=int, default=10)
    g("--sign", type=bool, default=True)
    g("--gamma", type=float, default=0.99)
    g("--alpha", type=float, default=0.05)
    g("--collect_ep", type=int, default=1000, help="random-walk episodes")
    g("--ep_len", type=int, default=200, help="steps per episode")
    g("--seed", type=int, default=0)
    g("--save_dir", default="option_results")
    g("--vis", type=bool, default=True)
    g("--only", choices=["vps", "eigen", "rand", "all"], default="all")
    g(
        "--outer_num",
        type=int,
        default=10,
        help="number of independent option sets",
    )
    args = pa.parse_args()

    set_seed(args.seed)
    env = SimpleEnv(render_mode=None)
    env.reset()
    T, wall = build_state_transition_matrix(env)
    free_idx = np.where(wall.flatten() == 0)[0]
    N = T.shape[0]
    os.makedirs(args.save_dir, exist_ok=True)
    vis = BottleneckVisualization(env) if args.vis else None

    for m in range(args.outer_num):
        # ---- VPS -------------------------------------------------------
        if args.only in ("vps", "all"):
            print(f"[VPS] training {m} …")
            Q_vps, V_vps, phi_vps = train_vps_options(
                T,
                free_idx,
                k_base=args.num_opts,
                sign=args.sign,
                gamma=args.gamma,
                alpha=args.alpha,
                collect_ep=args.collect_ep,
                ep_len=args.ep_len,
            )
            total = Q_vps.shape[0]
            fn = f"gridworld_{total}_VPSOpt_{m}.npy"
            np.save(os.path.join(args.save_dir, fn), Q_vps)
            print("  saved", fn)
            if vis:
                for k in range(min(4, total)):
                    vis.plot_policy_arrows(
                        np.argmax(Q_vps[k], 1), Q_vps[k], title=f"VPS-Option {k}"
                    )

        # ---- Eigen -----------------------------------------------------
        if args.only in ("eigen", "all"):
            print(f"[Eigen] training {m} …")
            Q_eig = train_eigen_options(
                T,
                free_idx,
                N,
                k_base=args.num_opts,
                sign=args.sign,
                gamma=args.gamma,
                alpha=args.alpha,
                collect_ep=args.collect_ep,
                ep_len=args.ep_len,
            )
            total = Q_eig.shape[0]
            fn = f"gridworld_{total}_EigenOpt_{m}.npy"
            np.save(os.path.join(args.save_dir, fn), Q_eig)
            print("  saved", fn)
            if vis:
                for k in range(min(4, total)):
                    vis.plot_policy_arrows(
                        np.argmax(Q_eig[k], 1), Q_eig[k], title=f"Eigen-Option {k}"
                    )

        # ---- Random ----------------------------------------------------
        if args.only in ("rand", "all"):
            print(f"[Random] training {m} …")
            total = args.num_opts * (2 if args.sign else args.num_opts)
            Q_rnd, phi_rnd = train_random_options(
                T,
                free_idx,
                N,
                k_base=total,
                gamma=args.gamma,
                alpha=args.alpha,
                collect_ep=args.collect_ep,
                ep_len=args.ep_len,
            )
            fn = f"gridworld_{total}_RandomOption_{m}.npy"
            np.save(os.path.join(args.save_dir, fn), Q_rnd)
            print("  saved", fn)
            if vis:
                for k in range(min(4, args.num_opts)):
                    vis.plot_policy_arrows(
                        np.argmax(Q_rnd[k], 1), Q_rnd[k], title=f"Random-Option {k}"
                    )


if __name__ == "__main__":
    main()
