#!/usr/bin/env python
"""Train Random/Eigen/VPS options in tabular discrete Gym environments.

Provides a unified CLI that:
- Samples a random-walk buffer.
- Trains three families of options (Random, Eigenoption, VPS) offline.
- Saves Q-tables to `option_results/<Env>/...npy` and can render GIFs.
"""
from __future__ import annotations
import argparse, random, warnings
from pathlib import Path
from typing import List, Tuple
import numpy as np
import scipy.sparse as sp
import os
from scipy.sparse.linalg import eigsh
import gymnasium as gym
import imageio.v2 as imageio

# ============================================================
#                   VPS-Option Helper
# ============================================================
class VPSOptionAgent:
    """
    Phase-0: collect a replay buffer → learn V and φ (|TD| proxy)
    Phase-1: fixed-step offline Q-learning under φ-shaped rewards
    """

    def __init__(
        self,
        env,
        *,
        k: int,
        sign: bool,
        gamma_v: float,
        gamma_q: float,
        lam: float = 0.9,
        alpha: float = 0.05,
        seed: int | None = None,
    ):
        if seed is not None:
            np.random.seed(seed)
            random.seed(seed)

        self.env = env
        self.S = env.observation_space.n
        self.A = env.action_space.n

        self.k = k
        self.sign = sign
        self.ktot = k * (2 if sign else 1)

        self.gamma_v, self.gamma_q = gamma_v, gamma_q
        self.lam, self.alpha = lam, alpha

        self.V = np.zeros((k, self.S), np.float32)
        self.phi = np.zeros_like(self.V)
        self.elig = np.zeros(self.S, np.float32)
        self.randR = np.random.randn(k, self.S).astype(np.float32)
        self.Qopt = np.zeros((self.ktot, self.S, self.A), np.float32)
        self.buffer: List[Tuple[int, int, int, bool]] = []

    # ------------ Phase-0: sample buffer ---------------------
    def collect(self, episodes: int = 1000, max_len: int = 200):
        for _ in range(episodes):
            s, _ = self.env.reset(seed=int(np.random.randint(1_000_000)))
            s = int(s)
            self.elig.fill(0.0)
            for _ in range(max_len):
                a = random.randrange(self.A)
                sn, _, term, trunc, _ = self.env.step(a)
                sn = int(sn)
                done = term or trunc

                # online TD update for V and φ
                self.elig *= self.gamma_v * self.lam
                self.elig[s] += 1
                td = self.randR[:, sn] + (0 if done else self.gamma_v * self.V[:, sn]) - self.V[:, s]
                td_v = ((0 if done else self.V[:, sn]) - self.V[:, s]) ** 2
                self.V += self.alpha * td[:, None] * self.elig
                self.phi[:, s] += 0.001 * (np.abs(td_v) - self.phi[:, s])

                self.buffer.append((s, a, sn, done))
                if done:
                    break
                s = sn

    # ------------ Phase-1: fixed-step Q-learning -------------
    def learn_options(self, total_steps: int = 200_000):
        if not self.buffer:
            raise RuntimeError("Buffer is empty; run collect() first.")

        k = self.k
        for _ in range(total_steps):
            s, a, sn, done = random.choice(self.buffer)
            r_vec = self.phi[:, sn] - self.phi[:, s]
            for i, r in enumerate(r_vec):
                tgt = r if done else r + self.gamma_q * self.Qopt[i, sn].max()
                self.Qopt[i, s, a] += self.alpha * (tgt - self.Qopt[i, s, a])

                if self.sign:
                    neg_i = i + k
                    r2 = -r
                    tgt2 = r2 if done else r2 + self.gamma_q * self.Qopt[neg_i, sn].max()
                    self.Qopt[neg_i, s, a] += self.alpha * (tgt2 - self.Qopt[neg_i, s, a])

    def train(self, phase1_eps: int, phase2_steps: int, max_len: int = 200):
        self.collect(phase1_eps, max_len)
        self.learn_options(total_steps=phase2_steps)
        return self.Qopt


