import pytest
import torch

from hallucinations.features.lookback_lens import (
    compute_lookback_lens,
    compute_lookback_lens_per_token,
    compute_lookback_ratio_per_token,
)


class TestComputeLookbackRatioPerToken:
    def test_all_attention_on_context_returns_one(self) -> None:
        attention = torch.tensor(
            [
                [1.0, 0.0, 0.0],
                [0.5, 0.5, 0.0],
                [1.0, 0.0, 0.0],  # generated token attends only to context
            ],
            dtype=torch.float32,
        )
        input_length = 2
        result = compute_lookback_ratio_per_token(attention, input_length)
        assert torch.allclose(result, torch.tensor([1.0]))

    def test_all_attention_on_generated_returns_zero(self) -> None:
        attention = torch.tensor(
            [
                [1.0, 0.0, 0.0],
                [0.5, 0.5, 0.0],
                [0.0, 0.0, 1.0],  # generated token attends only to itself
            ],
            dtype=torch.float32,
        )
        input_length = 2
        result = compute_lookback_ratio_per_token(attention, input_length)
        assert torch.allclose(result, torch.tensor([0.0]))

    def test_equal_attention_split_returns_half(self) -> None:
        attention = torch.tensor(
            [
                [1.0, 0.0, 0.0],
                [0.5, 0.5, 0.0],
                [0.5, 0.0, 0.5],  # half to context, half to generated
            ],
            dtype=torch.float32,
        )
        input_length = 2
        result = compute_lookback_ratio_per_token(attention, input_length)
        assert torch.allclose(result, torch.tensor([0.5]))

    def test_single_generated_token_only_attends_to_context(self) -> None:
        attention = torch.tensor(
            [
                [1.0, 0.0, 0.0],
                [0.3, 0.7, 0.0],
                [0.4, 0.6, 0.0],  # first generated token, no self-attention yet
            ],
            dtype=torch.float32,
        )
        input_length = 2
        result = compute_lookback_ratio_per_token(attention, input_length)
        # all attention goes to context (positions 0 and 1)
        assert torch.allclose(result, torch.tensor([1.0]))

    def test_multiple_generated_tokens(self) -> None:
        attention = torch.tensor(
            [
                [1.0, 0.0, 0.0, 0.0],
                [0.5, 0.5, 0.0, 0.0],
                [0.4, 0.4, 0.2, 0.0],  # 80% context, 20% generated
                [0.2, 0.2, 0.2, 0.4],  # 40% context, 60% generated
            ],
            dtype=torch.float32,
        )
        input_length = 2
        result = compute_lookback_ratio_per_token(attention, input_length)
        expected = torch.tensor([0.8, 0.4])
        assert torch.allclose(result, expected)

    def test_zero_attention_raises_error(self) -> None:
        attention = torch.tensor(
            [
                [1.0, 0.0, 0.0],
                [0.5, 0.5, 0.0],
                [0.0, 0.0, 0.0],  # zero attention row
            ],
            dtype=torch.float32,
        )
        input_length = 2
        with pytest.raises(ValueError):
            compute_lookback_ratio_per_token(attention, input_length)

    def test_no_generated_tokens_raises_error(self) -> None:
        attention = torch.tensor(
            [
                [1.0, 0.0],
                [0.5, 0.5],
            ],
            dtype=torch.float32,
        )
        input_length = 2
        with pytest.raises(ValueError, match="No generated tokens"):
            compute_lookback_ratio_per_token(attention, input_length)

    def test_ratios_bounded_zero_to_one(self) -> None:
        torch.manual_seed(42)
        seq_len = 10
        input_length = 3
        attention = torch.tril(torch.rand(seq_len, seq_len))
        attention = attention / attention.sum(dim=-1, keepdim=True)

        result = compute_lookback_ratio_per_token(attention, input_length)
        assert (result >= 0.0).all()
        assert (result <= 1.0).all()


