"""Provider registry with decorator-based registration.

This module provides a thread-safe singleton registry for baseline providers.
Providers can be registered using the @baseline_provider decorator.

Example::

    from expected_gradcam.baselines import baseline_provider, BaseProvider

    @baseline_provider(
        "my_custom",
        full_name="My Custom Provider",
        aliases=("custom", "mc"),
        description="Custom provider for specialized use cases"
    )
    class MyCustomProvider(BaseProvider):
        def __init__(self, custom_arg: str):
            super().__init__()
            self.custom_arg = custom_arg

        # ... implement abstract methods
"""

from __future__ import annotations

import threading
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, TypeVar

from expected_gradcam.exceptions.baseline import (
    ProviderNotFoundError,
    ProviderInitializationError,
)

if TYPE_CHECKING:
    from expected_gradcam.baselines.protocols import BaselineProvider

T = TypeVar("T", bound="BaselineProvider")


@dataclass(frozen=True)
class ProviderMetadata:
    """Metadata for a registered provider.

    Attributes:
        name: Primary name (e.g., "directory", "imagenet").
        full_name: Human-readable name for display.
        description: One-line description of the provider.
        aliases: Alternative names for lookup.
        supports_caching: Whether provider implements CacheableProvider.
        supports_streaming: Whether provider implements StreamingProvider.
        requires_packages: List of required optional packages.
    """

    name: str
    full_name: str = ""
    description: str = ""
    aliases: tuple[str, ...] = field(default_factory=tuple)
    supports_caching: bool = False
    supports_streaming: bool = False
    requires_packages: tuple[str, ...] = field(default_factory=tuple)

    def matches(self, query: str) -> bool:
        """Check if query matches name or any alias.

        Args:
            query: Name to check (case-insensitive).

        Returns:
            True if query matches name or any alias.
        """
        query_lower = query.lower()
        if self.name.lower() == query_lower:
            return True
        return any(alias.lower() == query_lower for alias in self.aliases)


