"""Tests for baseline provider implementations."""

from __future__ import annotations

import tempfile
from pathlib import Path
from typing import TYPE_CHECKING

import numpy as np
import pytest
import torch
from torch import Tensor, nn

from expected_gradcam.baselines.protocols import BaselineProvider, CacheableProvider
from expected_gradcam.baselines.providers import (
    CachedFeatureProvider,
    DirectoryProvider,
    ImageNetProvider,
    TorchDatasetProvider,
)
from expected_gradcam.exceptions.baseline import (
    CacheCorruptedError,
    DirectoryNotFoundError,
    ProviderInitializationError,
)

if TYPE_CHECKING:
    from tests.conftest import SimpleCNN


class TestCachedFeatureProvider:
    """Tests for CachedFeatureProvider."""

    def test_load_npy_cache(self, tmp_path: Path):
        """Test loading features from .npy file."""
        # Create test cache
        test_data = np.random.randn(100, 64).astype(np.float32)
        cache_path = tmp_path / "features.npy"
        np.save(cache_path, test_data)

        # Create provider
        provider = CachedFeatureProvider(cache_path)
        assert provider.provider_type == "cached"

        # Provider can be used without initialization for basic operations
        assert len(provider) == 0  # Before init

    def test_load_npz_cache(self, tmp_path: Path):
        """Test loading features from .npz file."""
        test_data = np.random.randn(50, 32).astype(np.float32)
        cache_path = tmp_path / "features.npz"
        np.savez(cache_path, features=test_data)

        provider = CachedFeatureProvider(cache_path)
        assert provider.provider_type == "cached"

    def test_invalid_cache_raises_error(self, tmp_path: Path):
        """Test that invalid cache file raises CacheCorruptedError."""
        cache_path = tmp_path / "invalid.npy"
        cache_path.write_text("not a numpy file")

        provider = CachedFeatureProvider(cache_path)

        # Create a minimal model for initialization
        model = nn.Sequential(nn.Conv2d(3, 64, 3), nn.ReLU())
        target_layer = model[0]

        with pytest.raises(CacheCorruptedError):
            provider.initialize(model, target_layer, torch.device("cpu"))

    def test_nonexistent_cache_raises_error(self, tmp_path: Path):
        """Test that nonexistent cache file raises error."""
        cache_path = tmp_path / "nonexistent.npy"

        provider = CachedFeatureProvider(cache_path)

        model = nn.Sequential(nn.Conv2d(3, 64, 3), nn.ReLU())
        target_layer = model[0]

        with pytest.raises(ProviderInitializationError):
            provider.initialize(model, target_layer, torch.device("cpu"))

    def test_get_samples_returns_centered(self, tmp_path: Path):
        """Test that get_baseline_samples returns centered samples."""
        # Create test cache with non-zero mean
        test_data = np.random.randn(100, 64).astype(np.float32) + 5.0
        cache_path = tmp_path / "features.npy"
        np.save(cache_path, test_data)

        provider = CachedFeatureProvider(cache_path)

        model = nn.Sequential(nn.Conv2d(3, 64, 3), nn.ReLU())
        target_layer = model[0]
        provider.initialize(model, target_layer, torch.device("cpu"))

        # Get samples
        samples = provider.get_baseline_samples(20, torch.device("cpu"))

        # Samples should be centered
        assert samples.shape == (20, 64)
        assert samples.mean(dim=0).abs().max() < 1e-5

    def test_protocol_compliance(self, tmp_path: Path):
        """Test that CachedFeatureProvider implements BaselineProvider protocol."""
        test_data = np.random.randn(100, 64).astype(np.float32)
        cache_path = tmp_path / "features.npy"
        np.save(cache_path, test_data)

        provider = CachedFeatureProvider(cache_path)

        # Check protocol compliance
        assert isinstance(provider, BaselineProvider)


class TestDirectoryProvider:
    """Tests for DirectoryProvider."""

    def test_nonexistent_directory_raises_error(self, tmp_path: Path):
        """Test that nonexistent directory raises DirectoryNotFoundError."""
        nonexistent = tmp_path / "nonexistent"

        provider = DirectoryProvider(nonexistent)

        model = nn.Sequential(nn.Conv2d(3, 64, 3), nn.ReLU())
        target_layer = model[0]

        with pytest.raises(DirectoryNotFoundError):
            provider.initialize(model, target_layer, torch.device("cpu"))

    def test_empty_directory_raises_error(self, tmp_path: Path):
        """Test that empty directory raises appropriate error."""
        from expected_gradcam.exceptions.baseline import EmptyBaselineDatasetError

        empty_dir = tmp_path / "empty"
        empty_dir.mkdir()

        provider = DirectoryProvider(empty_dir)

        model = nn.Sequential(nn.Conv2d(3, 64, 3), nn.ReLU())
        target_layer = model[0]

        with pytest.raises(EmptyBaselineDatasetError):
            provider.initialize(model, target_layer, torch.device("cpu"))

    def test_protocol_compliance(self, tmp_path: Path):
        """Test that DirectoryProvider implements protocols."""
        provider = DirectoryProvider(tmp_path)

        # Check protocol compliance
        assert isinstance(provider, BaselineProvider)
        assert isinstance(provider, CacheableProvider)


