"""Unit tests for the weight_transform module.

Tests weight transformation functions like double_power and extreme_power.
"""

from __future__ import annotations

import pytest
import torch
from torch import Tensor

from expected_gradcam.core.weight_transform import (
    double_power_transform,
    extreme_power_transform,
    feature_adaptive_transform,
    transform_weights,
)


class TestDoublePowerTransform:
    """Test double_power_transform function."""

    @pytest.fixture
    def sample_weights(self) -> Tensor:
        """Create sample weights with mixed positive and negative values."""
        torch.manual_seed(42)
        return torch.randn(64)

    def test_output_shape(self, sample_weights):
        """Test output shape matches input."""
        result = double_power_transform(sample_weights)
        assert result.shape == sample_weights.shape

    def test_preserves_sign(self, sample_weights):
        """Test that sign is preserved."""
        result = double_power_transform(sample_weights)

        positive_mask = sample_weights > 0
        negative_mask = sample_weights < 0

        assert (result[positive_mask] >= 0).all()
        assert (result[negative_mask] <= 0).all()

    def test_zero_stays_zero(self):
        """Test that zero values remain zero."""
        weights = torch.tensor([0.0, 1.0, -1.0, 0.0])
        result = double_power_transform(weights)

        assert result[0] == 0.0
        assert result[3] == 0.0

    def test_formula_verification(self, sample_weights):
        """Test that formula matches (w * |w|) * |w * |w||."""
        result = double_power_transform(sample_weights)

        # Manual computation of the formula
        w1 = sample_weights * sample_weights.abs()
        expected = w1 * w1.abs()

        assert torch.allclose(result, expected)

    def test_amplifies_large_values(self, sample_weights):
        """Test that large absolute values are amplified."""
        result = double_power_transform(sample_weights)

        # The transform should amplify large values relative to small ones
        large_mask = sample_weights.abs() > sample_weights.abs().median()
        small_mask = ~large_mask

        if large_mask.any() and small_mask.any():
            ratio_before = sample_weights[large_mask].abs().mean() / (
                sample_weights[small_mask].abs().mean() + 1e-8
            )
            ratio_after = result[large_mask].abs().mean() / (
                result[small_mask].abs().mean() + 1e-8
            )

            # Ratio should increase (large values more amplified)
            assert ratio_after >= ratio_before * 0.9  # Allow small tolerance


class TestExtremePowerTransform:
    """Test extreme_power_transform function."""

    @pytest.fixture
    def sample_weights(self) -> Tensor:
        """Create sample weights."""
        torch.manual_seed(42)
        return torch.randn(64)

    def test_output_shape(self, sample_weights):
        """Test output shape matches input."""
        result = extreme_power_transform(sample_weights)
        assert result.shape == sample_weights.shape

    def test_exponent_parameter(self, sample_weights):
        """Test that higher exponent produces more extreme results."""
        result_2 = extreme_power_transform(sample_weights, exponent=2.0)
        result_4 = extreme_power_transform(sample_weights, exponent=4.0)

        # Higher exponent should amplify differences more
        max_idx = sample_weights.abs().argmax()

        if result_2.abs().mean() > 1e-8 and result_4.abs().mean() > 1e-8:
            ratio_2 = result_2[max_idx].abs() / result_2.abs().mean()
            ratio_4 = result_4[max_idx].abs() / result_4.abs().mean()

            # Higher exponent should be more extreme
            assert ratio_4 >= ratio_2 * 0.9

    def test_formula_verification(self, sample_weights):
        """Test that formula matches w * |w|^exponent."""
        exponent = 3.0
        result = extreme_power_transform(sample_weights, exponent=exponent)

        expected = sample_weights * sample_weights.abs().pow(exponent)

        assert torch.allclose(result, expected)


class TestFeatureAdaptiveTransform:
    """Test feature_adaptive_transform function."""

    @pytest.fixture
    def sample_weights(self) -> Tensor:
        """Create sample weights."""
        torch.manual_seed(42)
        return torch.randn(64)

    @pytest.fixture
    def sample_features(self) -> Tensor:
        """Create sample feature maps."""
        torch.manual_seed(42)
        return torch.relu(torch.randn(1, 64, 8, 8))

    def test_output_shape(self, sample_weights, sample_features):
        """Test output shape matches weights."""
        result = feature_adaptive_transform(sample_weights, sample_features)
        assert result.shape == sample_weights.shape

    def test_uses_feature_statistics(self, sample_weights):
        """Test that transform adapts to feature statistics."""
        # Create two feature maps with different variance distributions
        torch.manual_seed(42)
        features1 = torch.relu(torch.randn(1, 64, 8, 8))

        # Create features with very different per-channel variances
        torch.manual_seed(123)
        features2 = torch.relu(torch.randn(1, 64, 8, 8))
        # Modify some channels to have very different variance
        features2[:, :32, :, :] *= 10.0
        features2[:, 32:, :, :] *= 0.1

        result1 = feature_adaptive_transform(sample_weights, features1)
        result2 = feature_adaptive_transform(sample_weights, features2)

        # Results should differ due to different per-channel variance distributions
        assert not torch.allclose(result1, result2)


class TestTransformWeights:
    """Test transform_weights dispatcher function."""

    @pytest.fixture
    def sample_weights(self) -> Tensor:
        """Create sample weights."""
        torch.manual_seed(42)
        return torch.randn(64)

    def test_none_transform(self, sample_weights):
        """Test 'none' transform returns unchanged weights."""
        result = transform_weights(sample_weights, method="none")
        assert torch.allclose(result, sample_weights)

    def test_double_power_dispatch(self, sample_weights):
        """Test 'double_power' dispatches correctly."""
        result = transform_weights(sample_weights, method="double_power")
        expected = double_power_transform(sample_weights)
        assert torch.allclose(result, expected)

    def test_extreme_power_dispatch(self, sample_weights):
        """Test 'extreme_power' dispatches correctly."""
        result = transform_weights(sample_weights, method="extreme_power")
        expected = extreme_power_transform(sample_weights)
        assert torch.allclose(result, expected)

    def test_invalid_transform_raises(self, sample_weights):
        """Test invalid transform name raises error."""
        with pytest.raises(ValueError):
            transform_weights(sample_weights, method="invalid_transform")
