"""Unit tests for ExpectedGradCAMWeightComputer."""

import pytest
import torch
from torch import nn

from expected_gradcam.core.weight_computer import ExpectedGradCAMWeightComputer
from expected_gradcam.config import ExpectedGradCAMConfig


class SimpleClassifierHead(nn.Module):
    """Simple classifier head for testing."""

    def __init__(self, in_channels: int = 64, num_classes: int = 10):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(in_channels, num_classes)

    def forward(self, x):
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)


class SimpleCNN(nn.Module):
    """Simple CNN for testing."""

    def __init__(self, num_classes: int = 10):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(64, num_classes)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

    def get_classifier_head(self) -> SimpleClassifierHead:
        """Return classifier head for testing."""
        head = SimpleClassifierHead(64, 10)
        head.pool = self.pool
        head.fc = self.fc
        return head


class TestExpectedGradCAMWeightComputerInit:
    """Test ExpectedGradCAMWeightComputer initialization."""

    def test_basic_initialization(self):
        """Test basic initialization."""
        model = SimpleCNN()
        target_layer = model.conv2

        computer = ExpectedGradCAMWeightComputer(
            model=model,
            target_layer=target_layer,
        )

        assert computer.model is model
        assert computer.target_layer is target_layer
        assert computer.config is not None
        assert computer.baseline_dataset is None

    def test_with_custom_config(self):
        """Test initialization with custom config."""
        model = SimpleCNN()
        target_layer = model.conv2
        config = ExpectedGradCAMConfig(M=30, N=10, T=20)

        computer = ExpectedGradCAMWeightComputer(
            model=model,
            target_layer=target_layer,
            config=config,
        )

        assert computer.config.M == 30
        assert computer.config.N == 10
        assert computer.config.T == 20

    def test_explicit_classifier_head(self):
        """Test explicit classifier head is used."""
        model = SimpleCNN()
        target_layer = model.conv2
        classifier_head = model.get_classifier_head()

        computer = ExpectedGradCAMWeightComputer(
            model=model,
            target_layer=target_layer,
            classifier_head=classifier_head,
        )

        # Classifier head should be set immediately
        assert computer._classifier_head is classifier_head

        # Accessing the property should return the same head
        assert computer.classifier_head is classifier_head

    def test_repr(self):
        """Test string representation."""
        model = SimpleCNN()
        target_layer = model.conv2

        computer = ExpectedGradCAMWeightComputer(
            model=model,
            target_layer=target_layer,
        )

        repr_str = repr(computer)
        assert "ExpectedGradCAMWeightComputer" in repr_str
        assert "SimpleCNN" in repr_str


class TestExpectedGradCAMWeightComputerComputeWeights:
    """Test weight computation functionality."""

    @pytest.fixture
    def simple_setup(self):
        """Create a simple setup for testing."""
        model = SimpleCNN()
        target_layer = model.conv2
        classifier_head = model.get_classifier_head()
        config = ExpectedGradCAMConfig(M=10, N=5, T=10)  # Small values for fast tests

        computer = ExpectedGradCAMWeightComputer(
            model=model,
            target_layer=target_layer,
            config=config,
            classifier_head=classifier_head,
        )

        # Create dummy features
        features = torch.randn(1, 64, 7, 7)

        return computer, features

    def test_compute_weights_returns_tensor(self, simple_setup):
        """Test that compute_weights returns a tensor."""
        computer, features = simple_setup

        weights, diagnostics = computer.compute_weights(
            features=features,
            class_idx=0,
        )

        assert isinstance(weights, torch.Tensor)
        assert weights.dim() == 1

    def test_compute_weights_correct_shape(self, simple_setup):
        """Test that weights have correct shape."""
        computer, features = simple_setup
        K = features.shape[1]  # 64 channels

        weights, _ = computer.compute_weights(
            features=features,
            class_idx=0,
        )

        assert weights.shape == (K,)

    def test_compute_weights_returns_diagnostics(self, simple_setup):
        """Test that diagnostics are returned."""
        computer, features = simple_setup

        weights, diagnostics = computer.compute_weights(
            features=features,
            class_idx=0,
        )

        assert diagnostics is not None
        assert hasattr(diagnostics, "method")
        assert hasattr(diagnostics, "condition_number")

    def test_compute_weights_with_mask(self, simple_setup):
        """Test weight computation with a mask."""
        computer, features = simple_setup
        U, V = features.shape[2:]

        # Create a simple mask
        mask = torch.ones(U, V)
        mask[:U // 2, :] = 0  # Mask out top half

        weights, _ = computer.compute_weights(
            features=features,
            class_idx=0,
            mask=mask,
        )

        assert weights.shape == (features.shape[1],)


class TestExpectedGradCAMWeightComputerBatchedSegments:
    """Test batched segment weight computation."""

    @pytest.fixture
    def batched_setup(self):
        """Create setup for batched testing."""
        model = SimpleCNN()
        target_layer = model.conv2
        classifier_head = model.get_classifier_head()
        config = ExpectedGradCAMConfig(M=5, N=3, T=5)  # Very small for fast tests

        computer = ExpectedGradCAMWeightComputer(
            model=model,
            target_layer=target_layer,
            config=config,
            classifier_head=classifier_head,
        )

        features = torch.randn(1, 64, 7, 7)

        return computer, features

    def test_compute_weights_batched_segments(self, batched_setup):
        """Test batched segment weight computation."""
        computer, features = batched_setup
        U, V = features.shape[2:]
        K = features.shape[1]

        # Create 3 segment masks
        masks = torch.zeros(3, U, V)
        masks[0, :U // 2, :] = 1
        masks[1, U // 2:, :V // 2] = 1
        masks[2, U // 2:, V // 2:] = 1

        per_segment_weights = computer.compute_weights_batched_segments(
            features=features,
            class_idx=0,
            masks=masks,
        )

        assert per_segment_weights.shape == (3, K)


class TestSamplingIntegration:
    """Test sampling integration."""

    def test_simple_sampling(self):
        """Test that simple (non-data-aware) sampling works."""
        model = SimpleCNN()
        target_layer = model.conv2
        config = ExpectedGradCAMConfig(M=10, N=5, T=10)

        computer = ExpectedGradCAMWeightComputer(
            model=model,
            target_layer=target_layer,
            config=config,
            baseline_dataset=None,  # No dataset = simple sampling
        )

        features = torch.randn(1, 64, 7, 7)
        I_samples = computer._sample_perturbations(features, M=10)

        assert I_samples.shape == (10, 64)
        assert I_samples.min() >= computer.config.alpha_min
        assert I_samples.max() <= computer.config.alpha_max
