#!/usr/bin/env python
"""Train Random/Eigen/VPS options in tabular discrete Gym environments.

Provides a unified CLI that:
- Samples a random-walk buffer.
- Trains three families of options (Random, Eigenoption, VPS) offline.
- Saves Q-tables to `option_results/<Env>/...npy` and can render GIFs.
"""
from __future__ import annotations
import argparse, random, warnings
from pathlib import Path
from typing import List, Tuple
import numpy as np
import scipy.sparse as sp
import os
from scipy.sparse.linalg import eigsh
import gymnasium as gym
import imageio.v2 as imageio

# ============================================================
#                   VPS-Option Helper
# ============================================================
class VPSOptionAgent:
    """
    Phase-0: collect a replay buffer → learn V and φ (|TD| proxy)
    Phase-1: fixed-step offline Q-learning under φ-shaped rewards
    """

    def __init__(
        self,
        env,
        *,
        k: int,
        sign: bool,
        gamma_v: float,
        gamma_q: float,
        lam: float = 0.9,
        alpha: float = 0.05,
        seed: int | None = None,
    ):
        """
        Gridworld-aligned VPS option trainer for tabular Gym envs.

        Logic mirrors `gridworld_options.train_vps_options`:
          1) Build k orthogonal reward weights via QR over one-hot states.
          2) Learn a tabular SR ψ(s) from a random-walk buffer.
          3) Compute V_i(s) = w_i^T ψ(s) and VPS features
             φ_i(s) ≈ E[(V_i(s') − V_i(s))^2 | s].
          4) Offline Q-learning on intrinsic rewards r_i = φ_i(s') − φ_i(s),
             with optional sign-doubling.
        """
        if seed is not None:
            np.random.seed(seed)
            random.seed(seed)

        self.env = env
        self.S = env.observation_space.n
        self.A = env.action_space.n

        self.k = k
        self.sign = sign
        self.ktot = k * (2 if sign else 1)

        # Store separate discounts for SR (gamma_v) and Q-learning (gamma_q).
        self.gamma_v = float(gamma_v)
        self.gamma_q = float(gamma_q)
        # Backwards-compatible alias used in Q-updates below.
        self.gamma = self.gamma_q
        self.alpha = float(alpha)

        # QR-based orthogonal reward weights over states (fixed per agent).
        randR_raw = np.random.randn(self.k, self.S).astype(np.float32)
        Q_mat, _ = np.linalg.qr(randR_raw.T)              # (S, k)
        self.reward_weights = Q_mat.T.astype(np.float32)  # (k, S)

        # Online SR ψ(s,·) and VPS feature φ_i(s).
        self.psi = np.zeros((self.S, self.S), dtype=np.float32)
        self.phi = np.zeros((self.k, self.S), dtype=np.float32)
        self.terminal_mask = np.zeros(self.S, dtype=bool)

        self.Qopt = np.zeros((self.ktot, self.S, self.A), np.float32)

    # ------------ Phase-0: sample buffer (random walk) -------
    def _collect_buffer(self, episodes: int, max_len: int) -> List[Tuple[int, int, int, bool]]:
        """Collect (s, a, sn, done) under a random policy and update SR/VPS online."""
        buffer: List[Tuple[int, int, int, bool]] = []

        # Reset SR/VPS statistics for this training run.
        self.psi.fill(0.0)
        self.phi.fill(0.0)
        self.terminal_mask.fill(False)

        for _ in range(episodes):
            s, _ = self.env.reset(seed=int(np.random.randint(1_000_000)))
            s = int(s)
            for _ in range(max_len):
                a = random.randrange(self.A)
                sn, _, term, trunc, _ = self.env.step(a)
                sn = int(sn)
                # Algorithmic "terminal" only uses environment termination,
                # not time-limit truncation; truncation is only for cutting
                # off episodes in data collection.
                done_alg = bool(term)
                done_env = bool(term or trunc)

                # ----- Online SR update ψ(s,·) -----------------------
                if done_alg:
                    # Mark terminal next-state; terminal SR rows stay zero.
                    self.terminal_mask[sn] = True
                    # Terminal transition: target e(s) only.
                    delta = -self.psi[s]
                    delta[s] += 1.0
                else:
                    # Continuing: e(s) + γ_v ψ(sn) − ψ(s)
                    delta = self.gamma_v * self.psi[sn] - self.psi[s]
                    delta[s] += 1.0
                self.psi[s] += self.alpha * delta

                # ----- Online VPS feature update φ_i(s) --------------
                # V_i(s) = w_i^T ψ(s)
                V_s = self.reward_weights @ self.psi[s]          # (k,)
                if done_alg or self.terminal_mask[sn]:
                    V_sn = np.zeros_like(V_s)
                else:
                    V_sn = self.reward_weights @ self.psi[sn]    # (k,)
                dV = V_sn - V_s
                td2 = dV ** 2
                # Exponential moving average towards |dV|^2 (EMA-style VPS).
                # This mirrors the old implementation's behavior while using
                # SR-derived value differences.
                self.phi[:, s] += 0.001 * (np.abs(td2) - self.phi[:, s])

                # Buffer stores algorithmic done flag (true only for env terminal).
                buffer.append((s, a, sn, done_alg))
                s = sn
                if done_env:
                    break

        # Enforce VPS value 0 at terminal states (no successors by definition).
        # if self.terminal_mask.any():
        #     self.phi[:, self.terminal_mask] = 0.0

        return buffer

    # ------------ Phase-1: offline Q-learning -----------------
    def _train_from_buffer(self, buffer: List[Tuple[int, int, int, bool]], total_steps: int):
        if not buffer:
            raise RuntimeError("Buffer is empty; nothing to train on.")

        n_states = self.S
        k_base = self.k
        gamma = self.gamma
        alpha = self.alpha

        # Offline Q-learning with intrinsic rewards r_i = φ_i(sn) − φ_i(s),
        # where φ_i(s) has been built online during buffer collection.
        phi_base = self.phi
        total_opts = k_base * (2 if self.sign else 1)
        Q = np.zeros((total_opts, n_states, self.A), np.float32)

        # Use random.choice over the buffer for Q-learning updates, mirroring
        # the old implementation's training scheme.
        for _ in range(total_steps):
            s, a, sn, done = random.choice(buffer)
            s = int(s)
            a = int(a)
            sn = int(sn)
            r_vec = phi_base[:, sn] - phi_base[:, s]
            for i, r in enumerate(r_vec):
                pos = i
                Q_sa = Q[pos, s, a]
                if done:
                    target = r
                else:
                    target = r + gamma * Q[pos, sn].max()
                td = target - Q_sa
                Q[pos, s, a] = Q_sa + alpha * td
                if self.sign:
                    neg = i + k_base
                    Q_sa_neg = Q[neg, s, a]
                    r_neg = -r
                    if done:
                        target_neg = r_neg
                    else:
                        target_neg = r_neg + gamma * Q[neg, sn].max()
                    td_neg = target_neg - Q_sa_neg
                    Q[neg, s, a] = Q_sa_neg + alpha * td_neg

        self.Qopt = Q

    def train(self, phase1_eps: int, phase2_steps: int, max_len: int = 200):
        """
        phase1_eps  : #random-walk episodes used to build the buffer.
        phase2_steps: approximate total number of Q-updates; we realize this
                      as several full passes over the buffer, mirroring the
                      gridworld offline Q-learning but allowing a similar
                      training-budget hyperparameter.
        """
        buffer = self._collect_buffer(phase1_eps, max_len)
        self._train_from_buffer(buffer, total_steps=phase2_steps)
        return self.Qopt


