from __future__ import annotations
import collections, os, datetime as dt, numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
import gymnasium as gym
from torch.utils.tensorboard import SummaryWriter
from networks import AtariCNN, RFFLayer, ValueNet, VPSNet, DQNHead
from replay_buffer import Memory

# -----------------------------------------------------------
def env_name(env: gym.Env) -> str:
    """Return a readable env id for logging."""
    return getattr(getattr(env, "spec", None), "id", env.__class__.__name__)

# -----------------------------------------------------------
class ContinuousVPSAgent:
    """Three-stage trainer for VPS-based options on image observations.

    Stages:
      A. Learn V(s) under random Fourier feature rewards.
      B. Fit VPS φ(s) to the squared TD of V(s).
      C. Train per-option DQN heads with intrinsic returns from φ.
    """
    # --------------------------- init ----------------------
    def __init__(
        self,
        env: gym.Env,
        k_options: int = 8,
        gamma_v: float = 0.99,
        gamma_q: float = 0.9,
        device: str = "cuda:0",
        buffer_cap: int = 200_000,
        frame_stack_len: int = 4,
        batch_size: int = 64,
        max_episode_length: int = 500,
        n_step: int = 10,
    ):
        self.env  = env
        self.k    = k_options
        self.dev  = torch.device(device)
        self.g_v  = gamma_v
        self.g_q  = gamma_q
        self.F    = frame_stack_len
        self.B    = batch_size
        self.max_episode_length = max_episode_length
        self.n_step = n_step

        # ---------- TensorBoard ----------
        eid  = env_name(env)
        category = eid.split("/")[0]
        logdir = os.path.join("runs", category, dt.datetime.now().strftime("%Y%m%d-%H%M%S"))
        self.writer = SummaryWriter(logdir)

        # ---------- RFF (frozen) ----------
        self.backbone_rff = AtariCNN(in_channels=self.F).to(self.dev)
        for p in self.backbone_rff.parameters():
            p.requires_grad_(False)
        self.rff = RFFLayer(self.backbone_rff.flat_dim, k_options, seed=0, sigma=10).to(self.dev)

        # ---------- ValueNet ----------
        self.backbone_val = AtariCNN(in_channels=self.F).to(self.dev)
        self.val_net  = ValueNet(k_options, self.backbone_val).to(self.dev)

        # ---------- VPSNet (separate backbone) ----------
        self.backbone_vps = AtariCNN(in_channels=self.F).to(self.dev)
        self.vps_net  = VPSNet(k_options, self.backbone_vps).to(self.dev)

        # ---------- Option-DQN ----------
        in_ch_q = frame_stack_len
        self.backbone_q = AtariCNN(in_channels=in_ch_q).to(self.dev)
        n_act = env.action_space.n
        self.opt_heads = nn.ModuleList(
            [DQNHead(self.backbone_q.flat_dim, n_act).to(self.dev) for _ in range(k_options)]
        )

        # ---------- Optimizers ----------
        self.value_opt = torch.optim.Adam(
            list(self.backbone_val.parameters()) + list(self.val_net.head.parameters()),
            lr=1e-3,
        )

        self.vps_opt = torch.optim.Adam(
            list(self.backbone_vps.parameters()) + list(self.vps_net.head.parameters()),
            lr=1e-4,
        )

        self.q_opt = torch.optim.Adam(
            list(self.backbone_q.parameters()) +
            [p for h in self.opt_heads for p in h.parameters()],
            lr=1e-4,
        )

        # ---------- Replay Buffer ----------
        self.buffer = Memory(
            max_size=buffer_cap,
            storage_devices="cuda",
            target_device=self.dev,
            frame_stack_num=self.F,
        )

        self.frame_stack: collections.deque[torch.Tensor] | None = None
        self.step_A = self.step_B = self.step_C = 0

    # ------------------------ preprocessing ----------------
    @torch.no_grad()
    def _preprocess(self, frame: np.ndarray) -> torch.Tensor:
        """
        RGB → 84×84 gray-scale 0-1 tensor, shape (1,84,84).
        """
        import cv2
        g = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
        g = cv2.resize(g, (84, 84), interpolation=cv2.INTER_AREA)
        return torch.as_tensor(g, device=self.dev, dtype=torch.float32).unsqueeze(0) / 255.0

    # ----------------- utils: frame-stack -----------------
    def reset_frame_stack(self) -> None:
        """Clear the internal F-frame stack at episode/replay start."""
        self.frame_stack = None

    @torch.no_grad()
    def obs_to_state(self, obs: np.ndarray) -> torch.Tensor:
        """Convert Gymnasium RGB obs to stacked (F,84,84) tensor and update stack."""
        frame_1chw = self._preprocess(obs)       # (1,84,84)
        return self._update_stack(frame_1chw)    # (F,84,84)

    def _update_stack(self, f1chw: torch.Tensor) -> torch.Tensor:
        if self.frame_stack is None:
            self.frame_stack = collections.deque(maxlen=self.F)
            for _ in range(self.F):
                self.frame_stack.append(f1chw.clone())
        else:
            self.frame_stack.append(f1chw)
        return torch.cat(list(self.frame_stack), 0)

    # ------------------------- collect ---------------------
    @torch.no_grad()
    def collect(self):
        """Fill the replay buffer using a random policy (no learning)."""
        while not self.buffer.full:
            obs, _ = self.env.reset()
            self.frame_stack = None
            frame = self._preprocess(obs)
            for _ in range(self.max_episode_length):
                a = self.env.action_space.sample()
                nxt_obs, r, term, trunc, _ = self.env.step(a)
                r_rand_vec = np.random.normal(0.0, 0.001, size=self.k)
                self.buffer.store(frame.cpu().numpy(), a, r, r_rand_vec, int(term or trunc))
                frame = self._preprocess(nxt_obs)

                # Debug helper (disabled by default):
                # if r != 0:
                #     pil_img = TF.to_pil_image(frame.squeeze(0).cpu())  # 8-bit gray
                #     pil_img.show()  # or display(pil_img) in Jupyter

                if term or trunc:
                    self.buffer.store(frame.cpu().numpy(), 0, 0.0, r_rand_vec, int(term or trunc))
                    break

    # -------------------- Stage-A : Value ------------------
    def train_value(self):
        if self.buffer.size() < self.B:
            return
        s, sp, _, reward, _, done = self.buffer.sample(self.B)
        with torch.no_grad():
            diff   = sp - s
            r_rand = self.rff(self.backbone_rff(sp))          # (B,k)

        v_s  = self.val_net.head(self.backbone_val(s))
        with torch.no_grad():
            v_sp = self.val_net.head(self.backbone_val(sp))
            target = r_rand + self.g_v * (1 - done.float()) * v_sp

        loss = F.smooth_l1_loss(v_s, target)
        self.value_opt.zero_grad()
        loss.backward()
        self.value_opt.step()
        self.writer.add_scalar("A/value_loss", loss.item(), self.step_A)
        self.step_A += 1

    # -------------------- Stage-B : VPS  -------------------
    def freeze_value(self):
        for p in self.backbone_val.parameters():
            p.requires_grad_(False)
        for p in self.val_net.head.parameters():
            p.requires_grad_(False)

    def train_vps(self):
        if self.buffer.size() < self.B:
            return
        s, sp, _, reward, _, done = self.buffer.sample(self.B)

        with torch.no_grad():                                    # fixed V̂
            v_s  = self.val_net.head(self.backbone_val(s))
            v_sp = self.val_net.head(self.backbone_val(sp))
            td_abs = (1 - done.float()) * v_sp - v_s
            td_abs = td_abs ** 2

        # We intentionally replace the target to test how well the VPS network can fit.
        target = td_abs
        phi_s  = self.vps_net.head(self.backbone_vps(s))
        loss   = F.smooth_l1_loss(phi_s, target)

        self.vps_opt.zero_grad()
        loss.backward()
        self.vps_opt.step()
        self.writer.add_scalar("B/vps_loss", loss.item(), self.step_B)
        self.step_B += 1

    # -------------------- Stage-C : Options ---------------
    def freeze_vps(self):
        for p in self.backbone_vps.parameters():
            p.requires_grad_(False)
        for p in self.vps_net.head.parameters():
            p.requires_grad_(False)

    def _gamma_vec(self, n_step: int, device):
        """Cache γ^0 … γ^{n-1} on self to avoid rebuilding every call."""
        if getattr(self, "_cached_n", None) != n_step:
            self._cached_gamma = self.g_q ** torch.arange(n_step, device=device)
            self._cached_n = n_step
        return self._cached_gamma  # (n,)

    def train_options(self):
        if self.buffer.size() < self.B:
            return

        # ------- n-step sampling -------
        S_seq, A_seq, _, _, D_seq, S_n = self.buffer.sample_nstep(self.B, self.n_step)

        # First state s_t
        S0 = S_seq[:, 0]
        A0 = A_seq[:, 0, 0].long()       # (B,)
        done_last = D_seq[:, -1, 0].float()  # (B,)

        # ------- φ(s) list for intrinsic rewards -------
        with torch.no_grad():
            Bn = self.B * self.n_step
            phi_seq = self.vps_net.head(
                self.backbone_vps(S_seq.reshape(Bn, *S_seq.shape[2:]))
            ).view(self.B, self.n_step, self.k)          # (B,n,k)
            phi_sn = self.vps_net.head(self.backbone_vps(S_n))      # (B,k)

        # r_int[j] = φ(s_{t+j+1}) - φ(s_{t+j})
        r_int_seq = torch.zeros_like(phi_seq)            # (B,n,k)
        r_int_seq[:, :-1] = phi_seq[:, 1:] - phi_seq[:, :-1]
        r_int_seq[:, -1]  = phi_sn - phi_seq[:, -1]

        # ------- n-step cumulative intrinsic return G_int -------
        gamma_vec = self._gamma_vec(self.n_step, device=self.dev)   # (n,)
        G_int = (r_int_seq * gamma_vec.view(1, self.n_step, 1)).sum(dim=1)  # (B,k)

        # ------- Forward Q(s_t) and Q(s_{t+n}) -------
        feat_s  = self.backbone_q(S0)        # (B,flat)
        feat_sn = self.backbone_q(S_n).detach()

        total = 0.0
        for i, head in enumerate(self.opt_heads):
            q      = head(feat_s).gather(1, A0.unsqueeze(1)).squeeze(1)   # (B,)
            q_next = head(feat_sn).max(1)[0]                              # (B,)

            # target_i = G_int_i + γ^n (1-d) Q_i(s_{t+n})
            target = G_int[:, i] + (self.g_q ** self.n_step) * (1 - done_last) * q_next
            total += F.mse_loss(q, target)

        self.q_opt.zero_grad()
        total.backward()
        self.q_opt.step()
        self.writer.add_scalar("C/option_q_loss", total.item(), self.step_C)
        self.step_C += 1

    # ----------------------- pipeline ---------------------
    def train(
        self,
        value_iters: int = 10_000,
        vps_iters: int = 10_000,
        option_iters: int = 10_000,
    ):
        print("[Collect]"); self.collect()

        print("[Stage-A] Value")
        for _ in range(value_iters):
            self.train_value()

        self.freeze_value()

        print("[Stage-B] VPS")
        for _ in range(vps_iters):
            self.train_vps()

        self.freeze_vps()

        print("[Stage-C] Options")
        for _ in range(option_iters):
            self.train_options()

    # ----------------------- save -------------------------
    def save_all(self, path: str):
        torch.save(
            {
                "k": self.k,
                "frame_stack_len": self.F,
                "backbone_rff": self.backbone_rff.state_dict(),
                "rff": self.rff.state_dict(),
                "rff_scale": self.rff.scale,
                "backbone_val": self.backbone_val.state_dict(),
                "value_head": self.val_net.head.state_dict(),
                "backbone_vps": self.backbone_vps.state_dict(),
                "vps_head": self.vps_net.head.state_dict(),
                "backbone_q": self.backbone_q.state_dict(),
                "opt_heads": [h.state_dict() for h in self.opt_heads],
            },
            path,
        )
        print(f"[✓] saved → {path}")
