#!/usr/bin/env python
"""Train Random/Eigen/VPS options on a MiniGrid layout (tabular).

Two-stage training for VPS and Random options (buffer then offline Q),
eigenoptions from graph Laplacian eigenvectors. Produces saved Q-tables
compatible with downstream experiments and visualization.
"""

import os, argparse, random, collections, ast
import numpy as np
import scipy.sparse as sp
from scipy.sparse.linalg import eigsh
from .bottleneck_env import SimpleEnv
from .generate_state_transition_matrix import build_state_transition_matrix
from .utils import BottleneckVisualization


# ---------- misc utilities ----------------------------------
def set_seed(sd):
    """Set both NumPy and Python RNG seeds."""
    if sd is not None:
        np.random.seed(sd)
        random.seed(sd)


def _parse_num_opts_list(x) -> list[int]:
    """Parse --num_opts as either an int or a list-like string.

    Examples:
      - 4
      - "4"
      - "2,4,6"
      - "[2,4,6]"
      - "(2, 4, 6)"
    """
    if isinstance(x, int):
        return [int(x)]
    if isinstance(x, (list, tuple)):
        return [int(v) for v in x]
    s = str(x).strip()
    if not s:
        raise ValueError("--num_opts is empty")
    if "," in s and not (s.startswith("[") or s.startswith("(")):
        return [int(p.strip()) for p in s.split(",") if p.strip()]
    try:
        val = ast.literal_eval(s)
    except Exception:
        return [int(s)]
    if isinstance(val, int):
        return [int(val)]
    if isinstance(val, (list, tuple)):
        return [int(v) for v in val]
    raise ValueError(f"Unsupported --num_opts format: {x!r}")


def option_rows_dup(arr, sign):
    """Duplicate positive/negative rows when sign=True."""
    return np.vstack([arr, arr]) if sign else arr


# ---------- phase-0 : collect buffer ------------------------
def collect_buffer(T, free_idx, *, episodes=1000, max_len=200):
    """
    Execute a random policy for 'episodes', each with length up to 'max_len'.
    Returns:
        · buffer : List[(s,a,sn)]
        · visited_set : set(states)
    """
    buffer = []
    visited = set()
    for _ in range(episodes):
        s = int(np.random.choice(free_idx))
        for _ in range(max_len):
            a = random.randint(0, 3)
            sn = int(T[s, a])
            buffer.append((s, a, sn))
            visited.add(s)
            visited.add(sn)
            s = sn
    return buffer, visited


