"""Tests for infidelity metrics."""

from __future__ import annotations

import math

import pytest
import torch
from torch import Tensor

from expected_gradcam.metrics.exceptions import (
    InfidelityComputationError,
    InvalidMetricInputError,
)
from expected_gradcam.metrics.infidelity import InternalInfidelity


class TestInternalInfidelity:
    """Tests for InternalInfidelity metric."""

    def test_compute_basic(
        self,
        sample_alpha: Tensor,
        sample_I_samples: Tensor,
        sample_g_z0: float,
        sample_g_perturbed: Tensor,
    ):
        """Basic computation should return a float."""
        metric = InternalInfidelity()
        result = metric.compute(
            alpha=sample_alpha,
            I_samples=sample_I_samples,
            g_z0=sample_g_z0,
            g_perturbed=sample_g_perturbed,
        )

        assert isinstance(result, float)
        assert math.isfinite(result)

    def test_infidelity_non_negative(
        self,
        sample_alpha: Tensor,
        sample_I_samples: Tensor,
        sample_g_z0: float,
        sample_g_perturbed: Tensor,
    ):
        """Infidelity (MSE) should always be non-negative."""
        metric = InternalInfidelity()
        result = metric.compute(
            alpha=sample_alpha,
            I_samples=sample_I_samples,
            g_z0=sample_g_z0,
            g_perturbed=sample_g_perturbed,
        )

        assert result >= 0.0

    def test_perfect_prediction_zero_infidelity(
        self,
        sample_alpha: Tensor,
        sample_I_samples: Tensor,
        sample_g_z0: float,
        perfect_g_perturbed: Tensor,
    ):
        """When predicted == actual, infidelity should be ~0."""
        metric = InternalInfidelity()
        result = metric.compute(
            alpha=sample_alpha,
            I_samples=sample_I_samples,
            g_z0=sample_g_z0,
            g_perturbed=perfect_g_perturbed,
        )

        # Should be very close to zero (floating point tolerance)
        assert result < 1e-10

    def test_missing_alpha_raises_error(
        self,
        sample_I_samples: Tensor,
        sample_g_z0: float,
        sample_g_perturbed: Tensor,
    ):
        """Missing alpha should raise InvalidMetricInputError."""
        metric = InternalInfidelity()

        with pytest.raises(InvalidMetricInputError) as exc_info:
            metric.compute(
                alpha=None,
                I_samples=sample_I_samples,
                g_z0=sample_g_z0,
                g_perturbed=sample_g_perturbed,
            )

        assert "alpha" in str(exc_info.value)

    def test_missing_I_samples_raises_error(
        self,
        sample_alpha: Tensor,
        sample_g_z0: float,
        sample_g_perturbed: Tensor,
    ):
        """Missing I_samples should raise InvalidMetricInputError."""
        metric = InternalInfidelity()

        with pytest.raises(InvalidMetricInputError):
            metric.compute(
                alpha=sample_alpha,
                I_samples=None,
                g_z0=sample_g_z0,
                g_perturbed=sample_g_perturbed,
            )

    def test_shape_mismatch_raises_error(
        self,
        sample_g_z0: float,
        sample_g_perturbed: Tensor,
    ):
        """Mismatched shapes should raise error."""
        metric = InternalInfidelity()

        # alpha [64] doesn't match I_samples [100, 32]
        alpha = torch.randn(64)
        I_samples = torch.randn(100, 32)

        with pytest.raises((InvalidMetricInputError, RuntimeError)):
            metric.compute(
                alpha=alpha,
                I_samples=I_samples,
                g_z0=sample_g_z0,
                g_perturbed=sample_g_perturbed,
            )

    def test_metric_timing(
        self,
        sample_alpha: Tensor,
        sample_I_samples: Tensor,
        sample_g_z0: float,
        sample_g_perturbed: Tensor,
    ):
        """Metric should record timing via @timed decorator."""
        metric = InternalInfidelity()
        metric.compute(
            alpha=sample_alpha,
            I_samples=sample_I_samples,
            g_z0=sample_g_z0,
            g_perturbed=sample_g_perturbed,
        )

        assert hasattr(metric, "last_compute_time_ms")
        assert metric.last_compute_time_ms >= 0

    def test_infidelity_with_different_sample_sizes(self, sample_alpha: Tensor):
        """Infidelity should work with different M values."""
        metric = InternalInfidelity()

        for M in [10, 50, 100, 500]:
            I_samples = torch.randn(M, 64)
            g_perturbed = torch.randn(M)

            result = metric.compute(
                alpha=sample_alpha,
                I_samples=I_samples,
                g_z0=5.0,
                g_perturbed=g_perturbed,
            )

            assert isinstance(result, float)
            assert result >= 0.0


