"""Unit tests for the hooks manager module.

Tests HookManager, MultiLayerHooks, and specialized hook classes.
"""

from __future__ import annotations

import pytest
import torch
from torch import Tensor, nn

from expected_gradcam.hooks import (
    CapturedActivation,
    FeatureMapHook,
    GradientHook,
    HookManager,
    MultiLayerHooks,
    capture_activations,
)


class TestCapturedActivation:
    """Test CapturedActivation dataclass."""

    def test_creation(self):
        """Test basic creation."""
        capture = CapturedActivation()
        assert capture.output is None
        assert capture.input is None

    def test_with_values(self):
        """Test with captured values."""
        output = torch.randn(1, 64, 8, 8)
        capture = CapturedActivation(output=output, layer_name="conv1")

        assert capture.output is output
        assert capture.layer_name == "conv1"

    def test_has_gradients(self):
        """Test has_gradients property."""
        capture = CapturedActivation()
        assert not capture.has_gradients

        capture.grad_output = (torch.randn(1, 64, 8, 8),)
        assert capture.has_gradients

    def test_clear(self):
        """Test clear method."""
        capture = CapturedActivation(
            output=torch.randn(1, 64, 8, 8),
            input=(torch.randn(1, 32, 16, 16),),
        )

        capture.clear()

        assert capture.output is None
        assert capture.input is None


class TestHookManager:
    """Test HookManager class."""

    @pytest.fixture
    def simple_model(self) -> nn.Module:
        """Create simple model for testing."""
        return nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(32, 10),
        )

    @pytest.fixture
    def sample_input(self) -> Tensor:
        """Create sample input."""
        return torch.randn(1, 3, 32, 32)

    def test_context_manager(self, simple_model):
        """Test HookManager as context manager."""
        with HookManager() as hooks:
            hooks.register_forward(simple_model[0], "conv1")
            assert len(hooks) == 1

        # Hooks should be removed after context
        assert len(hooks) == 0

    def test_register_forward_hook(self, simple_model, sample_input):
        """Test registering forward hook."""
        with HookManager() as hooks:
            hooks.register_forward(simple_model[0], "conv1")

            # Run forward pass
            _ = simple_model(sample_input)

            # Check captured activation
            output = hooks.get_output("conv1")
            assert output is not None
            assert output.shape == (1, 16, 32, 32)

    def test_register_backward_hook(self, simple_model, sample_input):
        """Test registering backward hook."""
        sample_input.requires_grad_(True)

        with HookManager() as hooks:
            hooks.register_backward(simple_model[2], "conv2_grad")

            output = simple_model(sample_input)
            loss = output.sum()
            loss.backward()

            grad_output = hooks.get_grad_output("conv2_grad")
            assert grad_output is not None

    def test_multiple_hooks(self, simple_model, sample_input):
        """Test registering multiple hooks."""
        with HookManager() as hooks:
            hooks.register_forward(simple_model[0], "conv1")
            hooks.register_forward(simple_model[2], "conv2")

            _ = simple_model(sample_input)

            assert hooks.get_output("conv1") is not None
            assert hooks.get_output("conv2") is not None
            assert len(hooks) == 2

    def test_get_activation(self, simple_model, sample_input):
        """Test get_activation method."""
        with HookManager() as hooks:
            hooks.register_forward(simple_model[0], "conv1")
            _ = simple_model(sample_input)

            activation = hooks.get_activation("conv1")
            assert isinstance(activation, CapturedActivation)
            assert activation.layer_name == "conv1"

    def test_get_nonexistent_hook(self, simple_model):
        """Test accessing non-existent hook raises error."""
        with HookManager() as hooks:
            hooks.register_forward(simple_model[0], "conv1")

            with pytest.raises(KeyError):
                hooks.get_output("nonexistent")

    def test_remove_hook(self, simple_model):
        """Test removing specific hook."""
        with HookManager() as hooks:
            hooks.register_forward(simple_model[0], "conv1")
            hooks.register_forward(simple_model[2], "conv2")

            hooks.remove("conv1")

            assert "conv1" not in hooks
            assert "conv2" in hooks

    def test_clear_captures(self, simple_model, sample_input):
        """Test clearing captured data."""
        with HookManager() as hooks:
            hooks.register_forward(simple_model[0], "conv1")

            _ = simple_model(sample_input)
            assert hooks.get_output("conv1") is not None

            hooks.clear_captures()
            assert hooks.get_output("conv1") is None

    def test_hook_names(self, simple_model):
        """Test hook_names property."""
        with HookManager() as hooks:
            hooks.register_forward(simple_model[0], "conv1")
            hooks.register_forward(simple_model[2], "conv2")

            names = hooks.hook_names
            assert "conv1" in names
            assert "conv2" in names


