from __future__ import annotations

try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
except Exception:  # pragma: no cover
    torch = None
    nn = None
    F = None

class TwoLayerMLP(nn.Module):
    def __init__(self, in_dim: int, hidden: int = 128, out_dim: int = 10):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, hidden)
        self.fc2 = nn.Linear(hidden, out_dim)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        return self.fc2(x)


class TwoHiddenMLP(nn.Module):
    def __init__(self, in_dim: int, hidden: int = 256, out_dim: int = 10):
        super().__init__()
        h2 = hidden
        self.fc1 = nn.Linear(in_dim, hidden)
        self.fc2 = nn.Linear(hidden, h2)
        self.fc3 = nn.Linear(h2, out_dim)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)


class LeNet(nn.Module):
    """LeNet-5 style CNN that adapts to different input sizes.
    
    For MNIST (28x28x1): after convs -> 4x4x16 = 256 features
    For CIFAR (32x32x3): after convs -> 5x5x16 = 400 features
    """
    def __init__(self, in_channels: int = 1, num_classes: int = 10, input_size: int = 28):
        super().__init__()
        self.in_channels = in_channels
        self.input_size = input_size
        
        self.conv1 = nn.Conv2d(in_channels, 6, kernel_size=5, padding=2)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        
        # Compute fc1 input size based on input dimensions
        # After conv1+pool: (input_size/2) 
        # After conv2+pool: ((input_size/2 - 4)/2) = (input_size - 8) / 4
        # For 28x28: (28-8)/4 = 5 -> but with padding=2 on conv1: 28/2=14, (14-4)/2=5 -> 5x5x16=400
        # For 32x32: 32/2=16, (16-4)/2=6 -> 6x6x16=576
        # Actually let's compute it properly:
        # conv1 with padding=2, kernel=5: output = input (same size due to padding)
        # pool: output = input/2
        # conv2 with no padding, kernel=5: output = input - 4
        # pool: output = input/2
        # So: ((input_size / 2) - 4) / 2 = (input_size - 8) / 4
        
        self._fc1_in = self._compute_fc1_size(input_size)
        self.fc1 = nn.Linear(self._fc1_in, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)
    
    def _compute_fc1_size(self, input_size: int) -> int:
        """Compute the flattened size after conv layers."""
        # conv1 (padding=2, kernel=5) keeps size same
        # pool: size // 2
        after_pool1 = input_size // 2
        # conv2 (no padding, kernel=5): size - 4
        after_conv2 = after_pool1 - 4
        # pool: size // 2
        after_pool2 = after_conv2 // 2
        # 16 channels
        return 16 * after_pool2 * after_pool2

    def forward(self, x):
        # If input is flat, reshape based on in_channels and input_size
        if x.dim() == 2:
            n = x.size(0)
            x = x.view(n, self.in_channels, self.input_size, self.input_size)
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)


class Autoencoder8Sigmoid(nn.Module):
    """8-layer autoencoder with sigmoid activations (4 encoder + 4 decoder).

    Designed for MNIST-like flattened inputs. Encoder/decoder widths are configurable.
    """

    def __init__(self, in_dim: int = 28 * 28, widths=(512, 256, 128, 64)):
        super().__init__()
        w1, w2, w3, w4 = widths
        # Encoder
        self.enc1 = nn.Linear(in_dim, w1)
        self.enc2 = nn.Linear(w1, w2)
        self.enc3 = nn.Linear(w2, w3)
        self.enc4 = nn.Linear(w3, w4)
        # Decoder
        self.dec1 = nn.Linear(w4, w3)
        self.dec2 = nn.Linear(w3, w2)
        self.dec3 = nn.Linear(w2, w1)
        self.dec4 = nn.Linear(w1, in_dim)

    def forward(self, x):
        # Accept (N,C,H,W) or (N, D)
        if x.dim() > 2:
            x = x.view(x.size(0), -1)
        z = torch.sigmoid(self.enc1(x))
        z = torch.sigmoid(self.enc2(z))
        z = torch.sigmoid(self.enc3(z))
        z = torch.sigmoid(self.enc4(z))
        y = torch.sigmoid(self.dec1(z))
        y = torch.sigmoid(self.dec2(y))
        y = torch.sigmoid(self.dec3(y))
        y = torch.sigmoid(self.dec4(y))
        return y


class TinyTransformerLM(nn.Module):
    """A minimal 2-layer Transformer language model (decoder-only) for small corpora.

    Not optimized; intended for quick experiments.
    """

    def __init__(self, vocab_size: int, d_model: int = 256, n_heads: int = 2, n_layers: int = 2, max_len: int = 256):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads, batch_first=True)
        self.enc = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.pos = nn.Parameter(torch.randn(1, max_len, d_model) * 0.01)
        self.proj = nn.Linear(d_model, vocab_size)
        self.max_len = max_len

    def forward(self, x):  # x: (B,T)
        B, T = x.shape
        T = min(T, self.max_len)
        x = x[:, :T]
        h = self.emb(x) + self.pos[:, :T, :]
        h = self.enc(h)
        logits = self.proj(h)
        return logits
