"""Property-based tests for baseline provider system.

These tests verify mathematical properties that must hold for all providers:
1. Centering: E[z'] = 0 (required for completeness axiom)
2. Non-degeneracy: samples have non-zero variance
3. Reproducibility: same provider config produces consistent samples
"""

from __future__ import annotations

import tempfile
from pathlib import Path

import numpy as np
import pytest
import torch
from torch import Tensor, nn

from expected_gradcam.baselines.base import BaseProvider, CacheableBaseProvider
from expected_gradcam.baselines.builder import BaselineProviderBuilder
from expected_gradcam.baselines.factory import baseline_from
from expected_gradcam.baselines.protocols import BaselineProvider
from expected_gradcam.baselines.registry import baseline_provider, get_registry


class TestCenteringProperty:
    """Tests for the centering property: E[z'] = 0.

    This is CRITICAL for the completeness axiom of Expected Gradients.
    All providers must return centered samples.
    """

    @pytest.fixture
    def mock_provider(self):
        """Create a mock provider for testing centering."""

        class MockProvider(BaseProvider):
            def __init__(self, data: Tensor):
                super().__init__()
                self._data = data

            @property
            def provider_type(self) -> str:
                return "mock"

            def _do_initialize(self) -> None:
                self._n_channels = self._data.shape[1]

            def _get_raw_samples(self, n: int) -> Tensor:
                indices = torch.randint(0, len(self._data), (n,))
                return self._data[indices]

            def __len__(self) -> int:
                return len(self._data)

        return MockProvider

    def test_centered_after_get_baseline_samples(self, mock_provider):
        """Test that get_baseline_samples returns centered data."""
        # Create data with non-zero mean
        data = torch.randn(100, 64) + 10.0  # Mean around 10

        provider = mock_provider(data)

        # Initialize with mock model/layer (not used for mock)
        model = nn.Linear(10, 10)
        provider.initialize(model, model, torch.device("cpu"))

        # Get samples
        samples = provider.get_baseline_samples(50, torch.device("cpu"))

        # Verify centering
        mean = samples.mean(dim=0)
        assert mean.abs().max() < 1e-5, f"Samples not centered: max mean = {mean.abs().max()}"

    def test_centering_with_various_distributions(self, mock_provider):
        """Test centering works with different data distributions."""
        distributions = [
            torch.randn(100, 64) * 0.1,  # Small variance
            torch.randn(100, 64) * 10,  # Large variance
            torch.randn(100, 64) + 100,  # Large positive mean
            torch.randn(100, 64) - 100,  # Large negative mean
            torch.rand(100, 64),  # Uniform [0, 1]
            torch.abs(torch.randn(100, 64)),  # Half-normal (positive values)
        ]

        for data in distributions:
            provider = mock_provider(data)
            model = nn.Linear(10, 10)
            provider.initialize(model, model, torch.device("cpu"))

            samples = provider.get_baseline_samples(50, torch.device("cpu"))
            mean = samples.mean(dim=0)

            assert mean.abs().max() < 1e-4, f"Centering failed for distribution with original mean {data.mean():.4f}"

    def test_centering_with_small_n(self, mock_provider):
        """Test centering with small sample sizes."""
        data = torch.randn(100, 64) + 5.0

        provider = mock_provider(data)
        model = nn.Linear(10, 10)
        provider.initialize(model, model, torch.device("cpu"))

        for n in [1, 2, 3, 5, 10]:
            samples = provider.get_baseline_samples(n, torch.device("cpu"))
            mean = samples.mean(dim=0)
            assert mean.abs().max() < 1e-4, f"Centering failed for n={n}"


class TestCachedProviderCentering:
    """Test centering for CachedFeatureProvider."""

    def test_cached_provider_centers_samples(self, tmp_path: Path):
        """Test that CachedFeatureProvider returns centered samples."""
        from expected_gradcam.baselines.providers.cached import CachedFeatureProvider

        # Create cache with non-zero mean
        data = np.random.randn(100, 64).astype(np.float32) + 5.0
        cache_path = tmp_path / "features.npy"
        np.save(cache_path, data)

        provider = CachedFeatureProvider(cache_path)
        model = nn.Linear(10, 10)
        provider.initialize(model, model, torch.device("cpu"))

        # Sample multiple times and verify centering
        for _ in range(5):
            samples = provider.get_baseline_samples(30, torch.device("cpu"))
            mean = samples.mean(dim=0)
            assert mean.abs().max() < 1e-4