class TestComputeLookbackLensPerToken:
    def test_output_shape(self) -> None:
        num_layers = 2
        num_heads = 3
        seq_len = 5
        input_length = 2

        attn = torch.tril(torch.ones(num_layers, num_heads, seq_len, seq_len))
        attn = attn / attn.sum(dim=-1, keepdim=True)

        result = compute_lookback_lens_per_token(attn, input_length)
        expected_shape = (num_layers, num_heads, seq_len - input_length)
        assert result.shape == expected_shape

    def test_list_input_format(self) -> None:
        num_layers = 2
        num_heads = 3
        seq_len = 5
        input_length = 2

        layer_attn = torch.tril(torch.ones(num_heads, seq_len, seq_len))
        layer_attn = layer_attn / layer_attn.sum(dim=-1, keepdim=True)
        attn_list = [layer_attn.clone() for _ in range(num_layers)]

        result = compute_lookback_lens_per_token(attn_list, input_length)
        expected_shape = (num_layers, num_heads, seq_len - input_length)
        assert result.shape == expected_shape

    def test_no_generated_tokens_raises_error(self) -> None:
        num_layers = 2
        num_heads = 3
        seq_len = 4
        input_length = 4

        attn = torch.tril(torch.ones(num_layers, num_heads, seq_len, seq_len))
        with pytest.raises(ValueError, match="No generated tokens"):
            compute_lookback_lens_per_token(attn, input_length)


class TestComputeLookbackLens:
    def test_output_shape(self) -> None:
        num_layers = 2
        num_heads = 3
        seq_len = 5
        input_length = 2

        attn = torch.tril(torch.ones(num_layers, num_heads, seq_len, seq_len))
        attn = attn / attn.sum(dim=-1, keepdim=True)

        result = compute_lookback_lens(attn, input_length)
        assert result.shape == (num_layers, num_heads)

    def test_mean_aggregation(self) -> None:
        attention = torch.tensor(
            [
                [
                    [
                        [1.0, 0.0, 0.0, 0.0],
                        [0.5, 0.5, 0.0, 0.0],
                        [0.4, 0.4, 0.2, 0.0],  # ratio = 0.8
                        [0.2, 0.2, 0.2, 0.4],  # ratio = 0.4
                    ]
                ]
            ],
            dtype=torch.float32,
        )
        input_length = 2
        result = compute_lookback_lens(attention, input_length)
        expected_mean = (0.8 + 0.4) / 2
        assert torch.allclose(result, torch.tensor([[expected_mean]]))

    def test_no_generated_tokens_raises_error(self) -> None:
        num_layers = 2
        num_heads = 3
        seq_len = 4
        input_length = 4

        attn = torch.tril(torch.ones(num_layers, num_heads, seq_len, seq_len))
        with pytest.raises(ValueError, match="No generated tokens"):
            compute_lookback_lens(attn, input_length)

    def test_all_context_attention_returns_ones(self) -> None:
        attention = torch.tensor(
            [
                [
                    [
                        [1.0, 0.0, 0.0],
                        [0.5, 0.5, 0.0],
                        [0.6, 0.4, 0.0],  # all to context
                    ],
                    [
                        [1.0, 0.0, 0.0],
                        [0.3, 0.7, 0.0],
                        [0.8, 0.2, 0.0],  # all to context
                    ],
                ]
            ],
            dtype=torch.float32,
        )
        input_length = 2
        result = compute_lookback_lens(attention, input_length)
        assert torch.allclose(result, torch.ones(1, 2))


