"""Base classes for architecture plugins."""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

import torch
from torch import nn


if TYPE_CHECKING:
    from torch import Tensor


class ClassifierHeadWrapper(nn.Module):
    """Wrapper that takes scaled feature maps and produces class logits.

    Handles:
    - Global average pooling (if needed)
    - Flattening
    - Fully connected layers

    Attributes:
        layers: Sequential container of layers after pooling.
        needs_pooling: Whether to apply adaptive average pooling.
        pool_size: Output size for adaptive pooling.
    """

    def __init__(
        self,
        layers: nn.Module,
        needs_pooling: bool = True,
        pool_size: tuple[int, int] = (1, 1),
    ) -> None:
        """Initialize classifier head wrapper.

        Args:
            layers: Layers after pooling (typically FC layers).
            needs_pooling: Whether to apply adaptive average pooling.
            pool_size: Output size for adaptive pooling.
        """
        super().__init__()
        self.layers = layers
        self.needs_pooling = needs_pooling
        self.pool = nn.AdaptiveAvgPool2d(pool_size) if needs_pooling else None

    def forward(self, x: "Tensor") -> "Tensor":
        """Forward pass through classifier head.

        Args:
            x: Scaled feature maps [B, K, U, V].

        Returns:
            Logits [B, num_classes].
        """
        if self.needs_pooling and self.pool is not None:
            x = self.pool(x)
        x = torch.flatten(x, 1)
        return self.layers(x)


class BaseArchitecturePlugin(ABC):
    """Abstract base class for architecture plugins.

    Provides default implementations and helper methods.
    Subclasses must implement the abstract methods.

    Example:
        >>> class MyPlugin(BaseArchitecturePlugin):
        ...     @classmethod
        ...     def supports(cls, model: nn.Module) -> bool:
        ...         return "myarch" in model.__class__.__name__.lower()
        ...
        ...     @classmethod
        ...     def _get_last_conv_layer(cls, model: nn.Module) -> nn.Module:
        ...         return model.features[-1]
        ...
        ...     @classmethod
        ...     def _get_fc_layers(cls, model: nn.Module) -> nn.Module:
        ...         return model.classifier
        ...
        ...     @classmethod
        ...     def _get_feature_channels(cls, model: nn.Module) -> int:
        ...         return model.features[-1].out_channels
    """

    @classmethod
    @abstractmethod
    def supports(cls, model: nn.Module) -> bool:
        """Check if this plugin supports the given model."""
        ...

    @classmethod
    @abstractmethod
    def _get_last_conv_layer(cls, model: nn.Module) -> nn.Module:
        """Get the last convolutional layer before pooling."""
        ...

    @classmethod
    @abstractmethod
    def _get_fc_layers(cls, model: nn.Module) -> nn.Module:
        """Get the fully connected layers."""
        ...

    @classmethod
    @abstractmethod
    def _get_feature_channels(cls, model: nn.Module) -> int:
        """Get number of feature channels at last conv layer."""
        ...

    @classmethod
    def _get_pool_size(cls) -> tuple[int, int]:
        """Get pooling size. Override for non-standard architectures."""
        return (1, 1)

    @classmethod
    def get_target_layer(cls, model: nn.Module) -> nn.Module:
        """Get recommended target layer for GradCAM."""
        return cls._get_last_conv_layer(model)

    @classmethod
    def extract_classifier_head(
        cls,
        model: nn.Module,
        target_layer: nn.Module | None = None,
    ) -> nn.Module:
        """Extract classifier head from model."""
        fc_layers = cls._get_fc_layers(model)

        # Wrap in Sequential if not already
        if not isinstance(fc_layers, nn.Sequential):
            fc_layers = nn.Sequential(fc_layers)

        return ClassifierHeadWrapper(
            layers=fc_layers,
            needs_pooling=True,
            pool_size=cls._get_pool_size(),
        )

    @classmethod
    def get_num_features(cls, model: nn.Module) -> int:
        """Get number of feature channels."""
        return cls._get_feature_channels(model)
