"""ResNet architecture plugin."""

from __future__ import annotations

from torch import nn

from expected_gradcam.architectures.base import BaseArchitecturePlugin


class ResNetPlugin(BaseArchitecturePlugin):
    """Plugin for ResNet and ResNeXt architectures.

    Supports:
    - ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
    - ResNeXt50_32x4d, ResNeXt101_32x8d
    - Wide ResNet variants

    ResNet structure:
        conv1 -> bn1 -> relu -> maxpool ->
        layer1 -> layer2 -> layer3 -> layer4 ->
        avgpool -> fc

    Target layer: layer4 (last residual block)
    Feature channels: 512 (ResNet18/34) or 2048 (ResNet50+)
    """

    @classmethod
    def supports(cls, model: nn.Module) -> bool:
        """Check if model is ResNet-like."""
        name = model.__class__.__name__.lower()
        return any(arch in name for arch in ["resnet", "resnext", "wide_resnet"])

    @classmethod
    def _get_last_conv_layer(cls, model: nn.Module) -> nn.Module:
        """Get layer4 (last residual block)."""
        return model.layer4  # type: ignore

    @classmethod
    def _get_fc_layers(cls, model: nn.Module) -> nn.Module:
        """Get fully connected layer."""
        return model.fc  # type: ignore

    @classmethod
    def _get_feature_channels(cls, model: nn.Module) -> int:
        """Get number of channels in layer4.

        ResNet18/34: 512 channels
        ResNet50+: 2048 channels
        """
        # Get from fc layer input features
        fc = model.fc  # type: ignore
        if hasattr(fc, "in_features"):
            return fc.in_features
        elif isinstance(fc, nn.Sequential):
            # Handle case where fc might be wrapped
            for layer in fc:
                if hasattr(layer, "in_features"):
                    return layer.in_features
        # Fallback based on architecture name
        name = model.__class__.__name__.lower()
        if "18" in name or "34" in name:
            return 512
        return 2048