class TestInternalInfidelityMathematical:
    """Mathematical property tests for InternalInfidelity."""

    def test_infidelity_formula_correctness(self):
        """Verify the infidelity formula: E[(I^T α - (g_z0 - g_perturbed))²]."""
        # Create simple known values
        alpha = torch.tensor([1.0, 2.0])
        I_samples = torch.tensor([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]])  # [3, 2]
        g_z0 = 5.0
        g_perturbed = torch.tensor([4.0, 3.0, 2.0])  # [3]

        # Manually compute expected infidelity
        # predicted_change = I @ alpha = [1, 2, 3]
        # actual_change = g_z0 - g_perturbed = [1, 2, 3]
        # error = predicted - actual = [0, 0, 0]
        # MSE = 0

        metric = InternalInfidelity()
        result = metric.compute(
            alpha=alpha,
            I_samples=I_samples,
            g_z0=g_z0,
            g_perturbed=g_perturbed,
        )

        assert result < 1e-10  # Should be ~0

    def test_infidelity_increases_with_error(self):
        """Infidelity should increase as prediction error increases."""
        metric = InternalInfidelity()
        alpha = torch.ones(64)
        I_samples = torch.randn(100, 64)

        # Perfect prediction
        predicted = torch.mv(I_samples, alpha)
        g_z0 = 5.0
        perfect_g_perturbed = g_z0 - predicted

        infid_perfect = metric.compute(
            alpha=alpha,
            I_samples=I_samples,
            g_z0=g_z0,
            g_perturbed=perfect_g_perturbed,
        )

        # Add increasing noise
        for noise_scale in [0.1, 1.0, 10.0]:
            noisy_g_perturbed = perfect_g_perturbed + torch.randn(100) * noise_scale

            infid_noisy = metric.compute(
                alpha=alpha,
                I_samples=I_samples,
                g_z0=g_z0,
                g_perturbed=noisy_g_perturbed,
            )

            assert infid_noisy > infid_perfect

    def test_scaling_invariance(self):
        """Infidelity should scale quadratically with error magnitude."""
        metric = InternalInfidelity()
        alpha = torch.ones(64)
        I_samples = torch.randn(100, 64)
        g_z0 = 5.0

        # Base case with fixed noise
        torch.manual_seed(123)
        noise = torch.randn(100)
        predicted = torch.mv(I_samples, alpha)
        g_perturbed_base = g_z0 - predicted + noise

        infid_base = metric.compute(
            alpha=alpha,
            I_samples=I_samples,
            g_z0=g_z0,
            g_perturbed=g_perturbed_base,
        )

        # Double the noise -> 4x the infidelity (MSE scales quadratically)
        g_perturbed_2x = g_z0 - predicted + 2 * noise

        infid_2x = metric.compute(
            alpha=alpha,
            I_samples=I_samples,
            g_z0=g_z0,
            g_perturbed=g_perturbed_2x,
        )

        # Should be ~4x (with some tolerance)
        ratio = infid_2x / infid_base
        assert 3.5 < ratio < 4.5


class TestInternalInfidelityCUDA:
    """CUDA-specific tests for InternalInfidelity."""

    @pytest.mark.gpu
    def test_cuda_computation(self, cuda_tensors):
        """Infidelity should work on CUDA tensors."""
        if cuda_tensors is None:
            pytest.skip("CUDA not available")

        metric = InternalInfidelity()
        alpha = cuda_tensors["alpha"]
        I_samples = cuda_tensors["I_samples"]
        g_perturbed = torch.randn(100, device="cuda")

        result = metric.compute(
            alpha=alpha,
            I_samples=I_samples,
            g_z0=5.0,
            g_perturbed=g_perturbed,
        )

        assert isinstance(result, float)
        assert result >= 0.0
