"""EfficientNet architecture plugin."""

from __future__ import annotations

from torch import nn

from expected_gradcam.architectures.base import BaseArchitecturePlugin, ClassifierHeadWrapper
from expected_gradcam.exceptions import ClassifierHeadExtractionError


class EfficientNetPlugin(BaseArchitecturePlugin):
    """Plugin for EfficientNet architectures.

    Supports:
    - EfficientNet B0-B7
    - EfficientNet V2 (S, M, L)

    EfficientNet structure (varies by implementation):
        features -> avgpool -> classifier

    Target layer: Last block in features
    Feature channels: Varies by architecture (1280-2560)
    """

    @classmethod
    def supports(cls, model: nn.Module) -> bool:
        """Check if model is EfficientNet-like."""
        return "efficientnet" in model.__class__.__name__.lower()

    @classmethod
    def _get_last_conv_layer(cls, model: nn.Module) -> nn.Module:
        """Get last feature block."""
        # Try different attribute names
        if hasattr(model, "features"):
            return model.features[-1]  # type: ignore
        elif hasattr(model, "_blocks"):
            return model._blocks[-1]  # type: ignore
        raise ClassifierHeadExtractionError(
            model.__class__.__name__,
            reason="Could not find features block",
        )

    @classmethod
    def _get_fc_layers(cls, model: nn.Module) -> nn.Module:
        """Get classifier layers."""
        # Try different attribute names used by various implementations
        for attr in ["classifier", "_fc", "fc", "head"]:
            if hasattr(model, attr):
                classifier = getattr(model, attr)
                if classifier is not None:
                    return classifier
        raise ClassifierHeadExtractionError(
            model.__class__.__name__,
            reason="Could not find classifier layer",
        )

    @classmethod
    def _get_feature_channels(cls, model: nn.Module) -> int:
        """Get number of channels.

        EfficientNet B0: 1280
        EfficientNet B1-B7: scales up
        EfficientNet V2: 1280-1536
        """
        try:
            fc = cls._get_fc_layers(model)
            if hasattr(fc, "in_features"):
                return fc.in_features
            elif isinstance(fc, nn.Sequential):
                for layer in fc:
                    if hasattr(layer, "in_features"):
                        return layer.in_features
        except Exception:
            pass
        # Fallback
        return 1280
