"""Unit tests for path integration completeness axiom.

Tests verify that Integrated Gradients and Expected Gradients satisfy
the completeness axiom:

    I^T @ phi = g(z_0) - g(z_0 - I)

This ensures that attributions correctly explain the change in model output.
"""

from __future__ import annotations

import pytest
import torch
from torch import nn

from expected_gradcam.core.predictor import Predictor, BatchedPredictor
from expected_gradcam.core.path_integration import IntegratedGradients, ExpectedGradients


class SimpleLinearHead(nn.Module):
    """Linear classifier head for testing completeness.

    Uses no bias term for clean mathematical verification.
    """

    def __init__(self, K: int = 64, num_classes: int = 10) -> None:
        super().__init__()
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(K, num_classes, bias=False)

    def forward(self, x):
        x = self.gap(x).flatten(1)
        return self.fc(x)


class TestIntegratedGradientsCompleteness:
    """Test completeness axiom for Integrated Gradients."""

    @pytest.fixture
    def setup(self):
        """Set up test fixtures."""
        torch.manual_seed(42)
        K = 64
        head = SimpleLinearHead(K=K, num_classes=10)
        features = torch.relu(torch.randn(1, K, 8, 8))
        return head, features, K

    def test_completeness_axiom_ig(self, setup):
        """Test: I^T @ phi_IG = g(z0) - g(z0 - I)

        For linear g, the completeness axiom should hold exactly
        (within numerical precision).
        """
        head, features, K = setup
        predictor = Predictor(head, target_class=5, feature_maps=features)

        z0 = torch.ones(K)
        I = torch.rand(K) * 0.5 + 0.1  # [0.1, 0.6]

        ig = IntegratedGradients(T=100)
        phi = ig.compute(predictor, z0, I)

        # LHS: I^T @ phi
        lhs = (I * phi).sum()

        # RHS: g(z0) - g(z0 - I)
        with torch.no_grad():
            rhs = predictor(z0)[0] - predictor(z0 - I)[0]

        # Should be very close (within numerical precision)
        rel_error = abs(lhs - rhs) / (abs(rhs) + 1e-10)
        assert rel_error < 0.01, f"Completeness violated: rel_error={rel_error:.6f}"

    def test_completeness_with_varying_T(self, setup):
        """Test completeness is maintained with varying integration steps.

        For linear g, completeness should be exact regardless of T,
        so all errors should be very small.
        """
        head, features, K = setup
        predictor = Predictor(head, target_class=5, feature_maps=features)

        z0 = torch.ones(K)
        I = torch.rand(K) * 0.5 + 0.1

        with torch.no_grad():
            rhs = predictor(z0)[0] - predictor(z0 - I)[0]

        errors = []
        for T in [10, 50, 100]:
            ig = IntegratedGradients(T=T)
            phi = ig.compute(predictor, z0, I)
            lhs = (I * phi).sum()
            rel_error = abs(lhs - rhs) / (abs(rhs) + 1e-10)
            errors.append(rel_error.item())

        # For linear g, all errors should be small (completeness is exact)
        assert all(e < 0.01 for e in errors), f"Errors too large: {errors}"


