import torch
from torch import nn
from torch.nn import functional as F


def create_cnn_encoder(obs_shape, out_dim=256):
    """Create a shared CNN encoder."""
    c, h, w = obs_shape
    conv1 = nn.Conv2d(c, 32, 8, 4)
    conv2 = nn.Conv2d(32, 64, 4, 2)
    conv3 = nn.Conv2d(64, 64, 3, 1)

    # Compute flattened dimension
    def conv_shape(x, k, s, p=0):
        return (x + 2 * p - k) // s + 1

    h1 = conv_shape(h, 8, 4)
    w1 = conv_shape(w, 8, 4)
    h2 = conv_shape(h1, 4, 2)
    w2 = conv_shape(w1, 4, 2)
    h3 = conv_shape(h2, 3, 1)
    w3 = conv_shape(w2, 3, 1)
    flat = h3 * w3 * 64

    fc = nn.Linear(flat, out_dim)
    return nn.ModuleList([conv1, conv2, conv3, fc]), (h3, w3)


class FPVRVisualModel(nn.Module):
    """FPVR visual model: CNN encoder (outputs φ_raw) + successor-feature head ψ.

    Future-Past Visitation Redundancy (FPVR) visual model that learns successor features
    ψ(s,a) for exploration. The model encodes visual observations to features φ_raw,
    and predicts successor features ψ(s,a) for each action.

    Important: this variant removes any separate `phi_head` and uses the CNN encoder output as φ_raw.
    Whitening φ̃ = ZCA(φ_raw) is applied in the agent.
    """

    def __init__(self, obs_shape, n_actions, phi_dim: int = 128, psi_dim: int | None = None,
                 env_size: int = None):
        super().__init__()
        psi_dim = psi_dim if psi_dim is not None else phi_dim
        self.phi_dim = int(phi_dim)
        self.psi_dim = int(psi_dim)
        self.n_actions = n_actions
        self.env_size = env_size

        # CNN encoder: outputs φ_raw with dimension phi_dim
        self.encoder, (_h3, _w3) = create_cnn_encoder(obs_shape, out_dim=phi_dim)
        # NOTE: keep encoder output as signed φ_raw (no LayerNorm by default).

        # Successor-feature head ψ (takes whitened φ̃)
        self.psi_head = nn.Sequential(
            nn.Linear(phi_dim, 2 * phi_dim),
            nn.ReLU(),
            nn.Linear(2 * phi_dim, n_actions * psi_dim)
        )

        self._init_weights()

    def _init_weights(self):
        """Orthogonal initialization."""
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.orthogonal_(m.weight, gain=1.0)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, obs):
        """Forward: obs -> φ_raw.

        Args:
            obs: (B, C, H, W) visual observation
        Returns:
            phi_raw: raw features φ_raw (whitening φ̃ = ZCA(φ_raw) is applied in the agent)
        """
        x = obs / 255.0
        for layer in self.encoder[:-1]:  # conv layers
            x = F.relu(layer(x))
        x = x.contiguous().view(x.size(0), -1)
        x = self.encoder[-1](x)  # fc layer (no ReLU; keep signed features for whitening)
        return x

