import torch

from hallucinations.features.attention_weights import (
    attention_diagonal,
    laplacian_diagonal_from_attn,
)
from hallucinations.features.sink_scores import compute_sink_score_per_token_from_attention_matrix


def test_compute_sink_score_per_token_from_attention_matrix_basic() -> None:
    attn_matrix = torch.tensor(
        [
            [[[1, 0], [0, 1]]],  # layer 0
            [[[1, 0], [0, 1]]],  # layer 1
        ]
    )
    sink_scores = compute_sink_score_per_token_from_attention_matrix(attn_matrix)
    target_sink_scores = torch.tensor([[[0.5, 1.0]], [[0.5, 1.0]]])
    torch.testing.assert_close(sink_scores, target_sink_scores)


def test_compute_sink_score_per_token_from_attention_matrix() -> None:
    attn_matrix = torch.tensor(
        [
            [  # layer 0
                [[1, 0, 0], [0, 1, 0], [0, 0, 1]],  # head 0
                [[1, 0, 0], [0, 1, 0], [0, 0, 1]],  # head 1
            ],
            [  # layer 1
                [[1, 0, 0], [0, 1, 0], [0, 0, 1]],  # head 0
                [[1, 0, 0], [0, 1, 0], [0, 0, 1]],  # head 1
            ],
        ]
    )
    sink_scores = compute_sink_score_per_token_from_attention_matrix(attn_matrix)
    target_sink_scores = torch.tensor(
        [
            [  # layer 0
                [1 / 3, 1 / 2, 1.0],  # head 0
                [1 / 3, 1 / 2, 1.0],  # head 1
            ],
            [  # layer 1
                [1 / 3, 1 / 2, 1.0],  # head 0
                [1 / 3, 1 / 2, 1.0],  # head 1
            ],
        ]
    )
    torch.testing.assert_close(sink_scores, target_sink_scores)


def test_compute_sink_score_per_token_from_attention_matrix_various_attention_patterns() -> None:
    attn_matrix = torch.tensor(
        [
            [  # layer 0
                [[1, 0, 0], [0.5, 0.5, 0], [0.25, 0.5, 0.25]],  # head 0
                [[1, 0, 0], [0.2, 0.8, 0], [0.1, 0.8, 0.1]],  # head 1
            ],
            [  # layer 1
                [[1, 0, 0], [0.3, 0.7, 0], [0.2, 0.7, 0.1]],  # head 0
                [[1, 0, 0], [0.4, 0.6, 0], [0.1, 0.6, 0.3]],  # head 1
            ],
        ]
    )
    sink_scores = compute_sink_score_per_token_from_attention_matrix(attn_matrix)
    target_sink_scores = torch.tensor(
        [
            [  # layer 0
                [(1 + 0.5 + 0.25) / 3, (0.5 + 0.5) / 2, 0.25],  # head 0
                [(1 + 0.2 + 0.1) / 3, (0.8 + 0.8) / 2, 0.1],  # head 1
            ],
            [  # layer 1
                [(1 + 0.3 + 0.2) / 3, (0.7 + 0.7) / 2, 0.1],  # head 0
                [(1 + 0.4 + 0.1) / 3, (0.6 + 0.6) / 2, 0.3],  # head 1
            ],
        ]
    )
    torch.testing.assert_close(sink_scores, target_sink_scores)


def test_laplacia_and_attention_diagonal_equivalent_to_sink_score_per_token() -> None:
    attn_matrix = torch.tensor(
        [
            [  # layer 0
                [[1, 0, 0], [0.5, 0.5, 0], [0.25, 0.5, 0.25]],  # head 0
                [[1, 0, 0], [0.2, 0.8, 0], [0.1, 0.8, 0.1]],  # head 1
            ],
            [  # layer 1
                [[1, 0, 0], [0.3, 0.7, 0], [0.2, 0.7, 0.1]],  # head 0
                [[1, 0, 0], [0.4, 0.6, 0], [0.1, 0.6, 0.3]],  # head 1
            ],
        ]
    )
    laplacian_diags = laplacian_diagonal_from_attn(attn_matrix, vertical_edges=False)
    attn_diags = attention_diagonal(attn_matrix)
    sink_scores = compute_sink_score_per_token_from_attention_matrix(attn_matrix)
    torch.testing.assert_close(sink_scores, laplacian_diags + attn_diags)