# ---------- 1. VPS-Option -----------------------------------
def train_vps_options(
    T,
    free_idx,
    *,
    k_base,
    sign,
    gamma,
    alpha,
    lam=0.9,  # kept for backward-compatibility, not used in SR updates
    collect_ep,
    ep_len,
    sr_epochs=10,  # Number of epochs for SR learning (more epochs = smoother SR)
    sr_lambda=0.0,  # TD(λ) parameter for SR learning (0.0 = TD(0), >0 = TD(λ))
):
    """
    Two stages (tabular VPS-options with SR-based value predictions):
      (1) Before training, construct k_base mutually orthogonal reward
          weight vectors over states via QR decomposition. Then, using a
          random-walk buffer, learn a successor representation ψ(s) over
          one-hot state features.
      (2) Compute option-specific value functions V_i(s) = w_i^T ψ(s),
          build VPS features φ_i(s) ≈ E[(V_i(s') - V_i(s))^2] from the
          buffer, and finally run offline Q-learning with intrinsic
          rewards r_i = φ_i(s') - φ_i(s).

    This mirrors the newer VPS option design: intrinsic rewards are
    derived from value-change signals induced by fixed, orthogonal
    reward weights applied to a shared SR.
    """
    n_states = T.shape[0]
    total_opts = k_base * (2 if sign else 1)

    # ------- collect buffer ----------
    buffer, _ = collect_buffer(T, free_idx, episodes=collect_ep, max_len=ep_len)

    # ------- Phase-1 : construct reward weights & learn SR ψ(s) -------

    # 1) QR-based orthogonal reward weights over states.
    #    Each row w_i ∈ R^{n_states} defines an intrinsic reward function.
    randR_raw = np.random.randn(k_base, n_states).astype(np.float32)
    # QR on transpose → columns of Q_mat are orthonormal → rows of Q_mat.T are orthonormal
    Q_mat, _ = np.linalg.qr(randR_raw.T)  # (n_states, k_base)
    reward_weights = Q_mat.T.astype(np.float32)  # (k_base, n_states) orthonormal rows

    # 2) Tabular successor representation ψ(s) over one-hot state features.
    #    ψ has shape (n_states, n_states); row s stores ψ(s, ·).
    psi = np.zeros((n_states, n_states), dtype=np.float32)
    alpha_sr = alpha  # reuse alpha as SR learning rate

    # ------- Phase-1 : learn SR ψ(s) with multiple epochs for smoother results -------
    if sr_lambda > 0.0:
        # Use TD(λ) with eligibility traces for faster convergence
        print(f"[SR] Using TD(λ) with λ={sr_lambda} for {sr_epochs} epochs")

        for epoch in range(sr_epochs):
            eligibility = np.zeros(n_states, dtype=np.float32)  # Eligibility traces

            for ep in range(collect_ep):
                offset = ep * ep_len
                for t in range(ep_len):
                    s, a, sn = buffer[offset + t]
                    s = int(s)
                    sn = int(sn)

                    # Update eligibility trace: e(s) = γλ * e(s) + 1
                    eligibility *= gamma * sr_lambda
                    eligibility[s] += 1.0

                    # Compute TD error
                    delta = gamma * psi[sn] - psi[s]
                    delta[s] += 1.0  # add one-hot(s)

                    # Update SR for all states using eligibility traces
                    for state in range(n_states):
                        if eligibility[state] > 0:
                            psi[state] += alpha_sr * eligibility[state] * delta

                    # Reset eligibility traces at episode boundaries (end of episode)
                    if t == ep_len - 1:
                        eligibility.fill(0.0)
    else:
        # Use TD(0) with multiple epochs
        print(f"[SR] Using TD(0) for {sr_epochs} epochs")

        for epoch in range(sr_epochs):
            # Random shuffle buffer for each epoch to improve convergence
            indices = np.random.permutation(len(buffer))

            for idx in indices:
                s, a, sn = buffer[idx]
                s = int(s)
                sn = int(sn)

                # TD(0) update for SR with immediate feature e(s):
                #   ψ(s) ← ψ(s) + α [ e(s) + γ ψ(s') − ψ(s) ]
                delta = gamma * psi[sn] - psi[s]
                delta[s] += 1.0  # add one-hot(s)
                psi[s] += alpha_sr * delta

    # ------- Phase-1.5 : derive V_i(s) and VPS features φ_i(s) --------

    # Value functions for each option i:
    #   V_i(s) = w_i^T ψ(s)  where w_i is reward_weights[i]
    # Shapes: reward_weights (k_base, n_states), psi (n_states, n_states)
    # → V_base (k_base, n_states)
    V_base = reward_weights @ psi.T

    # Build VPS feature φ_i(s) as expected squared value-change:
    #   φ_i(s) ≈ E[(V_i(s') − V_i(s))^2 | s]
    phi_base = np.zeros_like(V_base)
    visit_counts = np.zeros(n_states, dtype=np.int32)

    for (s, a, sn) in buffer:
        s = int(s)
        sn = int(sn)
        dV = V_base[:, sn] - V_base[:, s]  # (k_base,)
        td2 = dV**2
        phi_base[:, s] += td2
        visit_counts[s] += 1

    non_zero = visit_counts > 0
    if np.any(non_zero):
        phi_base[:, non_zero] /= visit_counts[non_zero]

    # ------- Phase-2 : learn Q options with VPS rewards ---------------
    Q = np.zeros((total_opts, n_states, 4), np.float32)
    for (s, a, sn) in buffer:
        s = int(s)
        sn = int(sn)
        r_vec = phi_base[:, sn] - phi_base[:, s]  # (+φ) reward
        for i, r in enumerate(r_vec):
            pos = i
            Q[pos, s, a] += alpha * (r + gamma * Q[pos, sn].max() - Q[pos, s, a])
            if sign:
                neg = i + k_base
                Q[neg, s, a] += alpha * (-r + gamma * Q[neg, sn].max() - Q[neg, s, a])

    # For logging / analysis, return both V and φ with positive/negative
    # duplication when sign=True to stay consistent with the old API.
    return Q, option_rows_dup(V_base, sign), option_rows_dup(phi_base, sign)