class TestTorchDatasetProvider:
    """Tests for TorchDatasetProvider."""

    def test_wraps_torch_dataset(self):
        """Test that provider wraps PyTorch datasets."""
        from torch.utils.data import TensorDataset

        # Create a simple dataset
        images = torch.randn(50, 3, 32, 32)
        labels = torch.randint(0, 10, (50,))
        dataset = TensorDataset(images, labels)

        provider = TorchDatasetProvider(dataset, max_samples=20)

        assert provider.provider_type == "torch_dataset"

    def test_empty_dataset_raises_error(self):
        """Test that empty dataset raises error."""
        from torch.utils.data import TensorDataset

        # Create empty dataset
        images = torch.randn(0, 3, 32, 32)
        labels = torch.randint(0, 10, (0,))
        dataset = TensorDataset(images, labels)

        provider = TorchDatasetProvider(dataset)

        model = nn.Sequential(nn.Conv2d(3, 64, 3), nn.ReLU())
        target_layer = model[0]

        # Should raise error during initialization
        from expected_gradcam.exceptions.baseline import EmptyBaselineDatasetError

        with pytest.raises(EmptyBaselineDatasetError):
            provider.initialize(model, target_layer, torch.device("cpu"))

    def test_protocol_compliance(self):
        """Test that TorchDatasetProvider implements protocols."""
        from torch.utils.data import TensorDataset

        images = torch.randn(10, 3, 32, 32)
        labels = torch.randint(0, 10, (10,))
        dataset = TensorDataset(images, labels)

        provider = TorchDatasetProvider(dataset)

        assert isinstance(provider, BaselineProvider)
        assert isinstance(provider, CacheableProvider)


class TestImageNetProvider:
    """Tests for ImageNetProvider."""

    def test_nonexistent_root_raises_error(self, tmp_path: Path):
        """Test that nonexistent root raises error."""
        nonexistent = tmp_path / "nonexistent"

        provider = ImageNetProvider(nonexistent)

        model = nn.Sequential(nn.Conv2d(3, 64, 3), nn.ReLU())
        target_layer = model[0]

        with pytest.raises(DirectoryNotFoundError):
            provider.initialize(model, target_layer, torch.device("cpu"))

    def test_protocol_compliance(self, tmp_path: Path):
        """Test that ImageNetProvider implements protocols."""
        provider = ImageNetProvider(tmp_path)

        assert isinstance(provider, BaselineProvider)
        assert isinstance(provider, CacheableProvider)


class TestProviderCaching:
    """Tests for provider caching functionality."""

    def test_cacheable_provider_save_load_cycle(self, tmp_path: Path):
        """Test save and load cache cycle for CacheableProvider."""
        from torch.utils.data import TensorDataset

        # Create a dataset with known features
        images = torch.randn(30, 3, 32, 32)
        labels = torch.randint(0, 10, (30,))
        dataset = TensorDataset(images, labels)

        # Create provider with cache path
        cache_path = str(tmp_path / "cache.npy")
        provider = TorchDatasetProvider(dataset, cache_path=cache_path)

        # Simple model for testing
        model = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(4),
        )
        target_layer = model[0]
        model.eval()

        # Initialize (this extracts features and saves cache)
        provider.initialize(model, target_layer, torch.device("cpu"))

        # Verify cache was saved
        assert Path(cache_path).exists()

        # Create new provider and load from cache
        new_provider = TorchDatasetProvider(
            dataset,  # Dataset still needed but won't be used
            cache_path=cache_path,
        )
        new_provider.initialize(model, target_layer, torch.device("cpu"))

        # Get samples from both providers
        samples1 = provider.get_baseline_samples(10, torch.device("cpu"))
        samples2 = new_provider.get_baseline_samples(10, torch.device("cpu"))

        # They should have the same shape and both be centered
        assert samples1.shape == samples2.shape
        assert samples1.mean(dim=0).abs().max() < 1e-5
        assert samples2.mean(dim=0).abs().max() < 1e-5
