"""Architecture plugin system for Expected GradCAM.

Provides extensible support for extracting classifier heads from CNN models.
Users can register custom plugins for unsupported architectures.

Built-in plugins:
- ResNetPlugin: ResNet, ResNeXt
- VGGPlugin: VGG family
- DenseNetPlugin: DenseNet family
- EfficientNetPlugin: EfficientNet family
- MobileNetPlugin: MobileNet V2/V3
"""

from expected_gradcam.architectures.protocols import ArchitecturePlugin
from expected_gradcam.architectures.registry import ArchitectureRegistry, register_plugin
from expected_gradcam.architectures.base import (
    BaseArchitecturePlugin,
    ClassifierHeadWrapper,
)
from expected_gradcam.architectures.resnet import ResNetPlugin
from expected_gradcam.architectures.vgg import VGGPlugin
from expected_gradcam.architectures.densenet import DenseNetPlugin
from expected_gradcam.architectures.efficientnet import EfficientNetPlugin
from expected_gradcam.architectures.mobilenet import MobileNetPlugin


# Auto-register built-in plugins
_BUILTIN_PLUGINS = [
    ResNetPlugin,
    VGGPlugin,
    DenseNetPlugin,
    EfficientNetPlugin,
    MobileNetPlugin,
]

for plugin in _BUILTIN_PLUGINS:
    register_plugin(plugin)


def extract_classifier_head(
    model: "torch.nn.Module",
    target_layer: "torch.nn.Module | None" = None,
) -> "torch.nn.Module":
    """Extract classifier head from a model using registered plugins.

    Args:
        model: The CNN model.
        target_layer: Target layer for GradCAM (optional, for validation).

    Returns:
        Classifier head module.

    Raises:
        UnsupportedArchitectureError: If no plugin supports the model.
    """
    return ArchitectureRegistry.extract_classifier_head(model, target_layer)


def get_target_layer(model: "torch.nn.Module") -> "torch.nn.Module":
    """Get recommended target layer for a model.

    Args:
        model: The CNN model.

    Returns:
        Target layer for GradCAM.

    Raises:
        UnsupportedArchitectureError: If no plugin supports the model.
    """
    return ArchitectureRegistry.get_target_layer(model)


def get_num_features(model: "torch.nn.Module") -> int:
    """Get number of feature channels at target layer.

    Args:
        model: The CNN model.

    Returns:
        Number of feature channels K.

    Raises:
        UnsupportedArchitectureError: If no plugin supports the model.
    """
    return ArchitectureRegistry.get_num_features(model)


__all__ = [
    # Protocols
    "ArchitecturePlugin",
    # Registry
    "ArchitectureRegistry",
    "register_plugin",
    # Base classes
    "BaseArchitecturePlugin",
    "ClassifierHeadWrapper",
    # Built-in plugins
    "ResNetPlugin",
    "VGGPlugin",
    "DenseNetPlugin",
    "EfficientNetPlugin",
    "MobileNetPlugin",
    # Convenience functions
    "extract_classifier_head",
    "get_target_layer",
    "get_num_features",
]


# Import torch for type hints
import torch
