"""Tests for heatmap metrics (entropy, Gini, weight norm)."""

from __future__ import annotations

import math

import pytest
import torch
from torch import Tensor

from expected_gradcam.metrics.exceptions import InvalidMetricInputError
from expected_gradcam.metrics.heatmap import (
    HeatmapEntropy,
    HeatmapGini,
    WeightNorm,
    WeightSparsity,
    compute_heatmap_stats,
    compute_weight_stats,
)


class TestHeatmapEntropy:
    """Tests for HeatmapEntropy metric."""

    def test_compute_basic(self, sample_heatmap_2d: Tensor):
        """Basic computation should return a float in [0, 1]."""
        metric = HeatmapEntropy(normalize=True)
        result = metric.compute(heatmap=sample_heatmap_2d)

        assert isinstance(result, float)
        assert 0.0 <= result <= 1.0

    def test_uniform_max_entropy(self, uniform_heatmap: Tensor):
        """Uniform distribution should have max entropy (normalized ~1)."""
        metric = HeatmapEntropy(normalize=True)
        result = metric.compute(heatmap=uniform_heatmap)

        # Should be close to 1.0 (max entropy)
        assert result > 0.95

    def test_single_point_min_entropy(self, single_point_heatmap: Tensor):
        """Single point should have minimum entropy (normalized ~0)."""
        metric = HeatmapEntropy(normalize=True)
        result = metric.compute(heatmap=single_point_heatmap)

        # Should be close to 0.0 (min entropy)
        assert result < 0.1

    def test_focused_medium_entropy(self, focused_heatmap: Tensor):
        """Focused heatmap should have medium entropy."""
        metric = HeatmapEntropy(normalize=True)
        result = metric.compute(heatmap=focused_heatmap)

        # Should be between extremes
        assert 0.1 < result < 0.9

    def test_3d_input(self, sample_heatmap_3d: Tensor):
        """Should handle 3D input [B, H, W]."""
        metric = HeatmapEntropy()
        result = metric.compute(heatmap=sample_heatmap_3d)

        assert isinstance(result, float)

    def test_missing_heatmap_raises_error(self):
        """Missing heatmap should raise InvalidMetricInputError."""
        metric = HeatmapEntropy()

        with pytest.raises(InvalidMetricInputError):
            metric.compute(heatmap=None)

    def test_unnormalized_entropy(self, sample_heatmap_2d: Tensor):
        """Unnormalized entropy should be in bits (log scale)."""
        metric_norm = HeatmapEntropy(normalize=True)
        metric_raw = HeatmapEntropy(normalize=False)

        entropy_norm = metric_norm.compute(heatmap=sample_heatmap_2d)
        entropy_raw = metric_raw.compute(heatmap=sample_heatmap_2d)

        # Raw should be larger (not divided by max)
        # Max possible entropy for 64 pixels is log(64) ≈ 4.16
        assert entropy_raw > entropy_norm


class TestHeatmapGini:
    """Tests for HeatmapGini metric."""

    def test_compute_basic(self, sample_heatmap_2d: Tensor):
        """Basic computation should return a float in [0, 1]."""
        metric = HeatmapGini()
        result = metric.compute(heatmap=sample_heatmap_2d)

        assert isinstance(result, float)
        assert 0.0 <= result <= 1.0

    def test_uniform_zero_gini(self, uniform_heatmap: Tensor):
        """Uniform distribution should have Gini ≈ 0."""
        metric = HeatmapGini()
        result = metric.compute(heatmap=uniform_heatmap)

        # Should be close to 0 (perfect equality)
        assert result < 0.05

    def test_single_point_max_gini(self, single_point_heatmap: Tensor):
        """Single point should have max Gini (close to 1)."""
        metric = HeatmapGini()
        result = metric.compute(heatmap=single_point_heatmap)

        # Should be close to 1 (maximum inequality)
        assert result > 0.95

    def test_focused_medium_gini(self, focused_heatmap: Tensor):
        """Focused heatmap should have medium Gini."""
        metric = HeatmapGini()
        result = metric.compute(heatmap=focused_heatmap)

        # Should be between extremes
        assert 0.3 < result < 0.95

    def test_entropy_gini_inverse_relationship(
        self,
        uniform_heatmap: Tensor,
        focused_heatmap: Tensor,
        single_point_heatmap: Tensor,
    ):
        """Entropy and Gini should be inversely related."""
        entropy_metric = HeatmapEntropy(normalize=True)
        gini_metric = HeatmapGini()

        for heatmap in [uniform_heatmap, focused_heatmap, single_point_heatmap]:
            entropy = entropy_metric.compute(heatmap=heatmap)
            gini = gini_metric.compute(heatmap=heatmap)

            # High entropy -> low Gini, low entropy -> high Gini
            # (inverse relationship, roughly)
            # We just check they're both computed
            assert math.isfinite(entropy)
            assert math.isfinite(gini)