class ProviderRegistry:
    """Thread-safe singleton registry for baseline providers.

    This class maintains a global registry of provider classes that can be
    instantiated by name. It uses the singleton pattern to ensure a single
    registry instance across the application.

    Example::

        >>> registry = ProviderRegistry()
        >>> @registry.register("custom", aliases=("my_provider",))
        ... class CustomProvider:
        ...     pass
        >>> provider = registry.create("custom", source="/path/to/data")
    """

    _instance: ProviderRegistry | None = None
    _lock: threading.Lock = threading.Lock()

    def __new__(cls) -> ProviderRegistry:
        """Create or return the singleton instance."""
        if cls._instance is None:
            with cls._lock:
                # Double-check locking pattern
                if cls._instance is None:
                    instance = super().__new__(cls)
                    instance._initialize()
                    cls._instance = instance
        return cls._instance

    def _initialize(self) -> None:
        """Initialize registry state."""
        self._providers: dict[str, type[BaselineProvider]] = {}
        self._metadata: dict[str, ProviderMetadata] = {}
        self._aliases: dict[str, str] = {}
        self._default: str | None = None
        self._entry_lock = threading.Lock()

    def register(
        self,
        name: str,
        *,
        full_name: str = "",
        description: str = "",
        aliases: tuple[str, ...] | list[str] = (),
        supports_caching: bool = False,
        supports_streaming: bool = False,
        requires_packages: tuple[str, ...] = (),
        default: bool = False,
    ) -> Callable[[type[T]], type[T]]:
        """Decorator to register a provider class.

        Args:
            name: Primary provider name (used for lookup).
            full_name: Human-readable name for display.
            description: One-line description.
            aliases: Alternative names for lookup.
            supports_caching: Whether provider implements CacheableProvider.
            supports_streaming: Whether provider implements StreamingProvider.
            requires_packages: Required optional packages (checked at creation).
            default: Mark this as the default provider.

        Returns:
            Decorator function that registers the class.

        Example::

            >>> @registry.register(
            ...     "directory",
            ...     full_name="Directory Provider",
            ...     aliases=("dir", "folder"),
            ...     supports_caching=True,
            ... )
            ... class DirectoryProvider:
            ...     pass
        """

        def decorator(cls: type[T]) -> type[T]:
            self._register_class(
                cls,
                name=name,
                full_name=full_name or name.title().replace("_", " "),
                description=description,
                aliases=tuple(aliases),
                supports_caching=supports_caching,
                supports_streaming=supports_streaming,
                requires_packages=requires_packages,
                default=default,
            )
            return cls

        return decorator

    def _register_class(
        self,
        cls: type[BaselineProvider],
        name: str,
        full_name: str,
        description: str,
        aliases: tuple[str, ...],
        supports_caching: bool,
        supports_streaming: bool,
        requires_packages: tuple[str, ...],
        default: bool,
    ) -> None:
        """Internal registration with validation.

        Args:
            cls: The provider class to register.
            name: Primary name.
            full_name: Display name.
            description: Description.
            aliases: Alternative names.
            supports_caching: Caching support flag.
            supports_streaming: Streaming support flag.
            requires_packages: Required packages.
            default: Default flag.

        Raises:
            ValueError: If name or alias already registered.
        """
        name_lower = name.lower()

        with self._entry_lock:
            # Check for conflicts
            if name_lower in self._providers:
                raise ValueError(f"Provider '{name}' already registered")

            for alias in aliases:
                alias_lower = alias.lower()
                if alias_lower in self._aliases:
                    raise ValueError(f"Alias '{alias}' already in use")
                if alias_lower in self._providers:
                    raise ValueError(
                        f"Alias '{alias}' conflicts with provider name"
                    )

            # Create metadata
            metadata = ProviderMetadata(
                name=name_lower,
                full_name=full_name,
                description=description,
                aliases=aliases,
                supports_caching=supports_caching,
                supports_streaming=supports_streaming,
                requires_packages=requires_packages,
            )

            # Register
            self._providers[name_lower] = cls
            self._metadata[name_lower] = metadata

            for alias in aliases:
                self._aliases[alias.lower()] = name_lower

            if default:
                self._default = name_lower

    def _resolve_name(self, name: str) -> str:
        """Resolve alias to canonical name.

        Args:
            name: Name or alias to resolve.

        Returns:
            Canonical provider name (lowercase).
        """
        lower = name.lower()
        return self._aliases.get(lower, lower)

    def get(self, name: str) -> type[BaselineProvider]:
        """Get provider class by name.

        Args:
            name: Provider name or alias.

        Returns:
            Provider class.

        Raises:
            ProviderNotFoundError: If provider not registered.
        """
        canonical = self._resolve_name(name)

        if canonical in self._providers:
            return self._providers[canonical]

        raise ProviderNotFoundError(name, available=self.list_providers())

    def create(
        self,
        name: str,
        **kwargs: Any,
    ) -> BaselineProvider:
        """Create provider instance.

        Args:
            name: Provider name or alias.
            **kwargs: Arguments for provider constructor.

        Returns:
            Provider instance.

        Raises:
            ProviderNotFoundError: If provider not registered.
            ProviderInitializationError: If required packages missing.
        """
        provider_cls = self.get(name)

        # Check for required packages
        metadata = self.get_info(name)
        for package in metadata.requires_packages:
            try:
                __import__(package)
            except ImportError:
                raise ProviderInitializationError(
                    name,
                    reason=f"Required package '{package}' not installed. "
                    f"Install with: pip install {package}",
                )

        return provider_cls(**kwargs)

    def get_info(self, name: str) -> ProviderMetadata:
        """Get metadata for provider.

        Args:
            name: Provider name or alias.

        Returns:
            Provider metadata.

        Raises:
            ProviderNotFoundError: If provider not registered.
        """
        canonical = self._resolve_name(name)

        if canonical in self._metadata:
            return self._metadata[canonical]

        raise ProviderNotFoundError(name, available=self.list_providers())

    def list_providers(self) -> list[str]:
        """List all registered provider names.

        Returns:
            List of canonical provider names (sorted).
        """
        return sorted(self._providers.keys())

    def list_all_names(self) -> list[str]:
        """List all provider names including aliases.

        Returns:
            List of all names and aliases (sorted).
        """
        all_names = set(self._providers.keys())
        all_names.update(self._aliases.keys())
        return sorted(all_names)

    def get_default(self) -> str | None:
        """Get default provider name.

        Returns:
            Default provider name, or None if not set.
        """
        return self._default

    def is_registered(self, name: str) -> bool:
        """Check if a provider name is registered.

        Args:
            name: Provider name or alias.

        Returns:
            True if registered.
        """
        canonical = self._resolve_name(name)
        return canonical in self._providers

    def unregister(self, name: str) -> None:
        """Unregister a provider (mainly for testing).

        Args:
            name: Provider name (not alias).

        Raises:
            ProviderNotFoundError: If provider not registered.
        """
        name_lower = name.lower()

        with self._entry_lock:
            if name_lower not in self._providers:
                raise ProviderNotFoundError(name, available=self.list_providers())

            # Get metadata to find aliases
            metadata = self._metadata[name_lower]

            # Remove aliases
            for alias in metadata.aliases:
                self._aliases.pop(alias.lower(), None)

            # Remove provider
            del self._providers[name_lower]
            del self._metadata[name_lower]

            # Clear default if this was it
            if self._default == name_lower:
                self._default = None


# Global registry instance
_registry = ProviderRegistry()


def get_registry() -> ProviderRegistry:
    """Get the global provider registry.

    Returns:
        The singleton ProviderRegistry instance.
    """
    return _registry


def baseline_provider(
    name: str,
    *,
    full_name: str = "",
    description: str = "",
    aliases: tuple[str, ...] | list[str] = (),
    supports_caching: bool = False,
    supports_streaming: bool = False,
    requires_packages: tuple[str, ...] = (),
    default: bool = False,
) -> Callable[[type[T]], type[T]]:
    """Decorator to register a baseline provider.

    This is the public API for registering custom providers. It delegates
    to the global registry singleton.

    Args:
        name: Primary provider name.
        full_name: Human-readable display name.
        description: One-line description.
        aliases: Alternative names for lookup.
        supports_caching: Whether provider implements CacheableProvider.
        supports_streaming: Whether provider implements StreamingProvider.
        requires_packages: Required optional packages.
        default: Mark as default provider.

    Returns:
        Decorator function.

    Example::

        from expected_gradcam.baselines import baseline_provider, BaseProvider

        @baseline_provider(
            "my_custom",
            full_name="My Custom Provider",
            aliases=("custom",),
        )
        class MyCustomProvider(BaseProvider):
            def __init__(self, custom_arg: str):
                super().__init__()
                self.custom_arg = custom_arg

            # ... implement abstract methods
    """
    return _registry.register(
        name,
        full_name=full_name,
        description=description,
        aliases=aliases,
        supports_caching=supports_caching,
        supports_streaming=supports_streaming,
        requires_packages=requires_packages,
        default=default,
    )


__all__ = [
    "ProviderRegistry",
    "ProviderMetadata",
    "get_registry",
    "baseline_provider",
]
