"""
FPVR Network Model (Successor Feature Network)
Reference: standard CNN encoder + successor-feature head design.
"""
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F


def conv_shape(input, kernel_size, stride, padding=0):
    return (input + 2 * padding - kernel_size) // stride + 1


class FPVRNetwork(nn.Module):
    """
    FPVR Network: Encoder → φ (with LayerNorm) → ψ head
    
    Future-Past Visitation Redundancy (FPVR) network that learns successor features ψ(s,a)
    for exploration. The network encodes states to features φ, then predicts successor
    features ψ(s,a) for each action.
    """
    
    def __init__(self, state_shape, n_actions, phi_dim=256):
        super(FPVRNetwork, self).__init__()
        self.state_shape = state_shape
        self.n_actions = n_actions
        self.phi_dim = int(phi_dim)
        
        c, w, h = state_shape
        
        # ========== Encoder (Nature CNN) ==========
        self.conv1 = nn.Conv2d(in_channels=c, out_channels=32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)
        
        conv1_out_w = conv_shape(w, 8, 4)
        conv1_out_h = conv_shape(h, 8, 4)
        conv2_out_w = conv_shape(conv1_out_w, 4, 2)
        conv2_out_h = conv_shape(conv1_out_h, 4, 2)
        conv3_out_w = conv_shape(conv2_out_w, 3, 1)
        conv3_out_h = conv_shape(conv2_out_h, 3, 1)
        flatten_size = conv3_out_w * conv3_out_h * 64
        
        self.fc1 = nn.Linear(in_features=flatten_size, out_features=256)
        
        # ========== φ branch with LayerNorm ==========
        self.fc2 = nn.Linear(in_features=256, out_features=self.phi_dim)
        #self.layer_norm = nn.LayerNorm(self.phi_dim)
        
        # ========== ψ head (successor features) ==========
        self.psi_fc1 = nn.Linear(in_features=self.phi_dim, out_features=2 * self.phi_dim)
        self.psi_fc2 = nn.Linear(in_features=2 * self.phi_dim, out_features=self.n_actions * self.phi_dim)
        
        # ========== Initialization ==========
        for layer in self.modules():
            if isinstance(layer, nn.Conv2d):
                nn.init.orthogonal_(layer.weight, gain=np.sqrt(2))
                layer.bias.data.zero_()
        
        nn.init.orthogonal_(self.fc1.weight, gain=np.sqrt(2))
        self.fc1.bias.data.zero_()
        nn.init.orthogonal_(self.fc2.weight, gain=np.sqrt(2))
        self.fc2.bias.data.zero_()
    
    def forward(self, state, phi_whitened=None):
        """
        Forward pass
        
        Args:
            state: [B, C, H, W] observations (uint8 0-255 or float 0-1)
            phi_whitened: [B, D] optional whitened features (for ψ head during training)
        Returns:
            phi_raw: [B, D] raw features after LayerNorm
            psi_all: [B, n_actions, D] successor features
        """
        # Encode to φ_raw with LayerNorm
        x = state.float() / 255.0 if state.dtype == torch.uint8 else state
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.contiguous().view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        # phi_raw = self.layer_norm(self.fc2(x))
        phi_raw = self.fc2(x)
        
        # ψ head: use whitened features if provided, else use raw phi
        phi_for_psi = phi_whitened if phi_whitened is not None else phi_raw
        z = F.relu(self.psi_fc1(phi_for_psi.detach()))
        psi_all = self.psi_fc2(z).view(z.size(0), self.n_actions, self.phi_dim)
        
        return phi_raw, psi_all