class TestLookbackLensReferenceData:
    @pytest.fixture
    def attention_weights(self) -> torch.Tensor:
        return torch.tensor(
            [
                [
                    [
                        [1.0, 0.0, 0.0, 0.0, 0.0],
                        [0.7, 0.3, 0.0, 0.0, 0.0],
                        [0.3, 0.3, 0.4, 0.0, 0.0],
                        [0.2, 0.2, 0.3, 0.3, 0.0],
                        [0.1, 0.1, 0.2, 0.2, 0.4],
                    ],
                    [
                        [1.0, 0.0, 0.0, 0.0, 0.0],
                        [0.5, 0.5, 0.0, 0.0, 0.0],
                        [0.4, 0.4, 0.2, 0.0, 0.0],
                        [0.25, 0.25, 0.25, 0.25, 0.0],
                        [0.2, 0.2, 0.2, 0.2, 0.2],
                    ],
                ],
                [
                    [
                        [1.0, 0.0, 0.0, 0.0, 0.0],
                        [0.6, 0.4, 0.0, 0.0, 0.0],
                        [0.5, 0.3, 0.2, 0.0, 0.0],
                        [0.4, 0.2, 0.2, 0.2, 0.0],
                        [0.3, 0.2, 0.1, 0.2, 0.2],
                    ],
                    [
                        [1.0, 0.0, 0.0, 0.0, 0.0],
                        [0.8, 0.2, 0.0, 0.0, 0.0],
                        [0.6, 0.2, 0.2, 0.0, 0.0],
                        [0.4, 0.2, 0.2, 0.2, 0.0],
                        [0.2, 0.2, 0.2, 0.2, 0.2],
                    ],
                ],
            ],
            dtype=torch.float32,
        )

    @pytest.fixture
    def expected_per_token_ratios(self) -> torch.Tensor:
        # input_length = 2, so generated tokens are at positions 2, 3, 4
        # Layer 0, Head 0:
        #   t=2: context=0.6, gen=0.4 -> 0.6
        #   t=3: context=0.4, gen=0.6 -> 0.4
        #   t=4: context=0.2, gen=0.8 -> 0.2
        # Layer 0, Head 1:
        #   t=2: context=0.8, gen=0.2 -> 0.8
        #   t=3: context=0.5, gen=0.5 -> 0.5
        #   t=4: context=0.4, gen=0.6 -> 0.4
        # Layer 1, Head 0:
        #   t=2: context=0.8, gen=0.2 -> 0.8
        #   t=3: context=0.6, gen=0.4 -> 0.6
        #   t=4: context=0.5, gen=0.5 -> 0.5
        # Layer 1, Head 1:
        #   t=2: context=0.8, gen=0.2 -> 0.8
        #   t=3: context=0.6, gen=0.4 -> 0.6
        #   t=4: context=0.4, gen=0.6 -> 0.4
        return torch.tensor(
            [
                [[0.6, 0.4, 0.2], [0.8, 0.5, 0.4]],
                [[0.8, 0.6, 0.5], [0.8, 0.6, 0.4]],
            ],
            dtype=torch.float32,
        )

    @pytest.fixture
    def expected_aggregated(self) -> torch.Tensor:
        # Mean of per-token ratios
        return torch.tensor(
            [
                [(0.6 + 0.4 + 0.2) / 3, (0.8 + 0.5 + 0.4) / 3],
                [(0.8 + 0.6 + 0.5) / 3, (0.8 + 0.6 + 0.4) / 3],
            ],
            dtype=torch.float32,
        )

    def test_per_token_reference(
        self,
        attention_weights: torch.Tensor,
        expected_per_token_ratios: torch.Tensor,
    ) -> None:
        input_length = 2
        result = compute_lookback_lens_per_token(attention_weights, input_length)
        assert torch.allclose(result, expected_per_token_ratios, rtol=1e-5)

    def test_aggregated_reference(
        self,
        attention_weights: torch.Tensor,
        expected_aggregated: torch.Tensor,
    ) -> None:
        input_length = 2
        result = compute_lookback_lens(attention_weights, input_length)
        assert torch.allclose(result, expected_aggregated, rtol=1e-5)

    @pytest.mark.parametrize(
        ("layer", "head"),
        [(0, 0), (0, 1), (1, 0), (1, 1)],
    )
    def test_individual_head_ratios(
        self,
        attention_weights: torch.Tensor,
        expected_per_token_ratios: torch.Tensor,
        layer: int,
        head: int,
    ) -> None:
        input_length = 2
        head_attn = attention_weights[layer, head]
        result = compute_lookback_ratio_per_token(head_attn, input_length)
        expected = expected_per_token_ratios[layer, head]
        assert torch.allclose(result, expected, rtol=1e-5)