# ============================================================
#               DiscreteOption  (Random / Eigen / VPS)
# ============================================================
class DiscreteOption:
    def __init__(
        self,
        env_id: str,
        num: int,
        sign: bool,
        gamma: float,
        alpha: float,
        seed: int,
        out_dir: str,
        idx: int,
    ):
        self.env_id, self.num, self.sign = env_id, num, sign
        self.gamma, self.alpha = gamma, alpha
        self.seed, self.idx = seed, idx
        self.out_dir = Path(out_dir)
        self.out_dir.mkdir(exist_ok=True)

        self.env = gym.make(env_id)
        self.S = self.env.observation_space.n
        self.A = self.env.action_space.n

        # Transition function for tabular envs (deterministic next state)
        self.T = np.asarray(
            [
                [self.env.unwrapped.P[s][a][0][1] for a in range(self.A)]
                for s in range(self.S)
            ],
            dtype=np.int32,
        )

        if seed is not None:
            np.random.seed(seed)
            random.seed(seed)

    # ---------- Phase-0 buffer sampling ----------------------
    def _collect_buffer(self, episodes: int, max_len: int = 200):
        buffer: list[tuple[int, int, int, bool]] = []
        visited: list[int] = []  # keep duplicates
        sampler = gym.make(self.env_id)

        for _ in range(episodes):
            s, _ = sampler.reset(seed=int(np.random.randint(1_000_000)))
            s = int(s)
            for _ in range(max_len):
                a = random.randrange(self.A)
                sn, _, term, trunc, _ = sampler.step(a)
                sn = int(sn)
                done = term or trunc

                buffer.append((s, a, sn, done))
                visited.append(s)
                visited.append(sn)
                s = sn
                if done:
                    break

        sampler.close()
        return buffer, visited  # list with duplicates

    # ---------- Random Option --------------------------------
    def train_random(
        self,
        collect_ep: int = 2000,
        ep_len: int = 200,
        train_steps: int = 400_000,
    ):
        """Random options with independent potentials φ ~ N(0,1).

        Shaped reward: r = φ(sn) − φ(s). If sign=True we double options
        by also using the negative potential.
        """
        # Phase-0: unified buffer
        buf, _ = self._collect_buffer(collect_ep, ep_len)

        k_all = self.num * 2 if self.sign else self.num
        rng = np.random.default_rng(self.seed)

        # random potentials (k_all, S)
        phi = rng.standard_normal((k_all, self.S)).astype(np.float32)

        # Phase-1: offline parallel Q-learning
        Q = np.zeros((k_all, self.S, self.A), np.float32)
        for _ in range(train_steps):
            s, a, sn, done = random.choice(buf)
            td = phi[:, sn] - phi[:, s]
            if not done:
                td += self.gamma * Q[:, sn].max(1)
            td -= Q[:, s, a]
            Q[:, s, a] += self.alpha * td

        path = self.out_dir / f"{self.env_id}_{k_all}_RandomOpt_{self.idx}.npy"
        np.save(path, Q)
        print(f"[✓] Save {path}  (K={k_all})")
        return Q

    # ---------- Eigen Option ---------------------------------
    def train_eigen(self, collect_ep: int, train_steps: int):
        buf, _ = self._collect_buffer(collect_ep)

        counts = np.zeros((self.S, self.S), np.int32)
        for s, _, sn, _ in buf:
            counts[s, sn] = 1
        A = sp.csr_matrix(counts)
        deg = np.ravel(A.sum(1))
        Dinv = sp.diags(1 / np.sqrt(np.maximum(deg, 1e-8)))
        L = sp.eye(self.S, dtype=np.float32) - Dinv @ A @ Dinv

        k_base = self.num
        _, vecs = eigsh(L, k=k_base + 1, which="SM")
        eig_vecs = vecs[:, 1:]

        total = k_base * (2 if self.sign else 1)
        Q = np.zeros((total, self.S, self.A), np.float32)

        opt_id = 0
        for phi_raw in eig_vecs.T:
            for sg in ([1, -1] if self.sign else [1]):
                if opt_id >= total:
                    break
                phi = sg * phi_raw
                for _ in range(train_steps):
                    s, a, sn, done = random.choice(buf)
                    r = phi[sn] - phi[s]
                    tgt = r if done else r + self.gamma * Q[opt_id, sn].max()
                    Q[opt_id, s, a] += self.alpha * (tgt - Q[opt_id, s, a])
                opt_id += 1
            if opt_id >= total:
                break

        path = self.out_dir / f"{self.env_id}_{total}_EigenOpt_{self.idx}.npy"
        np.save(path, Q)
        print(f"[✓] Save {path} (K={Q.shape[0]})")
        return Q

    # ---------- VPS Option -----------------------------------
    def train_vps(self, collect_ep: int, train_steps: int):
        agent = VPSOptionAgent(
            self.env,
            k=self.num,
            sign=self.sign,
            gamma_v=self.gamma,
            gamma_q=self.gamma,
            alpha=self.alpha,
            seed=self.seed,
        )
        Q = agent.train(phase1_eps=collect_ep, phase2_steps=train_steps)
        k_all = self.num * (2 if self.sign else 1)
        path = self.out_dir / f"{self.env_id}_{k_all}_VPSOpt_{self.idx}.npy"
        np.save(path, Q)
        print(f"[✓] Save {path} (K={Q.shape[0]})")
        return Q

    # ---------- GIF visualization ----------------------------
    @staticmethod
    def _terminated(qrow):
        return qrow.max() <= 0

    def save_gifs(self, Q: np.ndarray, method: str, steps: int = 60, fps: int = 2):
        gif_dir = self.out_dir / f"gifs_{method}_{self.idx}"
        gif_dir.mkdir(parents=True, exist_ok=True)
        policy = np.argmax(Q, 2)

        for oid in range(Q.shape[0]):
            env = gym.make(self.env_id, render_mode="rgb_array")
            s, _ = env.reset(seed=int(np.random.randint(1_000_000)))
            s = int(s)
            frames = []
            for _ in range(steps):
                frames.append(env.render())
                if self._terminated(Q[oid, s]):
                    break
                a = int(policy[oid, s])
                s, _, term, trunc, _ = env.step(a)
                s = int(s)
                if term or trunc:
                    break
            env.close()
            imageio.mimsave(gif_dir / f"option_{oid:02d}.gif", frames, fps=fps)


