"""Tests for sampling utility functions."""

import pytest
import torch

from expected_gradcam.sampling.utils import (
    center_samples,
    verify_centered,
    compute_optimal_baseline_scale,
    normalize_perturbations,
    safe_divide,
)


class TestCenterSamples:
    """Tests for center_samples function."""

    def test_centers_samples(self):
        """Test that samples are centered to have zero mean."""
        samples = torch.randn(100, 64) * 5 + 10  # Non-zero mean
        centered = center_samples(samples)

        # Mean should be essentially zero (within floating point precision)
        mean = centered.mean(dim=0)
        assert mean.abs().max() < 1e-5

    def test_shape_preserved(self):
        """Test that output shape matches input."""
        samples = torch.randn(50, 128)
        centered = center_samples(samples)
        assert centered.shape == samples.shape

    def test_already_centered(self):
        """Test that already-centered samples are unchanged."""
        samples = torch.randn(100, 64)
        samples = samples - samples.mean(dim=0, keepdim=True)

        centered = center_samples(samples)
        assert torch.allclose(centered, samples, atol=1e-6)


class TestVerifyCentered:
    """Tests for verify_centered function."""

    def test_returns_true_for_centered(self):
        """Test that centered samples pass verification."""
        samples = torch.randn(100, 64)
        centered = center_samples(samples)
        assert verify_centered(centered) is True

    def test_returns_false_for_non_centered(self):
        """Test that non-centered samples fail verification."""
        samples = torch.randn(100, 64) + 1.0  # Shift mean
        assert verify_centered(samples) is False

    def test_custom_tolerance(self):
        """Test with custom tolerance."""
        samples = torch.randn(100, 64)
        centered = center_samples(samples)

        # Should pass with larger tolerance
        assert verify_centered(centered, tolerance=1e-3) is True

        # Add small offset
        slightly_off = centered + 1e-5
        assert verify_centered(slightly_off, tolerance=1e-4) is True
        assert verify_centered(slightly_off, tolerance=1e-7) is False


class TestComputeOptimalBaselineScale:
    """Tests for compute_optimal_baseline_scale function."""

    def test_returns_positive_scale(self):
        """Test that returned scale is positive."""
        features = torch.randn(1, 2048, 7, 7)
        scale = compute_optimal_baseline_scale(features)
        assert scale > 0

    def test_scale_proportional_to_feature_std(self):
        """Test that scale is proportional to feature map variation."""
        # Small variation
        small_features = torch.randn(1, 64, 7, 7) * 0.1
        small_scale = compute_optimal_baseline_scale(small_features)

        # Large variation
        large_features = torch.randn(1, 64, 7, 7) * 10.0
        large_scale = compute_optimal_baseline_scale(large_features)

        # Large should be bigger
        assert large_scale > small_scale

    def test_percentile_affects_scale(self):
        """Test that percentile parameter affects scale."""
        features = torch.randn(1, 64, 7, 7)

        scale_low = compute_optimal_baseline_scale(features, percentile=0.05)
        scale_high = compute_optimal_baseline_scale(features, percentile=0.2)

        assert scale_high > scale_low


class TestNormalizePerturbations:
    """Tests for normalize_perturbations function."""

    def test_normalizes_to_target_scale(self):
        """Test that output has target standard deviation."""
        samples = torch.randn(100, 64) * 10  # Large scale

        target_scale = 0.3
        normalized = normalize_perturbations(samples, target_scale=target_scale)

        # Should be close to target scale
        assert abs(normalized.std().item() - target_scale) < 0.05

    def test_centers_by_default(self):
        """Test that samples are centered by default."""
        samples = torch.randn(100, 64) + 5  # Non-zero mean
        normalized = normalize_perturbations(samples)

        mean = normalized.mean(dim=0)
        assert mean.abs().max() < 1e-6

    def test_can_skip_centering(self):
        """Test that centering can be skipped."""
        samples = torch.randn(100, 64) + 5
        normalized = normalize_perturbations(samples, center=False)

        # Mean should NOT be zero
        mean = normalized.mean(dim=0)
        assert mean.abs().max() > 0.1


class TestSafeDivide:
    """Tests for safe_divide function."""

    def test_normal_division(self):
        """Test that normal division works correctly."""
        num = torch.tensor([1.0, 2.0, 3.0])
        denom = torch.tensor([2.0, 4.0, 3.0])

        result = safe_divide(num, denom)
        expected = num / denom

        assert torch.allclose(result, expected)

    def test_handles_small_values(self):
        """Test that small values don't cause issues."""
        num = torch.tensor([1.0, 2.0, 3.0])
        denom = torch.tensor([1e-10, 1e-10, 1.0])

        result = safe_divide(num, denom)

        # Should not be inf or nan
        assert torch.isfinite(result).all()

    def test_handles_zeros(self):
        """Test that zeros don't cause issues."""
        num = torch.tensor([1.0, 2.0, 3.0])
        denom = torch.tensor([0.0, 0.0, 1.0])

        result = safe_divide(num, denom)

        # Should not be inf or nan
        assert torch.isfinite(result).all()

    def test_custom_threshold(self):
        """Test with custom minimum threshold."""
        num = torch.ones(10)
        denom = torch.ones(10) * 0.001

        # With default threshold
        result1 = safe_divide(num, denom)

        # With custom larger threshold
        result2 = safe_divide(num, denom, min_threshold=0.1)

        # Results should differ
        assert not torch.allclose(result1, result2)
