"""Unit tests for the heatmap module.

Tests heatmap generation and processing functions.
"""

from __future__ import annotations

import pytest
import torch
from torch import Tensor

import numpy as np

from expected_gradcam.core.heatmap import (
    apply_contrast_enhancement,
    generate_heatmap,
    normalize_heatmap,
    process_heatmap,
    upsample_heatmap,
)


class TestGenerateHeatmap:
    """Test generate_heatmap function."""

    @pytest.fixture
    def sample_weights(self) -> Tensor:
        """Create sample weights."""
        torch.manual_seed(42)
        return torch.randn(64)

    @pytest.fixture
    def sample_features(self) -> Tensor:
        """Create sample feature maps."""
        torch.manual_seed(42)
        return torch.relu(torch.randn(1, 64, 8, 8))

    def test_basic_generation(self, sample_weights, sample_features):
        """Test basic heatmap generation."""
        heatmap = generate_heatmap(sample_features, sample_weights)

        assert isinstance(heatmap, Tensor)
        # Output should be [B, H, W] or [H, W]
        assert heatmap.ndim in (2, 3)

    def test_output_shape(self, sample_weights, sample_features):
        """Test output shape matches feature spatial dimensions."""
        heatmap = generate_heatmap(sample_features, sample_weights)

        # Should match spatial dimensions of features
        if heatmap.ndim == 3:
            assert heatmap.shape[1:] == sample_features.shape[2:]
        else:
            assert heatmap.shape == sample_features.shape[2:]

    def test_relu_application(self, sample_weights, sample_features):
        """Test ReLU is applied by default."""
        heatmap = generate_heatmap(sample_features, sample_weights, apply_relu=True)

        assert (heatmap >= 0).all()

    def test_no_relu(self, sample_weights, sample_features):
        """Test ReLU can be disabled."""
        heatmap = generate_heatmap(sample_features, sample_weights, apply_relu=False)

        # Might have negative values
        # (doesn't guarantee negatives, but allows them)
        assert isinstance(heatmap, Tensor)

    def test_weighted_combination(self, sample_features):
        """Test heatmap is weighted combination of channels."""
        # Simple weights: only first channel
        weights = torch.zeros(64)
        weights[0] = 1.0

        heatmap = generate_heatmap(sample_features, weights, apply_relu=False)

        # Heatmap should match first channel
        expected = sample_features[0, 0]
        assert torch.allclose(heatmap.squeeze(), expected.squeeze(), atol=1e-5)

    def test_batch_support(self, sample_weights):
        """Test batch of feature maps."""
        torch.manual_seed(42)
        features_batch = torch.relu(torch.randn(4, 64, 8, 8))

        heatmap = generate_heatmap(features_batch, sample_weights)

        assert heatmap.shape[0] == 4


class TestNormalizeHeatmap:
    """Test normalize_heatmap function."""

    @pytest.fixture
    def sample_heatmap(self) -> Tensor:
        """Create sample heatmap."""
        torch.manual_seed(42)
        return torch.rand(8, 8) * 10 + 5  # Values in [5, 15]

    def test_minmax_normalization(self, sample_heatmap):
        """Test min-max normalization."""
        normalized = normalize_heatmap(sample_heatmap, method="minmax")

        assert normalized.min() >= 0.0
        assert normalized.max() <= 1.0
        assert torch.isclose(normalized.min(), torch.tensor(0.0), atol=1e-5)
        assert torch.isclose(normalized.max(), torch.tensor(1.0), atol=1e-5)

    def test_quantile_normalization(self, sample_heatmap):
        """Test quantile normalization."""
        normalized = normalize_heatmap(sample_heatmap, method="quantile")

        assert normalized.min() >= 0.0
        assert normalized.max() <= 1.0

    def test_sum_normalization(self, sample_heatmap):
        """Test sum normalization."""
        normalized = normalize_heatmap(sample_heatmap, method="sum")

        # Sum normalized values might not be in [0, 1] per-pixel
        # but the total should sum to ~1
        assert not torch.isnan(normalized).any()

    def test_constant_heatmap(self):
        """Test handling of constant heatmap."""
        constant = torch.ones(8, 8) * 5

        normalized = normalize_heatmap(constant, method="minmax")

        # Should not have NaN
        assert not torch.isnan(normalized).any()

    def test_batch_normalization(self):
        """Test normalization preserves batch dimension."""
        torch.manual_seed(42)
        heatmap_batch = torch.rand(4, 8, 8)

        normalized = normalize_heatmap(heatmap_batch, method="minmax")

        assert normalized.shape == heatmap_batch.shape


