"""Tests for baseline sampling classes."""

import pytest
import torch

from expected_gradcam.sampling import (
    CenteredBaselineSampler,
    DataAwareBaselineSampler,
    sample_centered_baselines,
    verify_centered,
)


class TestCenteredBaselineSampler:
    """Tests for CenteredBaselineSampler class."""

    @pytest.fixture
    def sampler(self):
        """Create a default sampler."""
        return CenteredBaselineSampler("gaussian", scale=0.1)

    def test_sample_shape(self, sampler):
        """Test that samples have correct shape."""
        K, N = 64, 20
        samples = sampler.sample(K, N)
        assert samples.shape == (N, K)

    def test_samples_are_centered(self, sampler):
        """Test that samples are centered (E[z'] = 0)."""
        samples = sampler.sample(K=128, N=100)
        assert verify_centered(samples)

    def test_gaussian_distribution(self):
        """Test Gaussian distribution properties."""
        sampler = CenteredBaselineSampler("gaussian", scale=0.1)
        samples = sampler.sample(K=64, N=1000)

        # Std should be approximately scale
        std = samples.std()
        assert abs(std.item() - 0.1) < 0.05

    def test_uniform_distribution(self):
        """Test uniform distribution properties."""
        sampler = CenteredBaselineSampler("uniform", scale=1.0)
        samples = sampler.sample(K=64, N=1000)

        # Values should be in [-scale, scale] (approximately)
        assert samples.max() < 1.5
        assert samples.min() > -1.5

    def test_sphere_distribution(self):
        """Test sphere distribution properties."""
        sampler = CenteredBaselineSampler("sphere", scale=1.0)
        samples = sampler.sample(K=64, N=100)

        # All samples should have same norm (approximately)
        norms = samples.norm(dim=1)
        assert norms.std() < 0.5  # Should be relatively uniform

    def test_device_placement(self):
        """Test that samples are on correct device."""
        sampler = CenteredBaselineSampler("gaussian", device="cpu")
        samples = sampler.sample(K=64, N=20)
        assert samples.device.type == "cpu"

    def test_sample_with_stats(self, sampler):
        """Test sample_with_stats method."""
        samples, mean, std = sampler.sample_with_stats(K=64, N=100)

        assert samples.shape == (100, 64)
        assert mean.shape == (64,)
        assert std.shape == (64,)

        # Mean should be essentially zero
        assert mean.abs().max() < 1e-6

    def test_invalid_distribution(self):
        """Test that invalid distribution raises error."""
        sampler = CenteredBaselineSampler("invalid", scale=0.1)
        with pytest.raises(ValueError, match="Unknown distribution"):
            sampler.sample(K=64, N=20)


class TestDataAwareBaselineSampler:
    """Tests for DataAwareBaselineSampler class."""

    @pytest.fixture
    def cache(self):
        """Create a mock feature map cache."""
        return torch.randn(100, 64)

    def test_sample_from_cache(self, cache):
        """Test sampling from cache."""
        sampler = DataAwareBaselineSampler(cache)
        samples = sampler.sample(K=64, N=20)

        assert samples.shape == (20, 64)

    def test_samples_are_centered(self, cache):
        """Test that samples from cache are centered."""
        sampler = DataAwareBaselineSampler(cache)
        samples = sampler.sample(K=64, N=50)

        assert verify_centered(samples)

    def test_set_cache(self):
        """Test setting cache after initialization."""
        sampler = DataAwareBaselineSampler()
        cache = torch.randn(100, 64)

        # Should fail without cache
        with pytest.raises(ValueError, match="cache not set"):
            sampler.sample(K=64, N=20)

        # Should work after setting cache
        sampler.set_cache(cache)
        samples = sampler.sample(K=64, N=20)
        assert samples.shape == (20, 64)


class TestSampleCenteredBaselines:
    """Tests for sample_centered_baselines convenience function."""

    def test_basic_usage(self):
        """Test basic usage."""
        baselines = sample_centered_baselines(K=128, N=30, scale=0.1)

        assert baselines.shape == (30, 128)
        assert verify_centered(baselines)

    def test_different_distributions(self):
        """Test different distribution options."""
        for dist in ["gaussian", "uniform", "sphere"]:
            baselines = sample_centered_baselines(
                K=64, N=20, distribution=dist
            )
            assert baselines.shape == (20, 64)
            assert verify_centered(baselines)

    def test_device_parameter(self):
        """Test device parameter."""
        baselines = sample_centered_baselines(K=64, N=20, device="cpu")
        assert baselines.device.type == "cpu"