class TestNonDegeneracyProperty:
    """Tests for non-degeneracy: samples should have meaningful variance."""

    def test_samples_have_nonzero_variance(self, tmp_path: Path):
        """Test that samples have non-zero variance."""
        from expected_gradcam.baselines.providers.cached import CachedFeatureProvider

        # Create cache with variance
        data = np.random.randn(100, 64).astype(np.float32)
        cache_path = tmp_path / "features.npy"
        np.save(cache_path, data)

        provider = CachedFeatureProvider(cache_path)
        model = nn.Linear(10, 10)
        provider.initialize(model, model, torch.device("cpu"))

        samples = provider.get_baseline_samples(50, torch.device("cpu"))

        # Verify non-zero variance
        var = samples.var(dim=0)
        assert var.mean() > 0.01, "Samples have near-zero variance"


class TestBuilderPattern:
    """Tests for BaselineProviderBuilder."""

    def test_builder_from_cache(self, tmp_path: Path):
        """Test builder with cached features."""
        # Create cache
        data = np.random.randn(50, 32).astype(np.float32)
        cache_path = tmp_path / "features.npy"
        np.save(cache_path, data)

        # Build provider
        provider = (
            BaselineProviderBuilder()
            .from_cache(cache_path)
            .build()
        )

        assert provider is not None
        assert provider.provider_type == "cached"

    def test_builder_with_mmap_mode(self, tmp_path: Path):
        """Test builder with memory-mapped mode."""
        data = np.random.randn(100, 64).astype(np.float32)
        cache_path = tmp_path / "features.npy"
        np.save(cache_path, data)

        provider = (
            BaselineProviderBuilder()
            .from_cache(cache_path)
            .with_mmap_mode("r")
            .build()
        )

        assert provider is not None

    def test_builder_raises_without_source(self):
        """Test that builder raises error without data source."""
        builder = BaselineProviderBuilder()

        with pytest.raises(ValueError) as exc_info:
            builder.build()

        assert "No data source specified" in str(exc_info.value)


class TestFactoryFunction:
    """Tests for baseline_from factory function."""

    def test_factory_from_npy_path(self, tmp_path: Path):
        """Test factory auto-detects .npy as cached provider."""
        data = np.random.randn(50, 32).astype(np.float32)
        cache_path = tmp_path / "features.npy"
        np.save(cache_path, data)

        provider = baseline_from(cache_path)

        assert provider.provider_type == "cached"

    def test_factory_from_npz_path(self, tmp_path: Path):
        """Test factory auto-detects .npz as cached provider."""
        data = np.random.randn(50, 32).astype(np.float32)
        cache_path = tmp_path / "features.npz"
        np.savez(cache_path, features=data)

        provider = baseline_from(cache_path)

        assert provider.provider_type == "cached"

    def test_factory_from_directory_path(self, tmp_path: Path):
        """Test factory auto-detects directory as directory provider."""
        # Create an empty directory (will fail on init but type detection works)
        dir_path = tmp_path / "images"
        dir_path.mkdir()

        provider = baseline_from(dir_path)

        assert provider.provider_type == "directory"

    def test_factory_from_dict_config(self, tmp_path: Path):
        """Test factory with explicit dict configuration."""
        data = np.random.randn(50, 32).astype(np.float32)
        cache_path = tmp_path / "features.npy"
        np.save(cache_path, data)

        provider = baseline_from({
            "type": "cached",
            "cache_path": str(cache_path),
        })

        assert provider.provider_type == "cached"

    def test_factory_from_string_detects_huggingface(self):
        """Test factory treats non-path strings as HuggingFace datasets."""
        # This should create a HuggingFaceProvider (but won't initialize)
        # Note: Factory auto-detection prefers checking if path exists first,
        # so a string that doesn't look like a path is needed
        try:
            # Use a clearly non-path HuggingFace identifier
            provider = baseline_from(
                {"type": "huggingface", "dataset_name": "imagenet-1k"}
            )
            assert provider.provider_type == "huggingface"
        except ImportError:
            pytest.skip("datasets package not installed")


class TestProtocolCompliance:
    """Tests for protocol compliance across all providers."""

    @pytest.fixture
    def all_provider_classes(self):
        """Get all registered provider classes."""
        registry = get_registry()
        return [registry.get(name) for name in registry.list_providers()]

    def test_all_providers_implement_protocol(self, all_provider_classes):
        """Test that all providers implement BaselineProvider protocol."""
        for cls in all_provider_classes:
            # Check required attributes exist
            assert hasattr(cls, "is_initialized")
            assert hasattr(cls, "provider_type")
            assert hasattr(cls, "initialize")
            assert hasattr(cls, "get_baseline_samples")
            assert hasattr(cls, "__len__")

    def test_all_providers_have_provider_type(self, all_provider_classes):
        """Test that all providers have a provider_type property."""
        registry = get_registry()

        for name in registry.list_providers():
            cls = registry.get(name)
            # Create instance with minimal args (may fail but property should exist)
            try:
                # Try to get the metadata
                metadata = registry.get_metadata(name)
                assert metadata.name == name
            except Exception:
                pass  # Some providers need specific args