# ---------- 2. Eigen-Option ---------------------------------
def train_eigen_options(
    T,
    free_idx,
    N,
    *,
    k_base,
    sign,
    gamma,
    alpha,
    collect_ep,
    ep_len,
):
    buffer, _ = collect_buffer(T, free_idx, episodes=collect_ep, max_len=ep_len)

    # ----- build transition probability matrix P from buffer -----
    # Count transitions: counts[i, j] = number of times we transition from state i to state j
    counts = np.zeros((N, N), np.float32)
    for s, _, sn in buffer:
        counts[int(s), int(sn)] += 1.0

    # Compute transition probability matrix P
    row_sums = counts.sum(axis=1, keepdims=True)
    row_sums = np.maximum(row_sums, 1e-8)
    P = counts / row_sums

    # Convert to sparse matrix for efficiency
    P_sparse = sp.csr_matrix(P, dtype=np.float32)

    # Compute random walk Laplacian: L_rw = I - P
    L_rw = sp.eye(N, dtype=np.float32) - P_sparse

    k_need = k_base + 1  # skip the trivial eigenvector
    _, vecs = eigsh(L_rw, k=k_need, which="SM")
    eig_vecs = vecs[:, 1:]  # Skip the first eigenvector (trivial constant vector)

    total = k_base * (2 if sign else 1)
    Q = np.zeros((total, N, 4), np.float32)

    opt_id = 0
    for phi in eig_vecs.T:
        for sg in ([1, -1] if sign else [1]):
            if opt_id >= total:
                break
            for s, a, sn in buffer:
                r = sg * (phi[sn] - phi[s])
                Q[opt_id, s, a] += alpha * (r + gamma * Q[opt_id, sn].max() - Q[opt_id, s, a])
            opt_id += 1
        if opt_id >= total:
            break
    return Q, eig_vecs


# ------------------------------------------------------------
#  Random-Option (Grid-World)
#  · Each option uses a random potential φ ~ N(0,1)
#  · shaped-reward  r = φ(sn) − φ(s)
# ------------------------------------------------------------
def train_random_options(
    T,  # transition table
    free_idx,  # walkable indices (kept for interface consistency)
    N,  # number of states
    *,  # keyword-only
    k_base,  # number of options
    gamma,
    alpha,
    collect_ep,
    ep_len,
    seed=None,
):
    """
    Returns:
        Q   : (k_base, N, 4)   Q-table for each Random Option
        phi : (k_base, N)      random potential functions
    """
    rng = np.random.default_rng(seed)

    # ---------- Phase-0: collect buffer ------------------
    buffer, _ = collect_buffer(T, free_idx, episodes=collect_ep, max_len=ep_len)

    # ---------- random potentials φ ~ N(0,1) -------------
    phi = rng.standard_normal((k_base, N)).astype(np.float32)

    # ---------- offline parallel Q-learning --------------
    Q = np.zeros((k_base, N, 4), np.float32)
    for s, a, sn in buffer:
        td = phi[:, sn] - phi[:, s]  # shaped reward
        td += gamma * Q[:, sn].max(1)  # bootstrap
        td -= Q[:, s, a]
        Q[:, s, a] += alpha * td

    return Q, phi


