"""VGG architecture plugin."""

from __future__ import annotations

from torch import nn

from expected_gradcam.architectures.base import BaseArchitecturePlugin


class VGGPlugin(BaseArchitecturePlugin):
    """Plugin for VGG architectures.

    Supports:
    - VGG11, VGG13, VGG16, VGG19 (with and without batch norm)

    VGG structure:
        features (conv layers) -> avgpool -> classifier (fc layers)

    Target layer: Last conv layer in features
    Feature channels: 512
    Pool size: (7, 7) - VGG uses 7x7 spatial pooling
    """

    @classmethod
    def supports(cls, model: nn.Module) -> bool:
        """Check if model is VGG-like."""
        return "vgg" in model.__class__.__name__.lower()

    @classmethod
    def _get_last_conv_layer(cls, model: nn.Module) -> nn.Module:
        """Get last conv layer in features."""
        features = model.features  # type: ignore
        # Find last Conv2d layer
        for layer in reversed(list(features.children())):
            if isinstance(layer, nn.Conv2d):
                return layer
        # Return last layer if no Conv2d found
        return features[-1]

    @classmethod
    def _get_fc_layers(cls, model: nn.Module) -> nn.Module:
        """Get classifier layers."""
        return model.classifier  # type: ignore

    @classmethod
    def _get_feature_channels(cls, model: nn.Module) -> int:
        """VGG has 512 channels at final conv layer."""
        return 512

    @classmethod
    def _get_pool_size(cls) -> tuple[int, int]:
        """VGG uses 7x7 spatial pooling output."""
        return (7, 7)
