from __future__ import annotations

from dataclasses import dataclass
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F


@dataclass
class SPDQNConfig:
    in_channels: int = 1
    num_actions: int = 4
    feat_dim: int = 1024  # d
    # Input image size (H,W) after preprocessing.
    input_hw: Tuple[int, int] = (84, 84)


class SPDQNNet(nn.Module):
    """
    Successor-Predecessor exploration network (function approximation).

    Shared encoder:
      φ̃(s) = Conv(s),  φ(s) = φ̃(s) / ||φ̃(s)||_2

    Heads:
      - Q head: q(s,a)
      - SF head: ψ(s,a) for all actions (action-conditional), output [B,A,d]
      - PF head: ξ(s) (state-only), output [B,d]
      - Recon head: ŝ_{t+1} = Deconv(φ(s)) (state-only)

    IMPORTANT: For SP exploration, SF/PF TD losses must not update the encoder.
    In practice we feed φ.detach() into SF/PF heads so gradients flow only into those heads.
    """

    def __init__(self, cfg: SPDQNConfig) -> None:
        super().__init__()
        self.cfg = cfg

        # Encoder conv stack (same as SR model)
        self.conv1 = nn.Conv2d(cfg.in_channels, 64, kernel_size=6, stride=2, padding=0)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=6, stride=2, padding=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=6, stride=2, padding=2)

        with torch.no_grad():
            h0, w0 = int(cfg.input_hw[0]), int(cfg.input_hw[1])
            x = torch.zeros(1, cfg.in_channels, h0, w0)
            x = F.relu(self.conv1(x))
            x = F.relu(self.conv2(x))
            x = F.relu(self.conv3(x))
            conv_h = int(x.size(2))
            conv_w = int(x.size(3))
            conv_out = int(x.numel() // x.size(0))

        self.conv_h = int(conv_h)
        self.conv_w = int(conv_w)
        self.conv_out = int(conv_out)
        self.fc_phi = nn.Linear(self.conv_out, cfg.feat_dim)

        # Q head
        self.q_head = nn.Linear(cfg.feat_dim, cfg.num_actions)

        # SF head: φ_detach -> ψ(s,·) flattened then reshape to [B,A,d]
        self.sf_fc1 = nn.Linear(cfg.feat_dim, 2 * cfg.feat_dim)
        self.sf_fc2 = nn.Linear(2 * cfg.feat_dim, cfg.num_actions * cfg.feat_dim)

        # PF head: φ_detach -> ξ(s)
        self.pf_fc1 = nn.Linear(cfg.feat_dim, 2 * cfg.feat_dim)
        self.pf_fc2 = nn.Linear(2 * cfg.feat_dim, cfg.feat_dim)

        # Recon head: φ -> 6400 -> deconv stack (state-only)
        self.rec_fc = nn.Linear(cfg.feat_dim, self.conv_out)  # 6400
        self.deconv1 = nn.ConvTranspose2d(64, 64, kernel_size=6, stride=2, padding=2)
        self.deconv2 = nn.ConvTranspose2d(64, 64, kernel_size=6, stride=2, padding=2)
        h2 = 4 * int(self.conv_h)
        w2 = 4 * int(self.conv_w)
        tgt_h, tgt_w = int(cfg.input_hw[0]), int(cfg.input_hw[1])
        k_h = int(tgt_h - 2 * (h2 - 1))
        k_w = int(tgt_w - 2 * (w2 - 1))
        if k_h <= 0 or k_w <= 0:
            raise ValueError(
                f"Invalid decoder kernel computed for input_hw={cfg.input_hw} from conv_hw={(self.conv_h, self.conv_w)}. "
                f"Got kernel={(k_h, k_w)}."
            )
        self.deconv3 = nn.ConvTranspose2d(64, 1, kernel_size=(k_h, k_w), stride=2, padding=0)

    @staticmethod
    def l2_normalize(x: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
        denom = torch.norm(x, p=2, dim=1, keepdim=True).clamp_min(eps)
        return x / denom

    def encode(self, obs: torch.Tensor) -> torch.Tensor:
        """obs: float in [0,1], shape [B,C,H,W] -> φ(s) [B,d] (L2-normalized)."""
        h = F.relu(self.conv1(obs))
        h = F.relu(self.conv2(h))
        h = F.relu(self.conv3(h))
        h = torch.flatten(h, 1)
        phi_tilde = F.relu(self.fc_phi(h))
        return self.l2_normalize(phi_tilde)

    def q(self, phi: torch.Tensor) -> torch.Tensor:
        return self.q_head(phi)

    def sf_all(self, phi_detach: torch.Tensor) -> torch.Tensor:
        z = F.relu(self.sf_fc1(phi_detach))
        flat = self.sf_fc2(z)
        return flat.view(phi_detach.size(0), self.cfg.num_actions, self.cfg.feat_dim)

    def pf(self, phi_detach: torch.Tensor) -> torch.Tensor:
        z = F.relu(self.pf_fc1(phi_detach))
        return F.relu(self.pf_fc2(z))

    def recon_next(self, phi: torch.Tensor) -> torch.Tensor:
        z = F.relu(self.rec_fc(phi))
        z = z.view(z.size(0), 64, int(self.conv_h), int(self.conv_w))
        z = F.relu(self.deconv1(z))
        z = F.relu(self.deconv2(z))
        z = torch.sigmoid(self.deconv3(z))
        return z

    def forward(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Returns:
          phi: [B,d]
          q: [B,A]
          psi_all: [B,A,d]  (from phi.detach())
          xi: [B,d]         (from phi.detach())
        """
        phi = self.encode(obs)
        q = self.q(phi)
        phi_d = phi.detach()
        psi_all = self.sf_all(phi_d)
        xi = self.pf(phi_d)
        return phi, q, psi_all, xi

