from __future__ import annotations
import collections, os, datetime as dt, numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
import gymnasium as gym
from torch.utils.tensorboard import SummaryWriter
from networks import AtariCNN, RFFLayer, ValueNet, VPSNet, DQNHead
from replay_buffer import Memory

# -----------------------------------------------------------
def env_name(env: gym.Env) -> str:
    """Return a readable env id for logging."""
    return getattr(getattr(env, "spec", None), "id", env.__class__.__name__)

# -----------------------------------------------------------
class ContinuousVPSAgent:
    """Three-stage trainer for VPS-based options on image observations.

    Stages:
      A. Learn V(s) under random Fourier feature rewards.
      B. Fit VPS φ(s) to the squared TD of V(s).
      C. Train per-option DQN heads with intrinsic returns from φ.
    """
    # --------------------------- init ----------------------
    def __init__(
        self,
        env: gym.Env,
        k_options: int = 8,
        gamma_v: float = 0.99,
        gamma_q: float = 0.9,
        device: str = "cuda:0",
        buffer_cap: int = 200_000,
        frame_stack_len: int = 4,
        batch_size: int = 64,
        max_episode_length: int = 500,
        n_step: int = 10,
    ):
        self.env  = env
        self.k    = k_options
        self.dev  = torch.device(device)
        self.g_v  = gamma_v
        self.g_q  = gamma_q
        self.F    = frame_stack_len
        self.B    = batch_size
        self.max_episode_length = max_episode_length
        self.n_step = n_step

        # ---------- TensorBoard ----------
        eid  = env_name(env)
        category = eid.split("/")[0]
        logdir = os.path.join("runs", category, dt.datetime.now().strftime("%Y%m%d-%H%M%S"))
        self.writer = SummaryWriter(logdir)

        # ---------- RFF (frozen) ----------
        self.backbone_rff = AtariCNN(in_channels=self.F).to(self.dev)
        for p in self.backbone_rff.parameters():
            p.requires_grad_(False)
        self.rff = RFFLayer(self.backbone_rff.flat_dim, k_options, seed=0, sigma=10).to(self.dev)
        # Random-reward normalization stats (computed after data collection).
        self.rand_reward_mean: torch.Tensor | None = None
        self.rand_reward_std: torch.Tensor | None = None

        # ---------- Shared backbone for Value/VPS ----------
        # We intentionally share the CNN encoder so that ValueNet and VPSNet
        # operate on the same representation. Training remains staged:
        # - Stage A trains backbone + value head
        # - Stage B freezes backbone + value head, and trains only the VPS head
        self.backbone_shared = AtariCNN(in_channels=self.F).to(self.dev)

        # Keep the old attribute names for compatibility with existing code.
        self.backbone_val = self.backbone_shared
        self.backbone_vps = self.backbone_shared

        # Heads
        self.val_net = ValueNet(k_options, self.backbone_val).to(self.dev)
        self.vps_net = VPSNet(k_options, self.backbone_vps).to(self.dev)

        # ---------- Option-DQN ----------
        in_ch_q = frame_stack_len
        self.backbone_q = AtariCNN(in_channels=in_ch_q).to(self.dev)
        n_act = env.action_space.n
        self.opt_heads = nn.ModuleList(
            [DQNHead(self.backbone_q.flat_dim, n_act).to(self.dev) for _ in range(k_options)]
        )

        # ---------- Optimizers ----------
        self.value_opt = torch.optim.Adam(
            list(self.backbone_val.parameters()) + list(self.val_net.head.parameters()),
            lr=1e-3,
        )

        self.vps_opt = torch.optim.Adam(
            # Only train the VPS head in Stage-B since backbone is shared and
            # frozen after Stage-A (see `freeze_value`).
            list(self.vps_net.head.parameters()),
            lr=1e-4,
        )

        self.q_opt = torch.optim.Adam(
            list(self.backbone_q.parameters()) +
            [p for h in self.opt_heads for p in h.parameters()],
            lr=1e-4,
        )

        # ---------- Replay Buffer ----------
        self.buffer = Memory(
            max_size=buffer_cap,
            storage_devices="cuda",
            target_device=self.dev,
            frame_stack_num=self.F,
        )

        self.frame_stack: collections.deque[torch.Tensor] | None = None
        self.step_A = self.step_B = self.step_C = 0

    # ------------------------ preprocessing ----------------
    @torch.no_grad()
    def _preprocess(self, frame: np.ndarray) -> torch.Tensor:
        """
        RGB → 84×84 gray-scale 0-1 tensor, shape (1,84,84).
        """
        import cv2
        g = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
        g = cv2.resize(g, (84, 84), interpolation=cv2.INTER_AREA)
        return torch.as_tensor(g, device=self.dev, dtype=torch.float32).unsqueeze(0) / 255.0

    # ----------------- utils: frame-stack -----------------
    def reset_frame_stack(self) -> None:
        """Clear the internal F-frame stack at episode/replay start."""
        self.frame_stack = None

    @torch.no_grad()
    def obs_to_state(self, obs: np.ndarray) -> torch.Tensor:
        """Convert Gymnasium RGB obs to stacked (F,84,84) tensor and update stack."""
        frame_1chw = self._preprocess(obs)       # (1,84,84)
        return self._update_stack(frame_1chw)    # (F,84,84)

    def _update_stack(self, f1chw: torch.Tensor) -> torch.Tensor:
        if self.frame_stack is None:
            self.frame_stack = collections.deque(maxlen=self.F)
            for _ in range(self.F):
                self.frame_stack.append(f1chw.clone())
        else:
            self.frame_stack.append(f1chw)
        return torch.cat(list(self.frame_stack), 0)

    # ------------------------- collect ---------------------
    @torch.no_grad()
    def collect(self):
        """Fill the replay buffer using a random policy (no learning).

        We also compute the RFF random reward for each transition *online*
        and store it in the replay buffer. After collection, we standardize
        the stored reward vectors (per option dimension) to have zero mean
        and unit variance across the buffer.
        """
        while not self.buffer.full:
            obs, _ = self.env.reset()
            self.frame_stack = None
            frame = self._preprocess(obs)          # (1,84,84)
            _ = self._update_stack(frame)          # init stack (F,84,84)
            for _ in range(self.max_episode_length):
                a = self.env.action_space.sample()
                nxt_obs, r, term, trunc, _ = self.env.step(a)
                nxt_frame = self._preprocess(nxt_obs)
                sp_stack = self._update_stack(nxt_frame)  # (F,84,84) for s'

                # RFF random reward for next stacked state.
                r_rand_vec = (
                    self.rff(self.backbone_rff(sp_stack.unsqueeze(0)))[0]
                    .detach()
                    .cpu()
                    .numpy()
                )

                self.buffer.store(frame.cpu().numpy(), a, r, r_rand_vec, int(term or trunc))
                frame = nxt_frame

                # Debug helper (disabled by default):
                # if r != 0:
                #     pil_img = TF.to_pil_image(frame.squeeze(0).cpu())  # 8-bit gray
                #     pil_img.show()  # or display(pil_img) in Jupyter

                if term or trunc:
                    # Store one extra frame to make s' valid for terminal transitions.
                    # The random reward associated with this dummy transition is not used.
                    self.buffer.store(
                        frame.cpu().numpy(),
                        0,
                        0.0,
                        np.zeros(self.k, dtype=np.float32),
                        int(term or trunc),
                    )
                    break

        # Standardize the stored random rewards after buffer is filled.
        self.rand_reward_mean, self.rand_reward_std = self.buffer.normalize_rand_rewards_()
        print(
            f"[Collect] rand reward normalized: "
            f"mean(abs)={float(self.rand_reward_mean.abs().mean()):.4f}, "
            f"std(mean)={float(self.rand_reward_std.mean()):.4f}"
        )

    # -------------------- Stage-A : Value ------------------
    def train_value(self):
        if self.buffer.size() < self.B:
            return None
        s, sp, _, _, r_rand, done = self.buffer.sample(self.B)

        v_s  = self.val_net.head(self.backbone_val(s))
        with torch.no_grad():
            v_sp = self.val_net.head(self.backbone_val(sp))
            target = r_rand + self.g_v * (1 - done.float()) * v_sp

        loss = F.smooth_l1_loss(v_s, target)
        self.value_opt.zero_grad()
        loss.backward()
        self.value_opt.step()
        self.writer.add_scalar("A/value_loss", loss.item(), self.step_A)
        self.step_A += 1
        return float(loss.item())

    # -------------------- Stage-B : VPS  -------------------
    def freeze_value(self):
        for p in self.backbone_val.parameters():
            p.requires_grad_(False)
        for p in self.val_net.head.parameters():
            p.requires_grad_(False)

    def train_vps(self):
        if self.buffer.size() < self.B:
            return None
        s, sp, _, reward, _, done = self.buffer.sample(self.B)

        with torch.no_grad():                                    # fixed V̂
            v_s  = self.val_net.head(self.backbone_val(s))
            v_sp = self.val_net.head(self.backbone_val(sp))
            td_abs = (1 - done.float()) * v_sp - v_s
            td_abs = td_abs ** 2

        # We intentionally replace the target to test how well the VPS network can fit.
        target = td_abs
        phi_s  = self.vps_net.head(self.backbone_vps(s))
        loss   = F.smooth_l1_loss(phi_s, target)

        self.vps_opt.zero_grad()
        loss.backward()
        self.vps_opt.step()
        self.writer.add_scalar("B/vps_loss", loss.item(), self.step_B)
        self.step_B += 1
        return float(loss.item())

    # -------------------- Stage-C : Options ---------------
    def freeze_vps(self):
        for p in self.backbone_vps.parameters():
            p.requires_grad_(False)
        for p in self.vps_net.head.parameters():
            p.requires_grad_(False)

    def _gamma_vec(self, n_step: int, device):
        """Cache γ^0 … γ^{n-1} on self to avoid rebuilding every call."""
        if getattr(self, "_cached_n", None) != n_step:
            self._cached_gamma = self.g_q ** torch.arange(n_step, device=device)
            self._cached_n = n_step
        return self._cached_gamma  # (n,)

    def train_options(self):
        if self.buffer.size() < self.B:
            return None

        # ------- n-step sampling -------
        S_seq, A_seq, _, _, D_seq, S_n = self.buffer.sample_nstep(self.B, self.n_step)

        # First state s_t
        S0 = S_seq[:, 0]
        A0 = A_seq[:, 0, 0].long()       # (B,)
        done_last = D_seq[:, -1, 0].float()  # (B,)

        # ------- φ(s) list for intrinsic rewards -------
        with torch.no_grad():
            Bn = self.B * self.n_step
            phi_seq = self.vps_net.head(
                self.backbone_vps(S_seq.reshape(Bn, *S_seq.shape[2:]))
            ).view(self.B, self.n_step, self.k)          # (B,n,k)
            phi_sn = self.vps_net.head(self.backbone_vps(S_n))      # (B,k)

        # r_int[j] = φ(s_{t+j+1}) - φ(s_{t+j})
        r_int_seq = torch.zeros_like(phi_seq)            # (B,n,k)
        r_int_seq[:, :-1] = phi_seq[:, 1:] - phi_seq[:, :-1]
        r_int_seq[:, -1]  = phi_sn - phi_seq[:, -1]

        # ------- n-step cumulative intrinsic return G_int -------
        gamma_vec = self._gamma_vec(self.n_step, device=self.dev)   # (n,)
        G_int = (r_int_seq * gamma_vec.view(1, self.n_step, 1)).sum(dim=1)  # (B,k)

        # ------- Forward Q(s_t) and Q(s_{t+n}) -------
        feat_s  = self.backbone_q(S0)        # (B,flat)
        feat_sn = self.backbone_q(S_n).detach()

        total = 0.0
        for i, head in enumerate(self.opt_heads):
            q      = head(feat_s).gather(1, A0.unsqueeze(1)).squeeze(1)   # (B,)
            q_next = head(feat_sn).max(1)[0]                              # (B,)

            # target_i = G_int_i + γ^n (1-d) Q_i(s_{t+n})
            target = G_int[:, i] + (self.g_q ** self.n_step) * (1 - done_last) * q_next
            total += F.mse_loss(q, target)

        self.q_opt.zero_grad()
        total.backward()
        self.q_opt.step()
        self.writer.add_scalar("C/option_q_loss", total.item(), self.step_C)
        self.step_C += 1
        return float(total.item())

    # ----------------------- pipeline ---------------------
    def train(
        self,
        value_iters: int = 10_000,
        vps_iters: int = 10_000,
        option_iters: int = 10_000,
        print_every: int = 50,
    ):
        print("[Collect]"); self.collect()

        print("[Stage-A] Value")
        for _ in range(value_iters):
            loss = self.train_value()
            if loss is not None and print_every > 0 and (self.step_A % print_every == 0):
                print(f"[Stage-A] iter={self.step_A:>7d}/{value_iters}  loss={loss:.6f}")

        self.freeze_value()

        print("[Stage-B] VPS")
        for _ in range(vps_iters):
            loss = self.train_vps()
            if loss is not None and print_every > 0 and (self.step_B % print_every == 0):
                print(f"[Stage-B] iter={self.step_B:>7d}/{vps_iters}  loss={loss:.6f}")

        self.freeze_vps()

        print("[Stage-C] Options")
        for _ in range(option_iters):
            loss = self.train_options()
            if loss is not None and print_every > 0 and (self.step_C % print_every == 0):
                print(f"[Stage-C] iter={self.step_C:>7d}/{option_iters}  loss={loss:.6f}")

    # ----------------------- save -------------------------
    def save_all(self, path: str):
        torch.save(
            {
                "k": self.k,
                "frame_stack_len": self.F,
                "shared_backbone_val_vps": True,
                "backbone_rff": self.backbone_rff.state_dict(),
                "rff": self.rff.state_dict(),
                "rff_scale": self.rff.scale,
                "rand_reward_norm_mean": self.rand_reward_mean,
                "rand_reward_norm_std": self.rand_reward_std,
                "backbone_val": self.backbone_val.state_dict(),
                "value_head": self.val_net.head.state_dict(),
                "backbone_vps": self.backbone_vps.state_dict(),
                "vps_head": self.vps_net.head.state_dict(),
                "backbone_q": self.backbone_q.state_dict(),
                "opt_heads": [h.state_dict() for h in self.opt_heads],
            },
            path,
        )
        print(f"[✓] saved → {path}")
