"""Tests for perturbation sampling classes."""

import pytest
import torch

from expected_gradcam.sampling import (
    SimpleBatchedSampler,
)


class TestSimpleBatchedSampler:
    """Tests for SimpleBatchedSampler class."""

    @pytest.fixture
    def sampler(self):
        """Create a default sampler."""
        return SimpleBatchedSampler(scale=0.3)

    def test_sample_shape(self, sampler):
        """Test that samples have correct shape."""
        z0 = torch.ones(64)
        samples = sampler.sample(z0, M=100)
        assert samples.shape == (100, 64)

    def test_sample_scale(self, sampler):
        """Test that samples have approximately target scale."""
        z0 = torch.ones(128)
        samples = sampler.sample(z0, M=1000)

        std = samples.std()
        assert abs(std.item() - 0.3) < 0.05

    def test_different_scales(self):
        """Test different scale values."""
        z0 = torch.ones(64)

        sampler_small = SimpleBatchedSampler(scale=0.1)
        sampler_large = SimpleBatchedSampler(scale=1.0)

        small_samples = sampler_small.sample(z0, M=100)
        large_samples = sampler_large.sample(z0, M=100)

        assert small_samples.std() < large_samples.std()

    def test_device_placement(self):
        """Test that samples are on correct device."""
        sampler = SimpleBatchedSampler(scale=0.3, device="cpu")
        z0 = torch.ones(64)
        samples = sampler.sample(z0, M=20)

        assert samples.device.type == "cpu"

    def test_zero_m(self):
        """Test with M=0 (edge case)."""
        sampler = SimpleBatchedSampler(scale=0.3)
        z0 = torch.ones(64)
        samples = sampler.sample(z0, M=0)

        assert samples.shape == (0, 64)

    def test_single_sample(self):
        """Test with M=1."""
        sampler = SimpleBatchedSampler(scale=0.3)
        z0 = torch.ones(64)
        samples = sampler.sample(z0, M=1)

        assert samples.shape == (1, 64)


# Note: DataAwarePerturbationSampler and BatchedPerturbationSampler
# require a model and dataset, so they are tested in integration tests.
# Here we just verify the class can be imported.

class TestDataAwareSamplerImport:
    """Test that data-aware samplers can be imported."""

    def test_import_data_aware_sampler(self):
        """Test that DataAwarePerturbationSampler can be imported."""
        from expected_gradcam.sampling import DataAwarePerturbationSampler
        assert DataAwarePerturbationSampler is not None

    def test_import_batched_sampler(self):
        """Test that BatchedPerturbationSampler can be imported."""
        from expected_gradcam.sampling import BatchedPerturbationSampler
        assert BatchedPerturbationSampler is not None