class TestExpectedGradientsCompleteness:
    """Test completeness axiom for Expected Gradients."""

    @pytest.fixture
    def setup(self):
        """Set up test fixtures."""
        torch.manual_seed(42)
        K = 64
        head = SimpleLinearHead(K=K, num_classes=10)
        features = torch.relu(torch.randn(1, K, 8, 8))
        return head, features, K

    def test_completeness_axiom_eg(self, setup):
        """Test completeness for Expected Gradients with centered baselines."""
        head, features, K = setup
        predictor = Predictor(head, target_class=5, feature_maps=features)

        z0 = torch.ones(K)
        I = torch.rand(K) * 0.5 + 0.1

        # Centered baselines (required for completeness)
        D_samples = torch.randn(50, K) * 0.1
        D_samples = D_samples - D_samples.mean(dim=0)

        eg = ExpectedGradients(T=50)
        phi = eg.compute(predictor, z0, I, D_samples)

        lhs = (I * phi).sum()
        with torch.no_grad():
            rhs = predictor(z0)[0] - predictor(z0 - I)[0]

        rel_error = abs(lhs - rhs) / (abs(rhs) + 1e-10)
        assert rel_error < 0.05, f"EG completeness violated: rel_error={rel_error:.6f}"

    def test_completeness_with_batched_predictor(self, setup):
        """Test completeness holds with BatchedPredictor (after fix).

        This is the critical test for the K-factor bug fix. Before the fix,
        this test would fail with rel_error ≈ K-1 = 63.
        """
        head, features, K = setup
        predictor = BatchedPredictor(head, target_class=5, feature_maps=features)

        z0 = torch.ones(K)
        I = torch.rand(K) * 0.5 + 0.1
        D_samples = torch.randn(20, K) * 0.1
        D_samples = D_samples - D_samples.mean(dim=0)

        eg = ExpectedGradients(T=50)
        phi = eg.compute(predictor, z0, I, D_samples)

        lhs = (I * phi).sum()
        with torch.no_grad():
            rhs = predictor(z0)[0] - predictor(z0 - I)[0]

        rel_error = abs(lhs - rhs) / (abs(rhs) + 1e-10)
        # After fix, this should pass
        assert rel_error < 0.05, (
            f"BatchedPredictor completeness violated: rel_error={rel_error:.6f}\n"
            f"This indicates the 1D input handling fix may have regressed."
        )

    def test_predictor_vs_batched_predictor_consistency(self, setup):
        """Test that Predictor and BatchedPredictor give consistent results.

        The attributions should be identical (within numerical precision)
        regardless of which predictor is used.
        """
        head, features, K = setup

        predictor = Predictor(head, target_class=5, feature_maps=features)
        batched_predictor = BatchedPredictor(head, target_class=5, feature_maps=features)

        z0 = torch.ones(K)
        I = torch.rand(K) * 0.5 + 0.1
        D_samples = torch.randn(20, K) * 0.1
        D_samples = D_samples - D_samples.mean(dim=0)

        eg = ExpectedGradients(T=50)

        phi_predictor = eg.compute(predictor, z0, I, D_samples)
        phi_batched = eg.compute(batched_predictor, z0, I, D_samples)

        # Attributions should be the same
        assert torch.allclose(phi_predictor, phi_batched, atol=1e-5), (
            "Predictor and BatchedPredictor give different attributions"
        )

    def test_completeness_convergence_with_N(self, setup):
        """Test completeness error decreases with more baseline samples."""
        head, features, K = setup
        predictor = Predictor(head, target_class=5, feature_maps=features)

        z0 = torch.ones(K)
        I = torch.rand(K) * 0.5 + 0.1

        with torch.no_grad():
            rhs = predictor(z0)[0] - predictor(z0 - I)[0]

        errors = []
        for N in [10, 50, 100]:
            D_samples = torch.randn(N, K) * 0.1
            D_samples = D_samples - D_samples.mean(dim=0)

            eg = ExpectedGradients(T=50)
            phi = eg.compute(predictor, z0, I, D_samples)
            lhs = (I * phi).sum()
            rel_error = abs(lhs - rhs) / (abs(rhs) + 1e-10)
            errors.append(rel_error.item())

        # All errors should be small for linear g
        assert all(e < 0.05 for e in errors), f"Errors too large: {errors}"


class TestCompletenessWithDifferentPerturbations:
    """Test completeness across different perturbation patterns."""

    @pytest.fixture
    def setup(self):
        """Set up test fixtures."""
        torch.manual_seed(42)
        K = 64
        head = SimpleLinearHead(K=K, num_classes=10)
        features = torch.relu(torch.randn(1, K, 8, 8))
        return head, features, K

    def test_sparse_perturbation(self, setup):
        """Test completeness with sparse perturbation (only few channels)."""
        head, features, K = setup
        predictor = Predictor(head, target_class=5, feature_maps=features)

        z0 = torch.ones(K)
        # Only perturb 10 channels
        I = torch.zeros(K)
        I[:10] = torch.rand(10) * 0.5 + 0.1

        D_samples = torch.randn(20, K) * 0.1
        D_samples = D_samples - D_samples.mean(dim=0)

        eg = ExpectedGradients(T=50)
        phi = eg.compute(predictor, z0, I, D_samples)

        lhs = (I * phi).sum()
        with torch.no_grad():
            rhs = predictor(z0)[0] - predictor(z0 - I)[0]

        rel_error = abs(lhs - rhs) / (abs(rhs) + 1e-10)
        assert rel_error < 0.05, f"Sparse perturbation: rel_error={rel_error:.6f}"

    def test_uniform_perturbation(self, setup):
        """Test completeness with uniform perturbation (all channels same)."""
        head, features, K = setup
        predictor = Predictor(head, target_class=5, feature_maps=features)

        z0 = torch.ones(K)
        # Uniform perturbation
        I = torch.ones(K) * 0.3

        D_samples = torch.randn(20, K) * 0.1
        D_samples = D_samples - D_samples.mean(dim=0)

        eg = ExpectedGradients(T=50)
        phi = eg.compute(predictor, z0, I, D_samples)

        lhs = (I * phi).sum()
        with torch.no_grad():
            rhs = predictor(z0)[0] - predictor(z0 - I)[0]

        rel_error = abs(lhs - rhs) / (abs(rhs) + 1e-10)
        assert rel_error < 0.05, f"Uniform perturbation: rel_error={rel_error:.6f}"

    def test_large_perturbation(self, setup):
        """Test completeness with large perturbation (near boundary)."""
        head, features, K = setup
        predictor = Predictor(head, target_class=5, feature_maps=features)

        z0 = torch.ones(K)
        # Large perturbation (up to 0.99)
        I = torch.rand(K) * 0.89 + 0.1

        D_samples = torch.randn(20, K) * 0.1
        D_samples = D_samples - D_samples.mean(dim=0)

        eg = ExpectedGradients(T=50)
        phi = eg.compute(predictor, z0, I, D_samples)

        lhs = (I * phi).sum()
        with torch.no_grad():
            rhs = predictor(z0)[0] - predictor(z0 - I)[0]

        rel_error = abs(lhs - rhs) / (abs(rhs) + 1e-10)
        assert rel_error < 0.05, f"Large perturbation: rel_error={rel_error:.6f}"