# ============================================================
#               DiscreteOption  (Random / Eigen / VPS)
# ============================================================
class DiscreteOption:
    def __init__(
        self,
        env_id: str,
        num: int,
        sign: bool,
        gamma: float,
        alpha: float,
        seed: int,
        out_dir: str,
        idx: int,
    ):
        self.env_id, self.num, self.sign = env_id, num, sign
        self.gamma, self.alpha = gamma, alpha
        self.seed, self.idx = seed, idx
        self.out_dir = Path(out_dir)
        self.out_dir.mkdir(exist_ok=True)

        self.env = gym.make(env_id)
        self.S = self.env.observation_space.n
        self.A = self.env.action_space.n

        # Transition function for tabular envs (deterministic next state)
        self.T = np.asarray(
            [
                [self.env.unwrapped.P[s][a][0][1] for a in range(self.A)]
                for s in range(self.S)
            ],
            dtype=np.int32,
        )

        if seed is not None:
            np.random.seed(seed)
            random.seed(seed)

    # ---------- Phase-0 buffer sampling ----------------------
    def _collect_buffer(self, episodes: int, max_len: int = 200):
        buffer: list[tuple[int, int, int, bool]] = []
        visited: list[int] = []  # keep duplicates
        sampler = gym.make(self.env_id)

        for _ in range(episodes):
            s, _ = sampler.reset(seed=int(np.random.randint(1_000_000)))
            s = int(s)
            for _ in range(max_len):
                a = random.randrange(self.A)
                sn, _, term, trunc, _ = sampler.step(a)
                sn = int(sn)
                # Algorithmic "done" only uses true env termination;
                # truncation is only used to end the rollout loop.
                done_alg = bool(term)
                done_env = bool(term or trunc)

                buffer.append((s, a, sn, done_alg))
                visited.append(s)
                visited.append(sn)
                s = sn
                if done_env:
                    break

        sampler.close()
        return buffer, visited  # list with duplicates

    # ---------- Random Option --------------------------------
    def train_random(
        self,
        collect_ep: int = 2000,
        ep_len: int = 200,
        train_steps: int = 400_000,
    ):
        """Random options with independent potentials φ ~ N(0,1).

        Shaped reward: r = φ(sn) − φ(s). If sign=True we double options
        by also using the negative potential.
        """
        # Phase-0: unified buffer
        buf, _ = self._collect_buffer(collect_ep, ep_len)

        k_all = self.num * 2 if self.sign else self.num
        rng = np.random.default_rng(self.seed)

        # random potentials (k_all, S)
        phi = rng.standard_normal((k_all, self.S)).astype(np.float32)

        # Phase-1: offline parallel Q-learning
        Q = np.zeros((k_all, self.S, self.A), np.float32)
        for _ in range(train_steps):
            s, a, sn, done = random.choice(buf)
            td = phi[:, sn] - phi[:, s]
            if not done:
                td += self.gamma * Q[:, sn].max(1)
            td -= Q[:, s, a]
            Q[:, s, a] += self.alpha * td

        path = self.out_dir / f"{self.env_id}_{k_all}_RandomOpt_{self.idx}.npy"
        np.save(path, Q)
        print(f"[✓] Save {path}  (K={k_all})")
        return Q

    # ---------- Eigen Option ---------------------------------
    def train_eigen(self, collect_ep: int, train_steps: int):
        buf, _ = self._collect_buffer(collect_ep)

        counts = np.zeros((self.S, self.S), np.int32)
        for s, _, sn, _ in buf:
            counts[s, sn] = 1
        A = sp.csr_matrix(counts)
        deg = np.ravel(A.sum(1))
        Dinv = sp.diags(1 / np.sqrt(np.maximum(deg, 1e-8)))
        L = sp.eye(self.S, dtype=np.float32) - Dinv @ A @ Dinv

        k_base = self.num
        _, vecs = eigsh(L, k=k_base + 1, which="SM")
        eig_vecs = vecs[:, 1:]

        total = k_base * (2 if self.sign else 1)
        Q = np.zeros((total, self.S, self.A), np.float32)

        opt_id = 0
        for phi_raw in eig_vecs.T:
            for sg in ([1, -1] if self.sign else [1]):
                if opt_id >= total:
                    break
                phi = sg * phi_raw
                for _ in range(train_steps):
                    s, a, sn, done = random.choice(buf)
                    r = phi[sn] - phi[s]
                    tgt = r if done else r + self.gamma * Q[opt_id, sn].max()
                    Q[opt_id, s, a] += self.alpha * (tgt - Q[opt_id, s, a])
                opt_id += 1
            if opt_id >= total:
                break

        path = self.out_dir / f"{self.env_id}_{total}_EigenOpt_{self.idx}.npy"
        np.save(path, Q)
        print(f"[✓] Save {path} (K={Q.shape[0]})")
        return Q

    # ---------- VPS Option -----------------------------------
    def train_vps(self, collect_ep: int, train_steps: int):
        agent = VPSOptionAgent(
            self.env,
            k=self.num,
            sign=self.sign,
            gamma_v=self.gamma,
            gamma_q=self.gamma,
            alpha=self.alpha,
            seed=self.seed,
        )
        Q = agent.train(phase1_eps=collect_ep, phase2_steps=train_steps)
        k_all = self.num * (2 if self.sign else 1)
        path = self.out_dir / f"{self.env_id}_{k_all}_VPSOpt_{self.idx}.npy"
        np.save(path, Q)
        print(f"[✓] Save {path} (K={Q.shape[0]})")
        return Q

    # ---------- GIF visualization ----------------------------
    @staticmethod
    def _terminated(qrow):
        return qrow.max() <= 0

    def save_gifs(self, Q: np.ndarray, method: str, steps: int = 60, fps: int = 2):
        gif_dir = self.out_dir / f"gifs_{method}_{self.idx}"
        gif_dir.mkdir(parents=True, exist_ok=True)
        policy = np.argmax(Q, 2)

        for oid in range(Q.shape[0]):
            env = gym.make(self.env_id, render_mode="rgb_array")
            s, _ = env.reset(seed=int(np.random.randint(1_000_000)))
            s = int(s)
            frames = []
            for _ in range(steps):
                frames.append(env.render())
                if self._terminated(Q[oid, s]):
                    break
                a = int(policy[oid, s])
                s, _, term, trunc, _ = env.step(a)
                s = int(s)
                if term or trunc:
                    break
            env.close()
            imageio.mimsave(gif_dir / f"option_{oid:02d}.gif", frames, fps=fps)


