"""MobileNet architecture plugin."""

from __future__ import annotations

from torch import nn

from expected_gradcam.architectures.base import BaseArchitecturePlugin


class MobileNetPlugin(BaseArchitecturePlugin):
    """Plugin for MobileNet architectures.

    Supports:
    - MobileNet V2
    - MobileNet V3 (Small, Large)

    MobileNet structure:
        features -> avgpool -> classifier

    Target layer: Last block in features
    Feature channels: 1280 (V2), 960/576 (V3)
    """

    @classmethod
    def supports(cls, model: nn.Module) -> bool:
        """Check if model is MobileNet-like."""
        return "mobilenet" in model.__class__.__name__.lower()

    @classmethod
    def _get_last_conv_layer(cls, model: nn.Module) -> nn.Module:
        """Get last feature block."""
        return model.features[-1]  # type: ignore

    @classmethod
    def _get_fc_layers(cls, model: nn.Module) -> nn.Module:
        """Get classifier layers."""
        return model.classifier  # type: ignore

    @classmethod
    def _get_feature_channels(cls, model: nn.Module) -> int:
        """Get number of channels.

        MobileNet V2: 1280
        MobileNet V3 Large: 960
        MobileNet V3 Small: 576
        """
        classifier = model.classifier  # type: ignore
        if isinstance(classifier, nn.Sequential):
            for layer in classifier:
                if hasattr(layer, "in_features"):
                    return layer.in_features
        if hasattr(classifier, "in_features"):
            return classifier.in_features
        # Fallback based on architecture
        name = model.__class__.__name__.lower()
        if "v3" in name and "small" in name:
            return 576
        elif "v3" in name:
            return 960
        return 1280  # MobileNet V2 default
