"""Unit tests for the predictor module.

Tests Predictor and BatchedPredictor classes for evaluating g(z; A).
"""

from __future__ import annotations

import pytest
import torch
from torch import Tensor, nn

from expected_gradcam.core.predictor import BatchedPredictor, Predictor


class SimpleClassifierHead(nn.Module):
    """Simple classifier head: GAP + Linear."""

    def __init__(self, in_channels: int = 64, num_classes: int = 10) -> None:
        super().__init__()
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(in_channels, num_classes)

    def forward(self, x: Tensor) -> Tensor:
        # x: [B, K, U, V]
        pooled = self.gap(x).flatten(1)  # [B, K]
        return self.fc(pooled)  # [B, num_classes]


class TestPredictor:
    """Test Predictor class."""

    @pytest.fixture
    def classifier_head(self) -> nn.Module:
        """Create a simple classifier head."""
        torch.manual_seed(42)
        return SimpleClassifierHead(in_channels=64, num_classes=10)

    @pytest.fixture
    def sample_features(self) -> Tensor:
        """Create sample feature maps [1, 64, 8, 8]."""
        torch.manual_seed(42)
        return torch.relu(torch.randn(1, 64, 8, 8))

    def test_initialization(self, classifier_head, sample_features):
        """Test Predictor initialization."""
        predictor = Predictor(classifier_head, target_class=5, feature_maps=sample_features)

        assert predictor.classifier_head is classifier_head
        assert predictor.target_class == 5
        assert predictor.num_features == 64

    def test_call_returns_tensor(self, classifier_head, sample_features):
        """Test predictor returns tensor."""
        predictor = Predictor(classifier_head, target_class=5, feature_maps=sample_features)
        z = torch.ones(64)
        result = predictor(z)

        assert isinstance(result, Tensor)

    def test_output_shape(self, classifier_head, sample_features):
        """Test output has correct shape [B]."""
        predictor = Predictor(classifier_head, target_class=5, feature_maps=sample_features)
        z = torch.ones(64)
        result = predictor(z)

        # Should return [B] where B is batch size
        assert result.shape == (1,)

    def test_perturbation_effect(self, classifier_head, sample_features):
        """Test that different perturbations give different outputs."""
        predictor = Predictor(classifier_head, target_class=5, feature_maps=sample_features)

        z1 = torch.ones(64)
        z2 = torch.zeros(64)

        result1 = predictor(z1)
        result2 = predictor(z2)

        # Results should be different
        assert not torch.allclose(result1, result2)

    def test_batch_z(self, classifier_head, sample_features):
        """Test with batched z vectors [B, K]."""
        predictor = Predictor(classifier_head, target_class=5, feature_maps=sample_features)

        # Batch of 4 z vectors
        z_batch = torch.rand(4, 64)
        result = predictor(z_batch)

        # Output should be [4] (one score per z vector)
        # Note: feature_maps are broadcast
        assert result.shape == (4,)

    def test_reference_perturbation(self, classifier_head, sample_features):
        """Test evaluate_at_reference method."""
        predictor = Predictor(classifier_head, target_class=5, feature_maps=sample_features)

        result = predictor.evaluate_at_reference()

        # Should return class score at z=1
        assert result.shape == (1,)
        assert isinstance(result, Tensor)

    def test_zero_perturbation(self, classifier_head, sample_features):
        """Test that z=0 gives zero output (features are zeroed)."""
        predictor = Predictor(classifier_head, target_class=5, feature_maps=sample_features)

        z_zero = torch.zeros(64)
        result = predictor(z_zero)

        # Should be close to bias only (since all features are zeroed)
        assert result.shape == (1,)

    def test_gradient_flow(self, classifier_head, sample_features):
        """Test gradients can flow through predictor."""
        predictor = Predictor(classifier_head, target_class=5, feature_maps=sample_features)

        z = torch.ones(64, requires_grad=True)
        result = predictor(z)

        loss = result.sum()
        loss.backward()

        assert z.grad is not None
        assert z.grad.shape == z.shape

    def test_compute_output_difference(self, classifier_head, sample_features):
        """Test compute_output_difference method."""
        predictor = Predictor(classifier_head, target_class=5, feature_maps=sample_features)

        I = torch.rand(64) * 0.5  # Perturbation
        diff = predictor.compute_output_difference(I)

        # Should be g(z_0) - g(z_0 - I)
        expected = predictor.evaluate_at_reference() - predictor.evaluate_at_baseline(I)

        assert torch.allclose(diff, expected)

    def test_device_property(self, classifier_head, sample_features):
        """Test device property returns correct device."""
        predictor = Predictor(classifier_head, target_class=5, feature_maps=sample_features)

        assert predictor.device == sample_features.device


