from __future__ import annotations

from dataclasses import dataclass
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F


@dataclass
class SRDQNConfig:
    in_channels: int = 1
    num_actions: int = 4
    phi_dim: int = 1024
    # Input image size (H,W) after preprocessing.
    input_hw: Tuple[int, int] = (84, 84)


class SRDQNNet(nn.Module):
    """
    Shared encoder + three heads (Q, successor features, reconstruction).

    This matches the provided architecture diagram, but uses a **single-frame** input.

    Encoder φ(s): Conv(64,6x6,s2,p0) -> Conv(64,6x6,s2,p2) -> Conv(64,6x6,s2,p2) -> FC -> φ ∈ R^{phi_dim}
    Q head: q(φ) -> Q(s,·) ∈ R^{A}
    SF head (action-independent): ψ(φ_detach) -> ψ(s) ∈ R^{phi_dim}
    Recon head: predict next frame ŝ_{t+1} from (φ, a) using action-gated decoder + deconvs.
    """

    def __init__(self, cfg: SRDQNConfig) -> None:
        super().__init__()
        self.cfg = cfg

        # Encoder conv stack (matches the diagram / SFNet)
        self.conv1 = nn.Conv2d(cfg.in_channels, 64, kernel_size=6, stride=2, padding=0)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=6, stride=2, padding=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=6, stride=2, padding=2)

        # Infer conv output size for configured inputs
        with torch.no_grad():
            h0, w0 = int(cfg.input_hw[0]), int(cfg.input_hw[1])
            x = torch.zeros(1, cfg.in_channels, h0, w0)
            x = F.relu(self.conv1(x))
            x = F.relu(self.conv2(x))
            x = F.relu(self.conv3(x))
            conv_h = int(x.size(2))
            conv_w = int(x.size(3))
            conv_out = int(x.numel() // x.size(0))

        self.conv_h = int(conv_h)
        self.conv_w = int(conv_w)
        self.conv_out = int(conv_out)
        self.fc_phi = nn.Linear(self.conv_out, cfg.phi_dim)

        # Q head
        self.q_head = nn.Linear(cfg.phi_dim, cfg.num_actions)

        # Successor feature head ψ(s) (action-independent)
        self.psi_fc1 = nn.Linear(cfg.phi_dim, 2048)
        self.psi_fc2 = nn.Linear(2048, cfg.phi_dim)

        # Decoder (action-gated) for next-frame reconstruction
        self.state_gate_fc = nn.Linear(cfg.phi_dim, 2048)
        self.action_gate_fc = nn.Linear(cfg.num_actions, 2048)
        self.dec_fc1 = nn.Linear(2048, 2048)
        self.dec_fc2 = nn.Linear(2048, 1024)
        self.dec_fc3 = nn.Linear(1024, self.conv_out)  # -> 64*conv_h*conv_w

        self.deconv1 = nn.ConvTranspose2d(64, 64, kernel_size=6, stride=2, padding=2)
        self.deconv2 = nn.ConvTranspose2d(64, 64, kernel_size=6, stride=2, padding=2)
        # Final deconv is chosen to exactly match the input resolution (H,W).
        # After deconv1/deconv2, spatial size becomes 4*conv_{h,w}. A stride-2 deconv with padding=0 gives:
        # out = 2*(in-1) + kernel. So kernel = target - 2*(in-1).
        h2 = 4 * int(self.conv_h)
        w2 = 4 * int(self.conv_w)
        tgt_h, tgt_w = int(cfg.input_hw[0]), int(cfg.input_hw[1])
        k_h = int(tgt_h - 2 * (h2 - 1))
        k_w = int(tgt_w - 2 * (w2 - 1))
        if k_h <= 0 or k_w <= 0:
            raise ValueError(
                f"Invalid decoder kernel computed for input_hw={cfg.input_hw} from conv_hw={(self.conv_h, self.conv_w)}. "
                f"Got kernel={(k_h, k_w)}."
            )
        self.deconv3 = nn.ConvTranspose2d(64, 1, kernel_size=(k_h, k_w), stride=2, padding=0)

    @staticmethod
    def l2_normalize(x: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
        denom = torch.norm(x, p=2, dim=1, keepdim=True).clamp_min(eps)
        return x / denom

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """
        Encode observation to phi.
        x: float tensor in [0,1], shape [B, C, H, W]
        """
        h = F.relu(self.conv1(x))
        h = F.relu(self.conv2(h))
        h = F.relu(self.conv3(h))
        h = torch.flatten(h, 1)
        phi_raw = F.relu(self.fc_phi(h))
        return self.l2_normalize(phi_raw)

    def q(self, phi: torch.Tensor) -> torch.Tensor:
        return self.q_head(phi)

    def psi(self, phi_detached: torch.Tensor) -> torch.Tensor:
        z = F.relu(self.psi_fc1(phi_detached))
        return F.relu(self.psi_fc2(z))

    def decode_next(self, phi: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
        # action one-hot -> 2048
        a_onehot = F.one_hot(actions, num_classes=self.cfg.num_actions).float()
        a_2048 = F.relu(self.action_gate_fc(a_onehot))
        # state -> 2048
        s_2048 = F.relu(self.state_gate_fc(phi))
        # gating
        z = s_2048 * a_2048
        # project to conv feature map
        z = F.relu(self.dec_fc1(z))
        z = F.relu(self.dec_fc2(z))
        z = F.relu(self.dec_fc3(z))
        z = z.view(z.size(0), 64, int(self.conv_h), int(self.conv_w))
        z = F.relu(self.deconv1(z))
        z = F.relu(self.deconv2(z))
        z = torch.sigmoid(self.deconv3(z))
        return z

    def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Returns:
          phi: [B, D] (L2-normalized)
          psi: [B, D] (computed from phi.detach())
          recon_next: [B, 1, H, W]
          q: [B, A]
        """
        phi = self.encode(obs)
        psi = self.psi(phi.detach())
        recon_next = self.decode_next(phi, actions)
        q = self.q(phi)
        return phi, psi, recon_next, q