# ============================================================
#                            CLI
# ============================================================
def main():
    parser = argparse.ArgumentParser()
    g = parser.add_argument
    g("--env", default="Taxi-v3")  # e.g., FrozenLake-v1
    g("--num", type=int, default=10)
    g("--sign", type=str, default="true")
    g("--gamma", type=float, default=0.999)
    g("--alpha", type=float, default=0.05)
    g("--collect", type=int, default=1000,
      help="Phase-0 random-walk episodes for ALL options")
    g("--steps", type=int, default=1_000_000,
      help="Phase-1 total Q-learning updates for ALL options")
    g("--gif_len", type=int, default=60)
    g("--fps", type=int, default=2)
    g("--outer", type=int, default=1,
      help="repeat experiments this many times")
    g("--out_dir", default="option_results")
    args = parser.parse_args()

    sign_flag = args.sign.lower() in ("true", "1", "yes")
    save_dir = os.path.join(args.out_dir, args.env)
    for idx in range(args.outer):
        print(f"\n=====  Training group {idx}/{args.outer - 1} =====")
        opt = DiscreteOption(
            args.env,
            args.num,
            sign_flag,
            args.gamma,
            args.alpha,
            seed=idx,
            out_dir=save_dir,
            idx=idx,
        )

        # ---- Random Option ----
        Q_rnd = opt.train_random(args.collect, args.steps)
        opt.save_gifs(Q_rnd, "random", args.gif_len, args.fps)

        # ---- Eigen Option -----
        Q_eig = opt.train_eigen(args.collect, args.steps)
        opt.save_gifs(Q_eig, "eigen", args.gif_len, args.fps)

        # ---- VPS Option -------
        Q_vps = opt.train_vps(args.collect, args.steps)
        opt.save_gifs(Q_vps, "vps", args.gif_len, args.fps)


if __name__ == "__main__":
    warnings.filterwarnings(
        "ignore",
        category=UserWarning,
        message="Matplotlib is currently using agg",
    )
    main()