class TestUpsampleHeatmap:
    """Test upsample_heatmap function."""

    @pytest.fixture
    def sample_heatmap(self) -> Tensor:
        """Create sample heatmap."""
        torch.manual_seed(42)
        return torch.rand(1, 8, 8)

    def test_upscale(self, sample_heatmap):
        """Test upscaling heatmap."""
        resized = upsample_heatmap(sample_heatmap, target_size=(224, 224))

        assert resized.shape[-2:] == (224, 224)

    def test_downscale(self, sample_heatmap):
        """Test downscaling heatmap."""
        resized = upsample_heatmap(sample_heatmap, target_size=(4, 4))

        assert resized.shape[-2:] == (4, 4)

    def test_bilinear_interpolation(self, sample_heatmap):
        """Test bilinear interpolation mode."""
        resized = upsample_heatmap(sample_heatmap, target_size=(16, 16), mode="bilinear")

        assert resized.shape[-2:] == (16, 16)

    def test_bicubic_interpolation(self, sample_heatmap):
        """Test bicubic interpolation mode."""
        resized = upsample_heatmap(sample_heatmap, target_size=(16, 16), mode="bicubic")

        assert resized.shape[-2:] == (16, 16)

    def test_batch_resize(self):
        """Test resizing batch of heatmaps."""
        torch.manual_seed(42)
        heatmap_batch = torch.rand(4, 8, 8)

        resized = upsample_heatmap(heatmap_batch, target_size=(16, 16))

        assert resized.shape == (4, 16, 16)

    def test_value_range_preserved(self, sample_heatmap):
        """Test value range is approximately preserved."""
        resized = upsample_heatmap(sample_heatmap, target_size=(16, 16))

        # Values should be in similar range (interpolation might slightly change bounds)
        assert resized.min() >= sample_heatmap.min() - 0.1
        assert resized.max() <= sample_heatmap.max() + 0.1


class TestProcessHeatmap:
    """Test process_heatmap convenience function."""

    @pytest.fixture
    def sample_weights(self) -> Tensor:
        """Create sample weights."""
        torch.manual_seed(42)
        return torch.randn(64)

    @pytest.fixture
    def sample_features(self) -> Tensor:
        """Create sample feature maps."""
        torch.manual_seed(42)
        return torch.relu(torch.randn(1, 64, 8, 8))

    def test_full_pipeline(self, sample_weights, sample_features):
        """Test full heatmap processing pipeline."""
        heatmap, coarse_heatmap = process_heatmap(
            sample_features,
            sample_weights,
            input_size=(224, 224),
            normalize=True,
        )

        assert heatmap.shape[-2:] == (224, 224)
        assert heatmap.min() >= 0.0
        assert heatmap.max() <= 1.0

    def test_without_normalization(self, sample_weights, sample_features):
        """Test processing without normalization."""
        heatmap, coarse_heatmap = process_heatmap(
            sample_features,
            sample_weights,
            input_size=(224, 224),
            normalize=False,
        )

        # Min should be 0 (due to ReLU) but max can be > 1
        assert heatmap.min() >= 0.0

    def test_returns_coarse_heatmap(self, sample_weights, sample_features):
        """Test that coarse heatmap is also returned."""
        heatmap, coarse_heatmap = process_heatmap(
            sample_features,
            sample_weights,
            input_size=(224, 224),
        )

        # Coarse heatmap should match feature map spatial size
        assert coarse_heatmap.shape[-2:] == sample_features.shape[-2:]


