"""DenseNet architecture plugin."""

from __future__ import annotations

import torch
from torch import nn

from expected_gradcam.architectures.base import ClassifierHeadWrapper


class DenseNetHead(nn.Module):
    """Custom classifier head for DenseNet.

    DenseNet's features block includes norm5 (batch norm), so we need
    to apply ReLU before pooling.
    """

    def __init__(self, classifier: nn.Module) -> None:
        super().__init__()
        self.relu = nn.ReLU(inplace=False)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.classifier = classifier

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.relu(x)
        x = self.pool(x)
        x = torch.flatten(x, 1)
        return self.classifier(x)


class DenseNetPlugin:
    """Plugin for DenseNet architectures.

    Supports:
    - DenseNet121, DenseNet161, DenseNet169, DenseNet201

    DenseNet structure:
        features (includes norm5) -> relu -> pool -> classifier

    Target layer: features (includes final batch norm)
    Feature channels: Varies by architecture (1024, 2208, 1664, 1920)
    """

    @classmethod
    def supports(cls, model: nn.Module) -> bool:
        """Check if model is DenseNet-like."""
        return "densenet" in model.__class__.__name__.lower()

    @classmethod
    def get_target_layer(cls, model: nn.Module) -> nn.Module:
        """Get features block (includes norm5)."""
        return model.features  # type: ignore

    @classmethod
    def extract_classifier_head(
        cls,
        model: nn.Module,
        target_layer: nn.Module | None = None,
    ) -> nn.Module:
        """Extract classifier head with ReLU before pooling."""
        return DenseNetHead(model.classifier)  # type: ignore

    @classmethod
    def get_num_features(cls, model: nn.Module) -> int:
        """Get number of channels from classifier input.

        DenseNet121: 1024
        DenseNet161: 2208
        DenseNet169: 1664
        DenseNet201: 1920
        """
        classifier = model.classifier  # type: ignore
        if hasattr(classifier, "in_features"):
            return classifier.in_features
        # Fallback based on architecture
        name = model.__class__.__name__.lower()
        channel_map = {
            "121": 1024,
            "161": 2208,
            "169": 1664,
            "201": 1920,
        }
        for key, channels in channel_map.items():
            if key in name:
                return channels
        return 1024  # Default