class QNetwork(nn.Module):
    """
    Simple DQN/Double-DQN head with Nature CNN encoder.
    
    Independent of the FPVR network; learns extrinsic-reward Q(s,a).
    """

    def __init__(self, state_shape, n_actions):
        super().__init__()
        c, w, h = state_shape

        self.conv1 = nn.Conv2d(in_channels=c, out_channels=32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)

        conv1_out_w = conv_shape(w, 8, 4)
        conv1_out_h = conv_shape(h, 8, 4)
        conv2_out_w = conv_shape(conv1_out_w, 4, 2)
        conv2_out_h = conv_shape(conv1_out_h, 4, 2)
        conv3_out_w = conv_shape(conv2_out_w, 3, 1)
        conv3_out_h = conv_shape(conv2_out_h, 3, 1)
        flatten_size = conv3_out_w * conv3_out_h * 64

        self.fc = nn.Linear(flatten_size, 512)
        self.head = nn.Linear(512, n_actions)

        # ===== Initialization =====
        # Conv/FC: standard Nature-CNN style orthogonal init.
        for layer in self.modules():
            if isinstance(layer, nn.Conv2d):
                nn.init.orthogonal_(layer.weight, gain=np.sqrt(2))
                layer.bias.data.zero_()
        nn.init.orthogonal_(self.fc.weight, gain=np.sqrt(2))
        self.fc.bias.data.zero_()

        # Q-head: FORCE near-zero initial Q(s,a).
        # This prevents random initial Q magnitudes from dominating FPVR redundancy-guided
        # exploration when reward clipping makes returns small (e.g., +/-1).
        nn.init.zeros_(self.head.weight)
        nn.init.zeros_(self.head.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # inputs expected as uint8 images stacked
        x = x.float() / 255.0 if x.dtype == torch.uint8 else x
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc(x))
        q = self.head(x)
        return q


class ICLRQNetwork(nn.Module):
    """
    ICLR-style Q network (Q branch only) following function_approximation.

    Architecture (per def_large_arch.py / ssr_dqn.py):
    - Conv1: out=64, k=6, s=2, pad=0
    - Conv2: out=64, k=6, s=2, pad=2
    - Conv3: out=64, k=6, s=2, pad=2
    - Flatten -> FC(1024) + ReLU
    - L2 normalize (feature dim)
    - Linear -> n_actions
    """

    def __init__(self, state_shape, n_actions: int):
        super().__init__()
        c, w, h = state_shape

        self.conv1 = nn.Conv2d(in_channels=c, out_channels=64, kernel_size=6, stride=2, padding=0)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=6, stride=2, padding=2)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=6, stride=2, padding=2)

        conv1_out_w = conv_shape(w, 6, 2, padding=0)
        conv1_out_h = conv_shape(h, 6, 2, padding=0)
        conv2_out_w = conv_shape(conv1_out_w, 6, 2, padding=2)
        conv2_out_h = conv_shape(conv1_out_h, 6, 2, padding=2)
        conv3_out_w = conv_shape(conv2_out_w, 6, 2, padding=2)
        conv3_out_h = conv_shape(conv2_out_h, 6, 2, padding=2)
        flatten_size = conv3_out_w * conv3_out_h * 64

        self.fc = nn.Linear(flatten_size, 1024)
        self.head = nn.Linear(1024, int(n_actions))

        # Match our repo's Q head init behavior to reduce confounds in ablations.
        for layer in self.modules():
            if isinstance(layer, nn.Conv2d):
                nn.init.orthogonal_(layer.weight, gain=np.sqrt(2))
                layer.bias.data.zero_()
        nn.init.orthogonal_(self.fc.weight, gain=np.sqrt(2))
        self.fc.bias.data.zero_()
        nn.init.zeros_(self.head.weight)
        nn.init.zeros_(self.head.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.float() / 255.0 if x.dtype == torch.uint8 else x
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc(x))
        x = F.normalize(x, p=2, dim=1)
        return self.head(x)


def make_q_network(state_shape, n_actions: int, q_net_type: str = "nature") -> nn.Module:
    """
    Factory for Q networks. Keeps checkpoints backward compatible:
    - missing q_net_type => default 'nature'
    """
    qnt = str(q_net_type or "nature").lower()
    if qnt == "iclr":
        return ICLRQNetwork(state_shape, int(n_actions))
    if qnt == "nature":
        return QNetwork(state_shape, int(n_actions))
    raise ValueError(f"Unknown q_net_type={q_net_type!r}. Expected one of: 'nature', 'iclr'.")