class TestApplyContrastEnhancement:
    """Test apply_contrast_enhancement function."""

    @pytest.fixture
    def sample_heatmap_torch(self) -> Tensor:
        """Create sample normalized heatmap (torch)."""
        torch.manual_seed(42)
        return torch.rand(8, 8)

    @pytest.fixture
    def sample_heatmap_numpy(self) -> np.ndarray:
        """Create sample normalized heatmap (numpy)."""
        np.random.seed(42)
        return np.random.rand(8, 8).astype(np.float32)

    def test_torch_input(self, sample_heatmap_torch):
        """Test with torch tensor input."""
        enhanced = apply_contrast_enhancement(sample_heatmap_torch, boost_factor=0.15)

        assert isinstance(enhanced, Tensor)
        assert enhanced.shape == sample_heatmap_torch.shape

    def test_numpy_input(self, sample_heatmap_numpy):
        """Test with numpy array input."""
        enhanced = apply_contrast_enhancement(sample_heatmap_numpy, boost_factor=0.15)

        assert isinstance(enhanced, np.ndarray)
        assert enhanced.shape == sample_heatmap_numpy.shape

    def test_output_range(self, sample_heatmap_torch):
        """Test that output stays in [0, 1] range."""
        enhanced = apply_contrast_enhancement(sample_heatmap_torch, boost_factor=0.15)

        assert enhanced.min() >= 0.0
        assert enhanced.max() <= 1.0

    def test_contrast_increases(self, sample_heatmap_torch):
        """Test that contrast is actually increased."""
        original = sample_heatmap_torch
        enhanced = apply_contrast_enhancement(original, boost_factor=0.15)

        # Compute variance (measure of contrast)
        original_var = original.var()
        enhanced_var = enhanced.var()

        # Enhanced should have higher variance (more contrast)
        # Note: clipping may reduce this effect for extreme values
        assert enhanced_var >= original_var * 0.95  # Allow some tolerance

    def test_mean_preserved(self, sample_heatmap_torch):
        """Test that mean is approximately preserved."""
        original = sample_heatmap_torch
        enhanced = apply_contrast_enhancement(original, boost_factor=0.15)

        # Mean should be approximately preserved (may change slightly due to clipping)
        assert torch.isclose(original.mean(), enhanced.mean(), atol=0.05)

    def test_zero_boost(self, sample_heatmap_torch):
        """Test that boost_factor=0 returns unchanged heatmap."""
        enhanced = apply_contrast_enhancement(sample_heatmap_torch, boost_factor=0.0)

        assert torch.allclose(enhanced, sample_heatmap_torch)

    def test_large_boost(self, sample_heatmap_torch):
        """Test with large boost factor."""
        enhanced = apply_contrast_enhancement(sample_heatmap_torch, boost_factor=0.5)

        # Output should still be in valid range
        assert enhanced.min() >= 0.0
        assert enhanced.max() <= 1.0

    def test_batch_input(self):
        """Test with batched input."""
        torch.manual_seed(42)
        batch_heatmap = torch.rand(4, 8, 8)

        enhanced = apply_contrast_enhancement(batch_heatmap, boost_factor=0.15)

        assert enhanced.shape == batch_heatmap.shape
        assert enhanced.min() >= 0.0
        assert enhanced.max() <= 1.0

    def test_dtype_preserved_numpy(self, sample_heatmap_numpy):
        """Test that numpy dtype is preserved."""
        enhanced = apply_contrast_enhancement(sample_heatmap_numpy, boost_factor=0.15)

        assert enhanced.dtype == sample_heatmap_numpy.dtype
