"""Integration tests for the full Expected GradCAM pipeline.

Tests the complete workflow from image input to heatmap output.
"""

from __future__ import annotations

import pytest
import torch
from torch import Tensor, nn

from expected_gradcam.config import ExpectedGradCAMConfig, FAST_CONFIG
from expected_gradcam.core.heatmap import generate_heatmap, process_heatmap
from expected_gradcam.core.optimal_weights import compute_optimal_weights
from expected_gradcam.core.predictor import Predictor
from expected_gradcam.core.second_moment import compute_second_moment_matrix
from expected_gradcam.hooks import FeatureMapHook


class TestFullPipeline:
    """Test the complete E-GradCAM pipeline."""

    @pytest.fixture
    def simple_model(self) -> nn.Module:
        """Create a simple CNN for testing."""
        model = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(64, 10),
        )
        model.eval()
        return model

    @pytest.fixture
    def target_layer(self, simple_model) -> nn.Module:
        """Get target layer (last conv before pooling)."""
        return simple_model[6]  # Conv2d(32, 64)

    @pytest.fixture
    def classifier_head(self, simple_model) -> nn.Module:
        """Get classifier head (GAP + Flatten + Linear)."""
        return nn.Sequential(
            simple_model[8],  # AdaptiveAvgPool2d(1)
            simple_model[9],  # Flatten
            simple_model[10], # Linear(64, 10)
        )

    @pytest.fixture
    def sample_image(self) -> Tensor:
        """Create sample input image."""
        torch.manual_seed(42)
        return torch.randn(1, 3, 32, 32)

    def test_feature_extraction_with_hooks(
        self, simple_model, target_layer, sample_image
    ):
        """Test feature extraction using hooks."""
        with FeatureMapHook(target_layer) as hook:
            output = simple_model(sample_image)

            assert hook.features is not None
            assert hook.features.shape == (1, 64, 8, 8)

    def test_predictor_evaluation(
        self, simple_model, target_layer, classifier_head, sample_image
    ):
        """Test predictor g(z; A) evaluation."""
        target_class = 0

        # Get features
        with FeatureMapHook(target_layer) as hook:
            _ = simple_model(sample_image)
            features = hook.features.clone()

        # Create predictor with target_class and feature_maps
        predictor = Predictor(classifier_head, target_class=target_class, feature_maps=features)

        # Evaluate with reference perturbation (z = 1)
        z_ref = torch.ones(64)
        output_ref = predictor(z_ref)

        assert output_ref.shape == (1,)

        # Evaluate with zero perturbation
        z_zero = torch.zeros(64)
        output_zero = predictor(z_zero)

        # Outputs should differ
        assert not torch.allclose(output_ref, output_zero)

    def test_perturbation_sampling(
        self, simple_model, target_layer, classifier_head, sample_image
    ):
        """Test perturbation sampling and gradient computation."""
        config = FAST_CONFIG
        target_class = 0

        # Get features
        with FeatureMapHook(target_layer) as hook:
            _ = simple_model(sample_image)
            features = hook.features.clone()

        # Sample perturbations
        M = config.M
        K = features.shape[1]

        torch.manual_seed(42)
        I_samples = torch.rand(M, K)  # Uniform [0, 1]

        # Create predictor
        predictor = Predictor(classifier_head, target_class=target_class, feature_maps=features)

        # Compute predictor outputs
        outputs = []
        for i in range(M):
            out = predictor(I_samples[i])
            outputs.append(out)

        outputs = torch.cat(outputs, dim=0)
        assert outputs.shape == (M,)

    def test_second_moment_computation(self):
        """Test second moment matrix M_I computation."""
        torch.manual_seed(42)

        K = 64
        M = 50

        # Sample perturbations
        I_samples = torch.rand(M, K)

        # Compute using library function
        M_I = compute_second_moment_matrix(I_samples)

        # Should be positive semi-definite
        eigenvalues = torch.linalg.eigvalsh(M_I)
        assert (eigenvalues >= -1e-6).all()

        # Should be approximately symmetric
        assert torch.allclose(M_I, M_I.T, atol=1e-6)

    def test_optimal_weights_computation(self):
        """Test optimal weight computation."""
        torch.manual_seed(42)

        K = 64
        M = 100

        # Sample perturbations
        I_samples = torch.rand(M, K)

        # Compute second moment matrix
        M_I = compute_second_moment_matrix(I_samples)

        # Random phi samples (attribution values)
        phi_samples = torch.randn(M, K)

        # Compute optimal weights
        weights = compute_optimal_weights(M_I, I_samples, phi_samples)

        assert weights.shape == (K,)
        assert not torch.isnan(weights).any()

    def test_heatmap_generation(self, simple_model, target_layer, sample_image):
        """Test heatmap generation from weights and features."""
        # Get features
        with FeatureMapHook(target_layer) as hook:
            _ = simple_model(sample_image)
            features = hook.features

        # Create random weights (simulating computed optimal weights)
        torch.manual_seed(42)
        weights = torch.randn(64)

        # Generate heatmap (features first, then weights)
        heatmap = generate_heatmap(features, weights, apply_relu=True)

        assert heatmap.shape[-2:] == features.shape[-2:]
        assert (heatmap >= 0).all()  # ReLU applied

    def test_heatmap_processing(self, simple_model, target_layer, sample_image):
        """Test full heatmap processing pipeline."""
        # Get features
        with FeatureMapHook(target_layer) as hook:
            _ = simple_model(sample_image)
            features = hook.features

        # Create weights
        torch.manual_seed(42)
        weights = torch.randn(64)

        # Process heatmap (features first, then weights)
        heatmap, coarse_heatmap = process_heatmap(
            features,
            weights,
            input_size=(32, 32),
            normalize=True,
        )

        assert heatmap.shape[-2:] == (32, 32)
        assert heatmap.min() >= 0.0
        assert heatmap.max() <= 1.0

    def test_end_to_end_simplified(
        self, simple_model, target_layer, classifier_head, sample_image
    ):
        """Test simplified end-to-end E-GradCAM computation."""
        K = 64
        M = 20
        target_class = 0

        # Step 1: Extract features
        with FeatureMapHook(target_layer) as hook:
            logits = simple_model(sample_image)
            features = hook.features.clone()

        # Step 2: Sample perturbations
        torch.manual_seed(42)
        I_samples = torch.rand(M, K)

        # Step 3: Create predictor
        predictor = Predictor(classifier_head, target_class=target_class, feature_maps=features)

        y_ref = logits[0, target_class].item()

        # Step 4: Compute phi samples (attribution for each perturbation)
        phi_samples = torch.zeros(M, K)
        for i in range(M):
            I = I_samples[i]
            y_pert = predictor(I).item()
            phi = y_ref - y_pert
            # Distribute attribution across channels
            phi_samples[i] = I * phi

        # Step 5: Compute second moment matrix
        M_I = compute_second_moment_matrix(I_samples)

        # Step 6: Solve for optimal weights
        weights = compute_optimal_weights(M_I, I_samples, phi_samples)

        # Step 7: Generate heatmap
        heatmap, coarse_heatmap = process_heatmap(
            features,
            weights,
            input_size=sample_image.shape[-2:],
            normalize=True,
        )

        # Verify output
        assert heatmap.shape[-2:] == sample_image.shape[-2:]
        assert heatmap.min() >= 0.0
        assert heatmap.max() <= 1.0
        assert not torch.isnan(heatmap).any()


