"""Tests for baseline provider registry."""

from __future__ import annotations

import pytest
import torch
from torch import Tensor, nn

from expected_gradcam.baselines.registry import (
    ProviderRegistry,
    ProviderMetadata,
    baseline_provider,
    get_registry,
)
from expected_gradcam.baselines.base import BaseProvider
from expected_gradcam.exceptions.baseline import ProviderNotFoundError


class TestProviderRegistry:
    """Tests for ProviderRegistry class."""

    def test_get_registry_returns_singleton(self):
        """Test that get_registry returns the same instance."""
        registry1 = get_registry()
        registry2 = get_registry()
        assert registry1 is registry2

    def test_list_providers_returns_registered_providers(self):
        """Test that list_providers returns all registered provider names."""
        registry = get_registry()
        providers = registry.list_providers()

        # Check that the default providers are registered
        assert "directory" in providers
        assert "cached" in providers
        assert "torch_dataset" in providers
        assert "imagenet" in providers

    def test_get_returns_provider_class(self):
        """Test that get returns the provider class."""
        registry = get_registry()
        provider_cls = registry.get("directory")

        assert provider_cls is not None
        assert issubclass(provider_cls, BaseProvider)

    def test_get_with_alias_returns_provider_class(self):
        """Test that get works with aliases."""
        registry = get_registry()

        # "dir" is an alias for "directory"
        provider_cls = registry.get("dir")
        assert provider_cls is not None

        # Verify it's the same as the main name
        assert provider_cls is registry.get("directory")

    def test_get_unknown_raises_error(self):
        """Test that get raises ProviderNotFoundError for unknown names."""
        registry = get_registry()

        with pytest.raises(ProviderNotFoundError) as exc_info:
            registry.get("nonexistent_provider")

        assert "nonexistent_provider" in str(exc_info.value)

    def test_internal_metadata_available(self):
        """Test that provider metadata is stored in registry."""
        registry = get_registry()

        # Access internal metadata (for testing purposes)
        assert "directory" in registry._metadata
        metadata = registry._metadata["directory"]

        assert isinstance(metadata, ProviderMetadata)
        assert metadata.name == "directory"
        assert "dir" in metadata.aliases
        assert metadata.supports_caching is True

    def test_create_creates_provider_instance(self):
        """Test that create instantiates a provider."""
        registry = get_registry()

        # Create a cached provider (simplest to test)
        import tempfile
        import numpy as np
        from pathlib import Path

        with tempfile.NamedTemporaryFile(suffix=".npy", delete=False) as f:
            # Save some test data
            test_data = np.random.randn(100, 64).astype(np.float32)
            np.save(f.name, test_data)

            provider = registry.create("cached", cache_path=f.name)

            assert provider is not None
            assert provider.provider_type == "cached"


class TestBaselineProviderDecorator:
    """Tests for baseline_provider decorator."""

    def test_decorator_registers_provider(self):
        """Test that decorator registers a provider class."""
        registry = get_registry()

        # Create a test provider
        @baseline_provider("test_provider_decorator", aliases=("test_alias",))
        class TestProviderDecorator(BaseProvider):
            def __init__(self, value: int = 42):
                super().__init__()
                self.value = value

            @property
            def provider_type(self) -> str:
                return "test_provider_decorator"

            def _do_initialize(self) -> None:
                self._gap_cache = torch.randn(10, 64)
                self._n_channels = 64

            def _get_raw_samples(self, n: int) -> Tensor:
                return self._gap_cache[:n]

            def __len__(self) -> int:
                return 10

        # Check registration
        assert "test_provider_decorator" in registry.list_providers()
        assert registry.get("test_provider_decorator") is TestProviderDecorator
        assert registry.get("test_alias") is TestProviderDecorator

        # Cleanup
        registry.unregister("test_provider_decorator")

    def test_decorator_preserves_class(self):
        """Test that decorator returns the original class."""
        registry = get_registry()

        @baseline_provider("test_preserve_class")
        class TestPreserveClass(BaseProvider):
            class_attribute = "test"

            @property
            def provider_type(self) -> str:
                return "test_preserve_class"

            def _do_initialize(self) -> None:
                pass

            def _get_raw_samples(self, n: int) -> Tensor:
                return torch.zeros(n, 10)

            def __len__(self) -> int:
                return 0

        # Check class is preserved
        assert TestPreserveClass.class_attribute == "test"

        # Cleanup
        registry.unregister("test_preserve_class")


class TestProviderMetadata:
    """Tests for ProviderMetadata dataclass."""

    def test_metadata_fields(self):
        """Test that metadata has expected fields."""
        metadata = ProviderMetadata(
            name="test",
            full_name="Test Provider",
            description="A test provider",
            aliases=("t", "tst"),
            supports_caching=True,
            supports_streaming=False,
            requires_packages=(),
        )

        assert metadata.name == "test"
        assert metadata.full_name == "Test Provider"
        assert "t" in metadata.aliases
        assert metadata.supports_caching is True
        assert metadata.supports_streaming is False
