"""Metric registry with metaclass-based auto-registration.

This module provides a metaclass and decorator for automatically registering
metric classes, enabling dynamic lookup and discovery of available metrics.

Example:
    >>> from expected_gradcam.metrics.registry import register_metric, MetricRegistry
    >>>
    >>> @register_metric("my_metric", display_name="My Custom Metric")
    ... class MyMetric:
    ...     def compute(self, **kwargs) -> float:
    ...         return 0.0
    >>>
    >>> MetricRegistry.get("my_metric")
    <class 'MyMetric'>
    >>> MetricRegistry.list_metrics()
    ['my_metric', ...]
"""

from __future__ import annotations

import threading
from typing import TYPE_CHECKING, Any, Callable, TypeVar

from expected_gradcam.metrics.exceptions import MetricNotFoundError

if TYPE_CHECKING:
    from expected_gradcam.metrics.base import BaseMetric

T = TypeVar("T", bound="BaseMetric")


class MetricRegistryMeta(type):
    """Metaclass for automatic metric registration.

    This metaclass maintains a global registry of metric classes, enabling:
    - Automatic registration when a class is defined
    - Dynamic lookup by metric name
    - Discovery of all available metrics

    The registry is thread-safe and supports both automatic registration
    (via metaclass) and explicit registration (via decorator).

    Example:
        >>> class MyMetric(BaseMetric, metaclass=MetricRegistryMeta):
        ...     metric_name = "my_metric"
        ...     def compute(self, **kwargs) -> float:
        ...         return 0.0
        >>>
        >>> MetricRegistryMeta.get("my_metric")
        <class 'MyMetric'>
    """

    _registry: dict[str, type] = {}
    _lock = threading.Lock()

    def __new__(
        mcs,
        name: str,
        bases: tuple[type, ...],
        namespace: dict[str, Any],
    ) -> "MetricRegistryMeta":
        """Create a new metric class and optionally register it.

        Classes with `_abstract = True` in their namespace are not registered.
        Other classes are registered using their `metric_name` attribute.
        """
        cls = super().__new__(mcs, name, bases, namespace)

        # Skip registration for abstract base classes
        if namespace.get("_abstract", False):
            return cls

        # Get metric name from class attribute or namespace
        metric_name = namespace.get("metric_name") or getattr(cls, "metric_name", None)

        # Auto-register if metric_name is defined
        if metric_name and isinstance(metric_name, str):
            with mcs._lock:
                mcs._registry[metric_name] = cls

        return cls

    @classmethod
    def get(mcs, name: str) -> type:
        """Get a registered metric class by name.

        Args:
            name: The metric name to look up.

        Returns:
            The metric class registered under that name.

        Raises:
            MetricNotFoundError: If no metric is registered with that name.
        """
        with mcs._lock:
            if name not in mcs._registry:
                raise MetricNotFoundError(name, list(mcs._registry.keys()))
            return mcs._registry[name]

    @classmethod
    def list_metrics(mcs) -> list[str]:
        """List all registered metric names.

        Returns:
            Sorted list of registered metric names.
        """
        with mcs._lock:
            return sorted(mcs._registry.keys())

    @classmethod
    def register(mcs, name: str, cls: type) -> None:
        """Explicitly register a metric class.

        Args:
            name: The name to register the metric under.
            cls: The metric class to register.
        """
        with mcs._lock:
            mcs._registry[name] = cls

    @classmethod
    def unregister(mcs, name: str) -> None:
        """Unregister a metric by name.

        Args:
            name: The metric name to unregister.

        Raises:
            MetricNotFoundError: If no metric is registered with that name.
        """
        with mcs._lock:
            if name not in mcs._registry:
                raise MetricNotFoundError(name, list(mcs._registry.keys()))
            del mcs._registry[name]

    @classmethod
    def clear(mcs) -> None:
        """Clear all registered metrics (for testing)."""
        with mcs._lock:
            mcs._registry.clear()

    @classmethod
    def is_registered(mcs, name: str) -> bool:
        """Check if a metric is registered.

        Args:
            name: The metric name to check.

        Returns:
            True if the metric is registered, False otherwise.
        """
        with mcs._lock:
            return name in mcs._registry


# Alias for easier access
MetricRegistry = MetricRegistryMeta


def register_metric(
    name: str,
    *,
    display_name: str = "",
    lower_is_better: bool = True,
    streamable: bool = False,
    category: str = "general",
) -> Callable[[type[T]], type[T]]:
    """Decorator to register a metric class.

    This decorator provides a clean API for registering metrics with
    additional metadata (display name, whether lower is better, etc.).

    Args:
        name: Unique name for the metric (used for lookup).
        display_name: Human-readable name for display (default: derived from name).
        lower_is_better: Whether lower values indicate better performance.
        streamable: Whether the metric can be computed incrementally.
        category: Category for grouping metrics (solver, heatmap, etc.).

    Returns:
        Decorator function that registers the class.

    Example:
        >>> @register_metric(
        ...     "condition_number",
        ...     display_name="Condition Number",
        ...     lower_is_better=True,
        ...     category="solver"
        ... )
        ... class ConditionNumber(BaseMetric):
        ...     def compute(self, M_I):
        ...         ...
    """

    def decorator(cls: type[T]) -> type[T]:
        # Set class attributes
        cls.metric_name = name  # type: ignore
        cls._display_name = display_name or name.replace("_", " ").title()  # type: ignore
        cls._lower_is_better = lower_is_better  # type: ignore
        cls._streamable = streamable  # type: ignore
        cls._category = category  # type: ignore

        # Register in the global registry
        MetricRegistryMeta.register(name, cls)

        return cls

    return decorator