def main():
    pa = argparse.ArgumentParser()
    g = pa.add_argument
    g(
        "--num_opts",
        type=str,
        default="[8]",
        help='Number(s) of options to train. Examples: "4", "2,4,6", "[2,4,6]".',
    )
    g("--sign", type=bool, default=True)
    g("--gamma", type=float, default=0.99)
    g("--alpha", type=float, default=0.05)
    g("--collect_ep", type=int, default=1000, help="random-walk episodes")
    g("--ep_len", type=int, default=200, help="steps per episode")
    g(
        "--sr_epochs",
        type=int,
        default=1,
        help="Number of epochs for SR learning (more epochs = smoother value functions, default: 10)",
    )
    g(
        "--sr_lambda",
        type=float,
        default=0.9,
        help="TD(λ) parameter for SR learning (0.0 = TD(0), >0 = TD(λ) for faster convergence, default: 0.0)",
    )
    g("--seed", type=int, default=0)
    g("--save_dir", default="option_results")
    g("--vis", type=bool, default=True)
    g("--only", choices=["vps", "eigen", "rand", "all"], default="vps")
    g("--outer_num", type=int, default=1, help="number of independent option sets")
    args = pa.parse_args()

    env = SimpleEnv(render_mode=None)
    env.reset()
    T, wall = build_state_transition_matrix(env)
    free_idx = np.where(wall.flatten() == 0)[0]
    N = T.shape[0]

    # Always save option results in a directory next to this script,
    # regardless of the current working directory.
    script_dir = os.path.dirname(os.path.abspath(__file__))
    save_dir = os.path.join(script_dir, args.save_dir)
    os.makedirs(save_dir, exist_ok=True)

    vis = BottleneckVisualization(env) if args.vis else None

    num_list = _parse_num_opts_list(args.num_opts)
    for k_base in num_list:
        for m in range(args.outer_num):
            set_seed(int(args.seed) + 10000 * int(k_base) + m)
            # ---- VPS -------------------------------------------------------
            if args.only in ("vps", "all"):
                print(f"[VPS] k={k_base} training {m} …")
                Q_vps, V_vps, phi_vps = train_vps_options(
                    T,
                    free_idx,
                    k_base=k_base,
                    sign=args.sign,
                    gamma=args.gamma,
                    alpha=args.alpha,
                    collect_ep=args.collect_ep,
                    ep_len=args.ep_len,
                    sr_epochs=args.sr_epochs,
                    sr_lambda=args.sr_lambda,
                )
                total = Q_vps.shape[0]
                fn = f"gridworld_{total}_VPSOpt_{m}.npy"
                np.save(os.path.join(save_dir, fn), Q_vps)
                print("  saved", fn)

                if vis:
                    num_to_vis = min(4, total)
                    for k in range(num_to_vis):
                        vis.plot_2d_heatmap(
                            V_vps[k],
                            topk=32,
                            title=f"VPS-Option {k} Value Function (V)",
                            color_bar=True,
                            cmap_name="viridis",
                        )
                        vis.plot_2d_heatmap(
                            phi_vps[k],
                            topk=32,
                            title=f"VPS-Option {k} VPS Feature (φ)",
                            color_bar=True,
                            cmap_name="hot",
                        )
                        vis.plot_policy_arrows(np.argmax(Q_vps[k], 1), Q_vps[k], title=f"VPS-Option {k} Policy")

            # ---- Eigen -----------------------------------------------------
            if args.only in ("eigen", "all"):
                print(f"[Eigen] k={k_base} training {m} …")
                Q_eig, eig_vecs = train_eigen_options(
                    T,
                    free_idx,
                    N,
                    k_base=k_base,
                    sign=args.sign,
                    gamma=args.gamma,
                    alpha=args.alpha,
                    collect_ep=args.collect_ep,
                    ep_len=args.ep_len,
                )
                total = Q_eig.shape[0]
                fn = f"gridworld_{total}_EigenOpt_{m}.npy"
                np.save(os.path.join(save_dir, fn), Q_eig)
                print("  saved", fn)

                if vis:
                    num_to_vis = min(4, k_base)
                    for k in range(num_to_vis):
                        vis.plot_2d_heatmap(
                            eig_vecs[:, k],
                            topk=32,
                            title=f"Eigen-Option {k} Eigenvector (Potential)",
                            color_bar=True,
                            cmap_name="coolwarm",
                        )
                        opt_idx = k
                        if opt_idx < total:
                            Q_max = Q_eig[opt_idx].max(axis=1)
                            vis.plot_2d_heatmap(
                                Q_max,
                                topk=32,
                                title=f"Eigen-Option {k} Q-Value (max over actions)",
                                color_bar=True,
                                cmap_name="viridis",
                            )
                            vis.plot_policy_arrows(np.argmax(Q_eig[opt_idx], 1), Q_eig[opt_idx], title=f"Eigen-Option {k} Policy")
                        if args.sign:
                            opt_idx_neg = k + k_base
                            if opt_idx_neg < total:
                                Q_max_neg = Q_eig[opt_idx_neg].max(axis=1)
                                vis.plot_2d_heatmap(
                                    Q_max_neg,
                                    topk=32,
                                    title=f"Eigen-Option {k} (Negative) Q-Value",
                                    color_bar=True,
                                    cmap_name="viridis",
                                )
                                vis.plot_policy_arrows(
                                    np.argmax(Q_eig[opt_idx_neg], 1),
                                    Q_eig[opt_idx_neg],
                                    title=f"Eigen-Option {k} (Negative) Policy",
                                )

            # ---- Random ----------------------------------------------------
            if args.only in ("rand", "all"):
                print(f"[Random] k={k_base} training {m} …")
                total = k_base * (2 if args.sign else k_base)
                Q_rnd, phi_rnd = train_random_options(
                    T,
                    free_idx,
                    N,
                    k_base=total,
                    gamma=args.gamma,
                    alpha=args.alpha,
                    collect_ep=args.collect_ep,
                    ep_len=args.ep_len,
                )
                fn = f"gridworld_{total}_RandomOption_{m}.npy"
                np.save(os.path.join(save_dir, fn), Q_rnd)
                print("  saved", fn)

                if vis:
                    num_to_vis = min(4, total)
                    for k in range(num_to_vis):
                        vis.plot_2d_heatmap(
                            phi_rnd[k],
                            topk=32,
                            title=f"Random-Option {k} Potential Function (φ)",
                            color_bar=True,
                            cmap_name="coolwarm",
                        )
                        Q_max = Q_rnd[k].max(axis=1)
                        vis.plot_2d_heatmap(
                            Q_max,
                            topk=32,
                            title=f"Random-Option {k} Q-Value (max over actions)",
                            color_bar=True,
                            cmap_name="viridis",
                        )
                        vis.plot_policy_arrows(np.argmax(Q_rnd[k], 1), Q_rnd[k], title=f"Random-Option {k} Policy")


if __name__ == "__main__":
    main()