class TestConfigurationOptions:
    """Test different configuration options."""

    @pytest.fixture
    def simple_model(self) -> nn.Module:
        """Create simple model."""
        return nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(32, 10),
        )

    def test_fast_config(self, simple_model):
        """Test with FAST_CONFIG settings."""
        config = FAST_CONFIG

        assert config.M < 50  # Fewer samples
        assert config.T < 50  # Fewer integration steps

    def test_custom_config(self):
        """Test custom configuration."""
        config = ExpectedGradCAMConfig(
            M=100,
            N=30,
            T=100,
            solver_method="pinv",
        )

        assert config.M == 100
        assert config.N == 30
        assert config.T == 100


class TestGradientFlow:
    """Test gradient flow through the pipeline."""

    @pytest.fixture
    def simple_model(self) -> nn.Module:
        """Create differentiable model."""
        model = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(16, 10),
        )
        return model

    def test_gradients_through_predictor(self, simple_model):
        """Test gradients flow through predictor."""
        # Create classifier head
        classifier_head = nn.Sequential(
            simple_model[2],  # GAP
            simple_model[3],  # Flatten
            simple_model[4],  # Linear
        )

        # Features with grad
        features = torch.randn(1, 16, 8, 8, requires_grad=True)
        target_class = 0

        predictor = Predictor(classifier_head, target_class=target_class, feature_maps=features)

        z = torch.rand(16, requires_grad=True)

        output = predictor(z)
        loss = output.sum()
        loss.backward()

        assert z.grad is not None

    def test_gradients_through_heatmap(self, simple_model):
        """Test gradients flow through heatmap generation."""
        # Weights with grad
        weights = torch.randn(16, requires_grad=True)
        features = torch.randn(1, 16, 8, 8)

        heatmap = generate_heatmap(features, weights, apply_relu=False)
        loss = heatmap.sum()
        loss.backward()

        assert weights.grad is not None
