"""Tests for the metrics registry metaclass."""

from __future__ import annotations

import pytest

from expected_gradcam.metrics.base import BaseMetric
from expected_gradcam.metrics.exceptions import MetricNotFoundError
from expected_gradcam.metrics.registry import MetricRegistryMeta, register_metric


class TestMetricRegistryMeta:
    """Tests for MetricRegistryMeta metaclass."""

    def test_registry_exists(self):
        """MetricRegistryMeta should have a registry dict."""
        assert hasattr(MetricRegistryMeta, "_registry")
        assert isinstance(MetricRegistryMeta._registry, dict)

    def test_abstract_classes_not_registered(self):
        """Abstract base classes should not be in registry."""
        # BaseMetric has _abstract = True
        assert "BaseMetric" not in MetricRegistryMeta._registry

    def test_get_returns_metric_class(self):
        """get() should return registered metric class."""
        # InternalInfidelity should be registered
        from expected_gradcam.metrics.infidelity import InternalInfidelity

        metric_cls = MetricRegistryMeta.get("internal_infidelity")
        assert metric_cls is InternalInfidelity

    def test_get_raises_for_unknown_metric(self):
        """get() should raise MetricNotFoundError for unknown metrics."""
        with pytest.raises(MetricNotFoundError) as exc_info:
            MetricRegistryMeta.get("nonexistent_metric")

        assert "nonexistent_metric" in str(exc_info.value)

    def test_list_metrics_returns_names(self):
        """list_metrics() should return list of registered metric names."""
        metrics = MetricRegistryMeta.list_metrics()
        assert isinstance(metrics, list)
        assert len(metrics) > 0
        assert "internal_infidelity" in metrics

    def test_list_metrics_by_category(self):
        """list_metrics(category=...) should filter by category."""
        # Test infidelity category
        infidelity_metrics = MetricRegistryMeta.list_metrics(category="infidelity")
        for name in infidelity_metrics:
            assert "infidelity" in name.lower() or name in ["internal_infidelity", "batched_infidelity"]

    def test_get_info_returns_metadata(self):
        """get_info() should return metric metadata."""
        info = MetricRegistryMeta.get_info("internal_infidelity")

        assert info["name"] == "internal_infidelity"
        assert "display_name" in info
        assert "lower_is_better" in info
        assert info["lower_is_better"] is True  # Infidelity is lower-is-better


class TestRegisterMetricDecorator:
    """Tests for @register_metric decorator."""

    def test_decorator_registers_metric(self):
        """@register_metric should add metric to registry."""
        # Create a test metric class
        @register_metric(
            "test_metric_decorator",
            display_name="Test Metric",
            lower_is_better=True,
            streamable=False,
            category="test",
        )
        class TestMetricDec(BaseMetric):
            _abstract = False

            def validate_inputs(self, **kwargs) -> None:
                pass

            def compute(self, **kwargs) -> float:
                return 0.0

        # Verify registration
        assert "test_metric_decorator" in MetricRegistryMeta._registry

        # Clean up
        del MetricRegistryMeta._registry["test_metric_decorator"]

    def test_decorator_sets_attributes(self):
        """@register_metric should set class attributes."""
        @register_metric(
            "test_metric_attrs",
            display_name="Test Display Name",
            lower_is_better=False,
            streamable=True,
            category="test",
        )
        class TestMetricAttrs(BaseMetric):
            _abstract = False

            def validate_inputs(self, **kwargs) -> None:
                pass

            def compute(self, **kwargs) -> float:
                return 1.0

        assert TestMetricAttrs.metric_name == "test_metric_attrs"
        assert TestMetricAttrs._display_name == "Test Display Name"
        assert TestMetricAttrs._lower_is_better is False
        assert TestMetricAttrs._streamable is True
        assert TestMetricAttrs._category == "test"

        # Clean up
        del MetricRegistryMeta._registry["test_metric_attrs"]


class TestMetricInstantiation:
    """Tests for instantiating metrics from registry."""

    def test_instantiate_from_registry(self):
        """Should be able to instantiate metric from registry lookup."""
        metric_cls = MetricRegistryMeta.get("internal_infidelity")
        metric = metric_cls()

        assert isinstance(metric, BaseMetric)

    def test_all_registered_metrics_instantiable(self):
        """All registered metrics should be instantiable."""
        for name in MetricRegistryMeta.list_metrics():
            metric_cls = MetricRegistryMeta.get(name)
            # Try to instantiate (some may need params)
            try:
                metric = metric_cls()
                assert metric is not None
            except TypeError:
                # Some metrics require arguments
                pass