class TestWeightNorm:
    """Tests for WeightNorm metric."""

    def test_compute_l2(self, sample_alpha: Tensor):
        """Default L2 norm computation."""
        metric = WeightNorm(ord=2)
        result = metric.compute(alpha=sample_alpha)

        expected = torch.norm(sample_alpha, p=2).item()
        assert result == pytest.approx(expected, rel=1e-5)

    def test_compute_l1(self, sample_alpha: Tensor):
        """L1 norm computation."""
        metric = WeightNorm(ord=1)
        result = metric.compute(alpha=sample_alpha)

        expected = torch.norm(sample_alpha, p=1).item()
        assert result == pytest.approx(expected, rel=1e-5)

    def test_compute_linf(self, sample_alpha: Tensor):
        """L-inf norm computation."""
        metric = WeightNorm(ord=float("inf"))
        result = metric.compute(alpha=sample_alpha)

        expected = sample_alpha.abs().max().item()
        assert result == pytest.approx(expected, rel=1e-5)

    def test_zero_weights(self):
        """Zero weights should have norm 0."""
        metric = WeightNorm()
        alpha = torch.zeros(64)
        result = metric.compute(alpha=alpha)

        assert result == pytest.approx(0.0)

    def test_missing_alpha_raises_error(self):
        """Missing alpha should raise InvalidMetricInputError."""
        metric = WeightNorm()

        with pytest.raises(InvalidMetricInputError):
            metric.compute(alpha=None)

    def test_wrong_shape_raises_error(self):
        """Non-1D alpha should raise InvalidMetricInputError."""
        metric = WeightNorm()
        alpha_2d = torch.randn(64, 2)

        with pytest.raises(InvalidMetricInputError):
            metric.compute(alpha=alpha_2d)


class TestWeightSparsity:
    """Tests for WeightSparsity metric."""

    def test_compute_basic(self, sample_alpha: Tensor):
        """Basic computation should return a float in [0, 1]."""
        metric = WeightSparsity()
        result = metric.compute(alpha=sample_alpha)

        assert isinstance(result, float)
        assert 0.0 <= result <= 1.0

    def test_sparse_weights_high_sparsity(self, sample_alpha_sparse: Tensor):
        """Sparse weights should have high sparsity score."""
        metric = WeightSparsity(threshold=0.01)
        result = metric.compute(alpha=sample_alpha_sparse)

        # Most weights are zero -> high sparsity
        assert result > 0.7

    def test_dense_weights_low_sparsity(self, sample_alpha: Tensor):
        """Dense random weights should have lower sparsity."""
        metric = WeightSparsity(threshold=0.01)
        result = metric.compute(alpha=sample_alpha)

        # Random weights are not sparse
        assert result < 0.5

    def test_all_zero_full_sparsity(self):
        """All-zero weights should have sparsity = 1.0."""
        metric = WeightSparsity()
        alpha = torch.zeros(64)
        result = metric.compute(alpha=alpha)

        assert result == 1.0

    def test_threshold_affects_sparsity(self, sample_alpha: Tensor):
        """Higher threshold should increase sparsity count."""
        metric_low = WeightSparsity(threshold=0.01)
        metric_high = WeightSparsity(threshold=0.5)

        sparsity_low = metric_low.compute(alpha=sample_alpha)
        sparsity_high = metric_high.compute(alpha=sample_alpha)

        assert sparsity_high >= sparsity_low


class TestConvenienceFunctions:
    """Tests for convenience functions."""

    def test_compute_heatmap_stats(self, sample_heatmap_2d: Tensor):
        """compute_heatmap_stats should return comprehensive stats."""
        stats = compute_heatmap_stats(sample_heatmap_2d)

        assert "entropy" in stats
        assert "gini" in stats
        assert "mean" in stats
        assert "std" in stats
        assert "min" in stats
        assert "max" in stats
        assert "positive_area" in stats
        assert "peak_ratio" in stats

        # Check values are reasonable
        assert 0 <= stats["entropy"] <= 1
        assert 0 <= stats["gini"] <= 1
        assert 0 <= stats["positive_area"] <= 1

    def test_compute_weight_stats(self, sample_alpha: Tensor):
        """compute_weight_stats should return comprehensive stats."""
        stats = compute_weight_stats(sample_alpha)

        assert "l1_norm" in stats
        assert "l2_norm" in stats
        assert "linf_norm" in stats
        assert "mean" in stats
        assert "std" in stats
        assert "min" in stats
        assert "max" in stats
        assert "positive_fraction" in stats

        # Check values are reasonable
        assert stats["l1_norm"] >= 0
        assert stats["l2_norm"] >= 0
        assert 0 <= stats["positive_fraction"] <= 1


class TestHeatmapMetricsCUDA:
    """CUDA-specific tests for heatmap metrics."""

    @pytest.mark.gpu
    def test_entropy_cuda(self):
        """HeatmapEntropy should work on CUDA tensors."""
        if not torch.cuda.is_available():
            pytest.skip("CUDA not available")

        metric = HeatmapEntropy()
        heatmap = torch.rand(8, 8, device="cuda")
        result = metric.compute(heatmap=heatmap)

        assert isinstance(result, float)
        assert 0 <= result <= 1

    @pytest.mark.gpu
    def test_weight_norm_cuda(self):
        """WeightNorm should work on CUDA tensors."""
        if not torch.cuda.is_available():
            pytest.skip("CUDA not available")

        metric = WeightNorm()
        alpha = torch.randn(64, device="cuda")
        result = metric.compute(alpha=alpha)

        assert isinstance(result, float)
        assert result >= 0
