"""Protocol definitions for architecture plugins.

Defines the interface that all architecture plugins must implement.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Protocol, runtime_checkable


if TYPE_CHECKING:
    from torch import nn


@runtime_checkable
class ArchitecturePlugin(Protocol):
    """Protocol for architecture-specific classifier head extraction.

    Architecture plugins provide methods to:
    1. Detect if a model uses their architecture
    2. Get the recommended target layer
    3. Extract the classifier head
    4. Get the number of feature channels

    Example implementation:
        >>> class MyArchitecturePlugin:
        ...     @classmethod
        ...     def supports(cls, model: nn.Module) -> bool:
        ...         return "myarch" in model.__class__.__name__.lower()
        ...
        ...     @classmethod
        ...     def get_target_layer(cls, model: nn.Module) -> nn.Module:
        ...         return model.features[-1]
        ...
        ...     @classmethod
        ...     def extract_classifier_head(
        ...         cls, model: nn.Module, target_layer: nn.Module | None
        ...     ) -> nn.Module:
        ...         return model.classifier
        ...
        ...     @classmethod
        ...     def get_num_features(cls, model: nn.Module) -> int:
        ...         return model.features[-1].out_channels
    """

    @classmethod
    def supports(cls, model: "nn.Module") -> bool:
        """Check if this plugin supports the given model.

        Args:
            model: The model to check.

        Returns:
            True if this plugin can handle the model.
        """
        ...

    @classmethod
    def get_target_layer(cls, model: "nn.Module") -> "nn.Module":
        """Get the recommended target layer for GradCAM.

        Args:
            model: The model.

        Returns:
            The target layer (typically the last conv layer before pooling).
        """
        ...

    @classmethod
    def extract_classifier_head(
        cls,
        model: "nn.Module",
        target_layer: "nn.Module | None" = None,
    ) -> "nn.Module":
        """Extract the classifier head from the model.

        The classifier head should:
        1. Take feature maps [B, K, U, V] as input
        2. Apply global average pooling (or similar)
        3. Return class logits [B, num_classes]

        Args:
            model: The model.
            target_layer: Optional target layer (for validation).

        Returns:
            Classifier head module.
        """
        ...

    @classmethod
    def get_num_features(cls, model: "nn.Module") -> int:
        """Get the number of feature channels at the target layer.

        Args:
            model: The model.

        Returns:
            Number of feature channels K.
        """
        ...