class TestMultiLayerHooks:
    """Test MultiLayerHooks class."""

    @pytest.fixture
    def simple_model(self) -> nn.Module:
        """Create simple model."""
        model = nn.Sequential()
        model.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        model.relu1 = nn.ReLU()
        model.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        model.relu2 = nn.ReLU()
        model.pool = nn.AdaptiveAvgPool2d(1)
        model.flatten = nn.Flatten()
        model.fc = nn.Linear(32, 10)

        def forward(x):
            x = model.conv1(x)
            x = model.relu1(x)
            x = model.conv2(x)
            x = model.relu2(x)
            x = model.pool(x)
            x = model.flatten(x)
            x = model.fc(x)
            return x

        model.forward = forward
        return model

    @pytest.fixture
    def sample_input(self) -> Tensor:
        """Create sample input."""
        return torch.randn(1, 3, 32, 32)

    def test_with_dict_layers(self, simple_model, sample_input):
        """Test with dictionary of layers."""
        layers = {"conv1": simple_model.conv1, "conv2": simple_model.conv2}

        with MultiLayerHooks(layers) as hooks:
            _ = simple_model(sample_input)

            assert hooks["conv1"] is not None
            assert hooks["conv2"] is not None

    def test_items_iteration(self, simple_model, sample_input):
        """Test iterating over captured activations."""
        layers = {"conv1": simple_model.conv1, "conv2": simple_model.conv2}

        with MultiLayerHooks(layers) as hooks:
            _ = simple_model(sample_input)

            items = list(hooks.items())
            assert len(items) == 2
            assert all(isinstance(act, Tensor) for _, act in items)

    def test_clear(self, simple_model, sample_input):
        """Test clearing captured activations."""
        layers = {"conv1": simple_model.conv1}

        with MultiLayerHooks(layers) as hooks:
            _ = simple_model(sample_input)
            assert hooks["conv1"] is not None

            hooks.clear()
            assert hooks["conv1"] is None


class TestFeatureMapHook:
    """Test FeatureMapHook class."""

    @pytest.fixture
    def simple_model(self) -> nn.Module:
        """Create simple model."""
        return nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(16, 10),
        )

    @pytest.fixture
    def sample_input(self) -> Tensor:
        """Create sample input."""
        return torch.randn(1, 3, 32, 32, requires_grad=True)

    def test_captures_features(self, simple_model, sample_input):
        """Test capturing feature maps."""
        with FeatureMapHook(simple_model[0]) as hook:
            output = simple_model(sample_input)

            assert hook.features is not None
            assert hook.features.shape == (1, 16, 32, 32)

    def test_captures_gradients(self, simple_model, sample_input):
        """Test capturing gradients."""
        with FeatureMapHook(simple_model[0]) as hook:
            output = simple_model(sample_input)
            output.sum().backward()

            assert hook.gradients is not None
            assert hook.gradients.shape == (1, 16, 32, 32)

    def test_clear(self, simple_model, sample_input):
        """Test clearing captured data."""
        with FeatureMapHook(simple_model[0]) as hook:
            output = simple_model(sample_input)
            output.sum().backward()

            hook.clear()

            assert hook.features is None
            assert hook.gradients is None


class TestGradientHook:
    """Test GradientHook class."""

    @pytest.fixture
    def simple_model(self) -> nn.Module:
        """Create simple model with relu for gradient testing."""
        return nn.Sequential(
            nn.Linear(10, 20),
            nn.ReLU(),
            nn.Linear(20, 5),
        )

    @pytest.fixture
    def sample_input(self) -> Tensor:
        """Create sample input."""
        return torch.randn(1, 10, requires_grad=True)

    def test_captures_gradient(self, simple_model, sample_input):
        """Test capturing gradients."""
        with GradientHook(simple_model[1]) as hook:
            output = simple_model(sample_input)
            output.sum().backward()

            assert hook.grad is not None

    def test_gradient_modifier(self, simple_model, sample_input):
        """Test gradient modification."""

        def clamp_positive(grad):
            return torch.clamp(grad, min=0)

        with GradientHook(simple_model[1], modifier=clamp_positive):
            output = simple_model(sample_input)
            output.sum().backward()

            # Gradient should have been modified (clamped)
            # We can verify this worked by checking the input gradient
            assert sample_input.grad is not None


class TestCaptureActivations:
    """Test capture_activations context manager."""

    @pytest.fixture
    def simple_model(self) -> nn.Module:
        """Create simple model."""
        return nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, padding=1),
        )

    @pytest.fixture
    def sample_input(self) -> Tensor:
        """Create sample input."""
        return torch.randn(1, 3, 32, 32)

    def test_capture_dict(self, simple_model, sample_input):
        """Test capturing activations as dict."""
        layers = {"conv1": simple_model[0], "conv2": simple_model[2]}

        with capture_activations(layers) as acts:
            _ = simple_model(sample_input)

        assert "conv1" in acts
        assert "conv2" in acts
        assert acts["conv1"] is not None
