"""Neural components for VPS-style option learning.

- AtariCNN: Lightweight CNN backbone mapping 84×84 frames → feature vector.
- RFFLayer: Random Fourier features ϕ(h) = cos(Wh + b) (fixed after init).
- ValueNet: CNN + MLP head to predict k-dim value V_i(s).
- VPSNet:   CNN + MLP head to predict k-dim potential φ_i(s).
- DQNHead:  MLP head to output Q(s,a) for a given option head.
"""

from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import math


# --------------------------------------------------------------------------- #
#                       0.   Generic Atari-style CNN                          #
# --------------------------------------------------------------------------- #
class AtariCNN(nn.Module):
    """CNN architecture used in the DQN paper: 3×84×84 → flat feature vector."""
    def __init__(self, in_channels: int = 4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, 32, 8, stride=4), nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),          nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=1),          nn.ReLU(),
            nn.Flatten(),
        )
        # Compute flattened feature dimension
        with torch.no_grad():
            dummy = torch.zeros(1, in_channels, 84, 84)
            self.flat_dim = self.net(dummy).shape[1]

    def forward(self, x: Tensor) -> Tensor:          # (B, C, 84, 84)
        return self.net(x)                           # (B, flat_dim)


# --------------------------------------------------------------------------- #
#                       1. Fixed Random Fourier Feature Layer                 #
# --------------------------------------------------------------------------- #
class RFFLayer(nn.Module):
    """
    Input feature vector h ∈ ℝ^D, output k-dimensional random feature:
        φ_i(h) = √(2/k) · cos(w_i^T h + b_i),
    where   w_i ~ 𝒩(0, σ² I)  (σ can be 1),
            b_i ~ Uniform[0, 2π).
    All parameters are frozen after initialization → act as fixed basis
    functions.
    """
    def __init__(self, in_dim: int, k: int, sigma: float = 1.0, seed: int = 0):
        super().__init__()
        torch.manual_seed(seed)
        W = torch.randn(k, in_dim) * sigma           # (k, D)
        b = torch.rand(k) * 2 * math.pi              # (k,)
        self.register_buffer("W", W)
        self.register_buffer("b", b)
        self.scale = 1.0

    @torch.no_grad()
    def forward(self, h: Tensor) -> Tensor:          # (B, k)
        # h (B,D)  ->  (B,k)
        return self.scale * torch.cos(F.linear(h, self.W, self.b))


# --------------------------------------------------------------------------- #
#                2. ValueNet & VPSNet Using a Shared Backbone                 #
# --------------------------------------------------------------------------- #
class _DeepHead(nn.Module):
    """
    Generic three-layer MLP head:
        flat_dim → 512 → 256 → k
    ReLU non-linearity (kept for compatibility); weights use Kaiming-normal
    initialization.
    """
    def __init__(self, in_dim: int, k: int):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k)

        for m in (self.fc1, self.fc2, self.fc3):
            nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
            nn.init.zeros_(m.bias)

    def forward(self, feat: Tensor) -> Tensor:       # (B, in_dim) → (B, k)
        x = F.relu(self.fc1(feat))
        x = F.relu(self.fc2(x))
        return self.fc3(x)


class ValueNet(nn.Module):
    """Outputs k-dimensional V_i(s).  The backbone is passed in for sharing."""
    def __init__(self, k: int, backbone: AtariCNN):
        super().__init__()
        self.backbone = backbone
        self.head     = _DeepHead(backbone.flat_dim, k)

    def forward(self, x: Tensor) -> Tensor:          # (B, k)
        feat = self.backbone(x)
        return self.head(feat)


class VPSNet(nn.Module):
    """Outputs k-dimensional φ_i(s).  Backbone may or may not be shared."""
    def __init__(self, k: int, backbone: AtariCNN):
        super().__init__()
        self.backbone = backbone
        self.head     = _DeepHead(backbone.flat_dim, k)

    def forward(self, x: Tensor) -> Tensor:          # (B, k)
        feat = self.backbone(x)
        return self.head(feat)


# --------------------------------------------------------------------------- #
#                               3.  DQN Head                                  #
# --------------------------------------------------------------------------- #
class DQNHead(nn.Module):
    """
    Two-layer MLP:
        feat_dim → 512 → 256 → n_actions
    """
    def __init__(self, feat_dim: int, n_actions: int):
        super().__init__()
        self.q = nn.Sequential(
            nn.Linear(feat_dim, 512), nn.ReLU(),
            nn.Linear(512, 256),      nn.ReLU(),
            nn.Linear(256, n_actions)
        )

        for m in self.q:
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
                nn.init.zeros_(m.bias)

    def forward(self, feat: Tensor) -> Tensor:       # (B, A)
        return self.q(feat)


# --------------------------------------------------------------------------- #
# 4. Root Q-Network  (primitive actions + k options)                           #
# --------------------------------------------------------------------------- #
class RootQNet(nn.Module):
    """
    Two-layer MLP, outputs Q(s,a):
        feat_dim → 512 → 256 → n_prim + k_opt
    """
    def __init__(self, feat_dim: int, n_prim: int, k_opt: int):
        super().__init__()
        self.q = nn.Sequential(
            nn.Linear(feat_dim, 512), nn.ReLU(),
            nn.Linear(512, 256),      nn.ReLU(),
            nn.Linear(256, n_prim + k_opt)
        )
        for m in self.q:
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
                nn.init.zeros_(m.bias)

    def forward(self, feat: Tensor) -> Tensor:       # (B, n_prim + k_opt)
        return self.q(feat)
