"""Model fixtures for testing Expected GradCAM.

Provides simple CNN models that can be used for testing CAM methods
without requiring large pretrained models.
"""

from __future__ import annotations

from typing import Literal

import torch
from torch import Tensor, nn


class SimpleCNN(nn.Module):
    """A minimal CNN for testing purposes.

    Architecture:
        - conv1: 3 -> 16 channels, 3x3 kernel + BN + ReLU + MaxPool
        - conv2: 16 -> 32 channels, 3x3 kernel + BN + ReLU + MaxPool
        - conv3: 32 -> 64 channels, 3x3 kernel + BN + ReLU (target layer)
        - pool: AdaptiveAvgPool2d(1)
        - fc: 64 -> num_classes

    Attributes:
        num_classes: Number of output classes.
        input_size: Expected input image size.
    """

    def __init__(
        self,
        num_classes: int = 10,
        input_size: int = 32,
        in_channels: int = 3,
    ) -> None:
        super().__init__()
        self.num_classes = num_classes
        self.input_size = input_size
        self.in_channels = in_channels

        # Feature extraction layers
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )

        # Pooling and classifier
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(64, num_classes)

    @property
    def target_layer(self) -> nn.Module:
        """Get the target layer for CAM methods."""
        return self.conv3

    @property
    def classifier_head(self) -> nn.Module:
        """Get the classifier head."""
        return self.fc

    @property
    def feature_dim(self) -> int:
        """Number of feature channels at target layer."""
        return 64

    def forward(self, x: Tensor) -> Tensor:
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

    def forward_features(self, x: Tensor) -> Tensor:
        """Forward pass returning features before pooling."""
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        return x

    def forward_from_features(self, features: Tensor) -> Tensor:
        """Forward pass from feature maps to output."""
        x = self.pool(features)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


class SmallResNet(nn.Module):
    """A small ResNet-like model for testing.

    Mimics ResNet architecture with layer1-4 naming convention
    for testing architecture detection and plugin system.
    """

    def __init__(
        self,
        num_classes: int = 10,
        in_channels: int = 3,
    ) -> None:
        super().__init__()
        self.num_classes = num_classes
        self.in_channels = in_channels

        # Initial conv (like ResNet stem)
        self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(2)

        # Residual blocks (layer1-4 naming like ResNet)
        self.layer1 = self._make_layer(32, 32)
        self.layer2 = self._make_layer(32, 64, stride=2)
        self.layer3 = self._make_layer(64, 128, stride=2)
        self.layer4 = self._make_layer(128, 256, stride=2)

        # Classifier
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(256, num_classes)

    def _make_layer(
        self,
        in_channels: int,
        out_channels: int,
        stride: int = 1,
    ) -> nn.Sequential:
        """Create a residual-like layer."""
        return nn.Sequential(
            nn.Conv2d(
                in_channels, out_channels, kernel_size=3, stride=stride, padding=1
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    @property
    def target_layer(self) -> nn.Module:
        """Get the target layer for CAM methods (layer4)."""
        return self.layer4

    @property
    def classifier_head(self) -> nn.Module:
        """Get the classifier head."""
        return self.fc

    @property
    def feature_dim(self) -> int:
        """Number of feature channels at target layer."""
        return 256

    def forward(self, x: Tensor) -> Tensor:
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


class SmallVGG(nn.Module):
    """A small VGG-like model for testing architecture detection."""

    def __init__(
        self,
        num_classes: int = 10,
        in_channels: int = 3,
    ) -> None:
        super().__init__()
        self.num_classes = num_classes

        # VGG-style feature extractor
        self.features = nn.Sequential(
            # Block 1
            nn.Conv2d(in_channels, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            # Block 2
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            # Block 3
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
        )

        self.avgpool = nn.AdaptiveAvgPool2d(1)

        # VGG-style classifier
        self.classifier = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, num_classes),
        )

    @property
    def target_layer(self) -> nn.Module:
        """Get the last conv layer."""
        return self.features[-2]  # Last Conv2d before MaxPool

    @property
    def feature_dim(self) -> int:
        """Number of feature channels at target layer."""
        return 128

    def forward(self, x: Tensor) -> Tensor:
        x = self.features(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x


def create_simple_cnn(
    num_classes: int = 10,
    input_size: int = 32,
    device: torch.device | str = "cpu",
    pretrained_like: bool = False,
) -> SimpleCNN:
    """Factory function to create SimpleCNN with options.

    Args:
        num_classes: Number of output classes.
        input_size: Expected input image size.
        device: Device to place model on.
        pretrained_like: If True, initialize weights to look more like
            trained weights (non-zero, varied activations).

    Returns:
        Initialized SimpleCNN model.
    """
    model = SimpleCNN(num_classes=num_classes, input_size=input_size)

    if pretrained_like:
        # Initialize with more realistic weights
        for m in model.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    model.eval()
    return model.to(device)


def create_small_resnet(
    num_classes: int = 10,
    device: torch.device | str = "cpu",
) -> SmallResNet:
    """Factory function to create SmallResNet.

    Args:
        num_classes: Number of output classes.
        device: Device to place model on.

    Returns:
        Initialized SmallResNet model.
    """
    model = SmallResNet(num_classes=num_classes)
    model.eval()
    return model.to(device)


def get_expected_feature_shape(
    model: SimpleCNN | SmallResNet,
    input_size: int,
) -> tuple[int, int, int]:
    """Get expected feature map shape for a model.

    Args:
        model: The model.
        input_size: Input image size.

    Returns:
        Tuple of (channels, height, width).
    """
    if isinstance(model, SimpleCNN):
        # After 2 MaxPool(2) layers
        spatial = input_size // 4
        return (64, spatial, spatial)
    elif isinstance(model, SmallResNet):
        # After maxpool and 3 stride-2 layers
        spatial = input_size // 16
        return (256, spatial, spatial)
    else:
        raise ValueError(f"Unknown model type: {type(model)}")