class TestBatchedPredictor:
    """Test BatchedPredictor class."""

    @pytest.fixture
    def classifier_head(self) -> nn.Module:
        """Create a simple classifier head."""
        torch.manual_seed(42)
        return SimpleClassifierHead(in_channels=64, num_classes=10)

    @pytest.fixture
    def sample_features(self) -> Tensor:
        """Create sample feature maps [1, 64, 8, 8]."""
        torch.manual_seed(42)
        return torch.relu(torch.randn(1, 64, 8, 8))

    def test_initialization(self, classifier_head, sample_features):
        """Test BatchedPredictor initialization."""
        predictor = BatchedPredictor(
            classifier_head, target_class=5, feature_maps=sample_features
        )

        assert predictor.num_features == 64
        assert predictor.target_class == 5

    def test_requires_batch_size_1(self, classifier_head):
        """Test BatchedPredictor requires batch size 1 for feature maps."""
        features_batch = torch.randn(4, 64, 8, 8)

        with pytest.raises(ValueError, match="batch size 1"):
            BatchedPredictor(classifier_head, target_class=5, feature_maps=features_batch)

    def test_multiple_perturbations(self, classifier_head, sample_features):
        """Test processing multiple perturbations."""
        predictor = BatchedPredictor(
            classifier_head, target_class=5, feature_maps=sample_features
        )

        # 50 perturbations
        z_samples = torch.rand(50, 64)
        results = predictor(z_samples)

        assert results.shape == (50,)

    def test_consistency_with_predictor(self, classifier_head, sample_features):
        """Test BatchedPredictor gives same results as Predictor."""
        predictor = Predictor(
            classifier_head, target_class=5, feature_maps=sample_features
        )
        batched_predictor = BatchedPredictor(
            classifier_head, target_class=5, feature_maps=sample_features
        )

        torch.manual_seed(42)
        z_samples = torch.rand(20, 64)

        # Single predictor (one by one)
        results_single = []
        for z in z_samples:
            r = predictor(z)
            results_single.append(r)
        results_single = torch.cat(results_single, dim=0)

        # Batched predictor
        results_batched = batched_predictor(z_samples)

        assert torch.allclose(results_single, results_batched, atol=1e-5)

    def test_gradient_flow(self, classifier_head, sample_features):
        """Test gradient flows through batched predictor."""
        predictor = BatchedPredictor(
            classifier_head, target_class=5, feature_maps=sample_features
        )

        z_samples = torch.rand(10, 64, requires_grad=True)
        results = predictor(z_samples)

        loss = results.sum()
        loss.backward()

        assert z_samples.grad is not None
        assert z_samples.grad.shape == z_samples.shape

    def test_use_compile_option(self, classifier_head, sample_features):
        """Test use_compile option doesn't raise."""
        # Should not raise even if compile not available
        predictor = BatchedPredictor(
            classifier_head,
            target_class=5,
            feature_maps=sample_features,
            use_compile=True,
        )

        z_samples = torch.rand(10, 64)
        results = predictor(z_samples)

        assert results.shape == (10,)

    def test_1d_input_handling(self, classifier_head, sample_features):
        """Test BatchedPredictor handles 1D input correctly (bug fix test).

        This test verifies the fix for the completeness axiom bug where
        1D input [K] was misinterpreted as K separate samples instead of
        a single sample with K features.
        """
        batched_predictor = BatchedPredictor(
            classifier_head, target_class=5, feature_maps=sample_features
        )
        single_predictor = Predictor(
            classifier_head, target_class=5, feature_maps=sample_features
        )

        # 1D input (single z vector)
        z_1d = torch.rand(64)

        result_batched = batched_predictor(z_1d)
        result_single = single_predictor(z_1d)

        # Should give same result as Predictor
        assert torch.allclose(result_batched, result_single, atol=1e-5)
        # Result shape should be [1], not [K]
        assert result_batched.shape == (1,)

    def test_1d_gradient_correctness(self, classifier_head, sample_features):
        """Test gradient computation for 1D input is correct (not K times larger).

        This specifically tests the bug where gradients were K times larger
        when 1D input was misinterpreted as K separate samples.
        """
        batched_predictor = BatchedPredictor(
            classifier_head, target_class=5, feature_maps=sample_features
        )
        single_predictor = Predictor(
            classifier_head, target_class=5, feature_maps=sample_features
        )

        # Test gradient with 1D input
        z_batched = torch.ones(64, requires_grad=True)
        z_single = torch.ones(64, requires_grad=True)

        result_batched = batched_predictor(z_batched)
        result_single = single_predictor(z_single)

        # Backprop
        result_batched.sum().backward()
        result_single.sum().backward()

        # Gradients should be the same (not K times different)
        assert z_batched.grad is not None
        assert z_single.grad is not None
        assert torch.allclose(z_batched.grad, z_single.grad, atol=1e-5)

    def test_2d_input_unchanged(self, classifier_head, sample_features):
        """Test 2D input still works correctly (no regression)."""
        predictor = BatchedPredictor(
            classifier_head, target_class=5, feature_maps=sample_features
        )

        # 2D input (batch of z vectors)
        z_2d = torch.rand(10, 64)
        result = predictor(z_2d)

        # Should return [10] (one score per sample)
        assert result.shape == (10,)
