"""Architecture plugin registry with automatic discovery.

Provides a centralized registry for architecture plugins that can be
extended at runtime.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Type

from expected_gradcam.architectures.protocols import ArchitecturePlugin
from expected_gradcam.exceptions import UnsupportedArchitectureError


if TYPE_CHECKING:
    from torch import nn


class ArchitectureRegistry:
    """Central registry for architecture plugins.

    The registry maintains a list of plugins and provides methods to
    find the appropriate plugin for a given model.

    Plugins are checked in registration order, so more specific plugins
    should be registered before general ones.

    Example:
        >>> ArchitectureRegistry.register(MyPlugin)
        >>> plugin = ArchitectureRegistry.find_plugin(model)
        >>> if plugin:
        ...     head = plugin.extract_classifier_head(model, target_layer)
    """

    _plugins: list[Type[ArchitecturePlugin]] = []

    @classmethod
    def register(cls, plugin: Type[ArchitecturePlugin]) -> None:
        """Register an architecture plugin.

        Args:
            plugin: Plugin class implementing ArchitecturePlugin protocol.

        Raises:
            TypeError: If plugin doesn't implement required methods.
        """
        # Validate plugin implements protocol
        required_methods = ["supports", "get_target_layer", "extract_classifier_head", "get_num_features"]
        for method in required_methods:
            if not hasattr(plugin, method) or not callable(getattr(plugin, method)):
                raise TypeError(
                    f"Plugin {plugin.__name__} missing required method: {method}"
                )

        # Don't add duplicates
        if plugin not in cls._plugins:
            cls._plugins.append(plugin)

    @classmethod
    def unregister(cls, plugin: Type[ArchitecturePlugin]) -> None:
        """Unregister an architecture plugin.

        Args:
            plugin: Plugin class to remove.
        """
        if plugin in cls._plugins:
            cls._plugins.remove(plugin)

    @classmethod
    def find_plugin(cls, model: "nn.Module") -> Type[ArchitecturePlugin] | None:
        """Find a plugin that supports the given model.

        Plugins are checked in registration order.

        Args:
            model: The model to find a plugin for.

        Returns:
            Plugin class if found, None otherwise.
        """
        for plugin in cls._plugins:
            try:
                if plugin.supports(model):
                    return plugin
            except Exception:
                # Skip plugins that raise exceptions
                continue
        return None

    @classmethod
    def get_supported_architectures(cls) -> list[str]:
        """Get list of supported architecture names.

        Returns:
            List of architecture names from registered plugins.
        """
        return [p.__name__.replace("Plugin", "") for p in cls._plugins]

    @classmethod
    def extract_classifier_head(
        cls,
        model: "nn.Module",
        target_layer: "nn.Module | None" = None,
    ) -> "nn.Module":
        """Extract classifier head using registered plugins.

        Args:
            model: The model.
            target_layer: Optional target layer.

        Returns:
            Classifier head module.

        Raises:
            UnsupportedArchitectureError: If no plugin supports the model.
        """
        plugin = cls.find_plugin(model)
        if plugin is None:
            raise UnsupportedArchitectureError(
                model.__class__.__name__,
                supported=cls.get_supported_architectures(),
            )
        return plugin.extract_classifier_head(model, target_layer)

    @classmethod
    def get_target_layer(cls, model: "nn.Module") -> "nn.Module":
        """Get target layer using registered plugins.

        Args:
            model: The model.

        Returns:
            Target layer.

        Raises:
            UnsupportedArchitectureError: If no plugin supports the model.
        """
        plugin = cls.find_plugin(model)
        if plugin is None:
            raise UnsupportedArchitectureError(
                model.__class__.__name__,
                supported=cls.get_supported_architectures(),
            )
        return plugin.get_target_layer(model)

    @classmethod
    def get_num_features(cls, model: "nn.Module") -> int:
        """Get number of features using registered plugins.

        Args:
            model: The model.

        Returns:
            Number of feature channels K.

        Raises:
            UnsupportedArchitectureError: If no plugin supports the model.
        """
        plugin = cls.find_plugin(model)
        if plugin is None:
            raise UnsupportedArchitectureError(
                model.__class__.__name__,
                supported=cls.get_supported_architectures(),
            )
        return plugin.get_num_features(model)

    @classmethod
    def clear(cls) -> None:
        """Clear all registered plugins (for testing)."""
        cls._plugins.clear()


def register_plugin(plugin: Type[ArchitecturePlugin]) -> Type[ArchitecturePlugin]:
    """Decorator/function to register an architecture plugin.

    Can be used as a decorator or called directly.

    Example:
        >>> @register_plugin
        ... class MyPlugin:
        ...     ...

        >>> register_plugin(MyPlugin)
    """
    ArchitectureRegistry.register(plugin)
    return plugin