# ============================================================
#                            CLI
# ============================================================
def main():
    parser = argparse.ArgumentParser()
    g = parser.add_argument
    g("--env", default="Taxi-v3")  # e.g., FrozenLake-v1
    g("--num", type=int, default=5)
    g("--sign", type=str, default="true")
    g("--gamma", type=float, default=0.999)
    g("--alpha", type=float, default=0.1)
    g("--collect", type=int, default=1000,
      help="Phase-0 random-walk episodes for ALL options")
    g("--steps", type=int, default=1_000_000,
      help="Phase-1 total Q-learning updates for ALL options")
    g("--gif_len", type=int, default=60)
    g("--fps", type=int, default=2)
    g("--outer", type=int, default=5,
      help="repeat experiments this many times")
    g("--out_dir", default="option_results")
    g("--only", choices=["random", "eigen", "vps", "all"], default="all",
      help="which option type(s) to train (default: all)")
    args = parser.parse_args()

    sign_flag = args.sign.lower() in ("true", "1", "yes")

    # Resolve option save directory relative to this script, so that
    # all option files live under `discrete_gym/option_results/<Env>/...`.
    script_dir = os.path.dirname(os.path.abspath(__file__))
    root_dir = os.path.join(script_dir, args.out_dir)
    save_dir = os.path.join(root_dir, args.env)
    os.makedirs(save_dir, exist_ok=True)
    for idx in range(args.outer):
        print(f"\n=====  Training group {idx}/{args.outer - 1} =====")
        opt = DiscreteOption(
            args.env,
            args.num,
            sign_flag,
            args.gamma,
            args.alpha,
            seed=idx,
            out_dir=save_dir,
            idx=idx,
        )

        # ---- Random Option ----
        if args.only in ("random", "all"):
            Q_rnd = opt.train_random(args.collect, args.steps)
            opt.save_gifs(Q_rnd, "random", args.gif_len, args.fps)

        # ---- Eigen Option -----
        if args.only in ("eigen", "all"):
            Q_eig = opt.train_eigen(args.collect, args.steps)
            opt.save_gifs(Q_eig, "eigen", args.gif_len, args.fps)

        # ---- VPS Option -------
        if args.only in ("vps", "all"):
            Q_vps = opt.train_vps(args.collect, args.steps)
            opt.save_gifs(Q_vps, "vps", args.gif_len, args.fps)


if __name__ == "__main__":
    warnings.filterwarnings(
        "ignore",
        category=UserWarning,
        message="Matplotlib is currently using agg",
    )
    main()
