import numpy as np
import pytest
import torch

from hallucinations.features.mtopdiv import (
    compute_mtopdiv,
    transform_attention_scores_to_distances,
    transform_distances_to_mtopdiv,
)


def test_transform_attention_scores_to_distances_symmetry() -> None:
    attn = np.array([[[1.0, 0.0], [0.5, 0.5]]], dtype=np.float32)
    distance = transform_attention_scores_to_distances(attn)
    # the matix is symmetrized
    expected = np.array([[[0.0, 0.5], [0.5, 0.0]]], dtype=np.float32)
    assert np.allclose(distance, expected)


@pytest.mark.skip(reason="eye matrix is not a valid distance matrix")
def test_transform_distances_to_mtopdiv_identity_is_zero() -> None:
    distance = np.eye(3, dtype=np.float32)
    assert transform_distances_to_mtopdiv(distance) == 0.0


def test_compute_mtopdiv_shape() -> None:
    attn = torch.tensor(
        [
            [[[0.0, 0.5], [0.5, 0.0]]],
            [[[0.0, 0.5], [0.5, 0.0]]],
        ],
        dtype=torch.float32,
    )
    scores = compute_mtopdiv(attn, response_length=1, n_jobs=1)
    assert scores.shape == (2, 1)


def test_compute_mtopdiv_normalizes_by_response_length() -> None:
    attn = torch.tensor(
        [[[[0.0, 0.5], [0.5, 0.0]]]],
        dtype=torch.float32,
    )
    score_len_1 = compute_mtopdiv(attn, response_length=1, n_jobs=1)[0, 0].item()
    score_len_2 = compute_mtopdiv(attn, response_length=2, n_jobs=1)[0, 0].item()
    assert np.isclose(score_len_1 / 2.0, score_len_2)


class TestMtopdivReferenceData:
    """Tests against reference data from original toha implementation."""

    @pytest.fixture
    def attention_weights(self) -> list[np.ndarray]:
        return [
            np.array(
                [
                    [
                        [1.0, 0.0, 0.0, 0.0, 0.0],
                        [0.7286803, 0.2713197, 0.0, 0.0, 0.0],
                        [0.01129194, 0.5320589, 0.45664915, 0.0, 0.0],
                        [0.1269808, 0.21064326, 0.3633171, 0.29905877, 0.0],
                        [0.32790893, 0.07475863, 0.15656842, 0.19634348, 0.24442056],
                    ],
                    [
                        [1.0, 0.0, 0.0, 0.0, 0.0],
                        [0.78083676, 0.21916325, 0.0, 0.0, 0.0],
                        [0.66771996, 0.25160486, 0.0806752, 0.0, 0.0],
                        [0.07818333, 0.31723318, 0.02203087, 0.5825526, 0.0],
                        [0.29764694, 0.14004035, 0.2336475, 0.24561687, 0.08304833],
                    ],
                ],
                dtype=np.float32,
            ),
            np.array(
                [
                    [
                        [1.0, 0.0, 0.0, 0.0, 0.0],
                        [0.9124155, 0.08758455, 0.0, 0.0, 0.0],
                        [0.26107386, 0.18226467, 0.5566615, 0.0, 0.0],
                        [0.34780008, 0.09031475, 0.51410764, 0.04777761, 0.0],
                        [0.3090461, 0.07952441, 0.00220991, 0.32634106, 0.28287858],
                    ],
                    [
                        [1.0, 0.0, 0.0, 0.0, 0.0],
                        [0.5806664, 0.4193336, 0.0, 0.0, 0.0],
                        [0.19214931, 0.43112087, 0.37672973, 0.0, 0.0],
                        [0.05549871, 0.33098722, 0.35304868, 0.26046538, 0.0],
                        [0.31304795, 0.3313931, 0.27104503, 0.01611478, 0.06839913],
                    ],
                ],
                dtype=np.float32,
            ),
        ]

    @pytest.fixture
    def expected_distances(self) -> list[np.ndarray]:
        return [
            np.array(
                [
                    [0.0, 0.2713197, 0.9887081, 0.8730192, 0.67209107],
                    [0.2713197, 0.0, 0.4679411, 0.7893567, 0.92524135],
                    [0.9887081, 0.4679411, 0.0, 0.63668287, 0.8434316],
                    [0.8730192, 0.7893567, 0.63668287, 0.0, 0.8036565],
                    [0.67209107, 0.92524135, 0.8434316, 0.8036565, 0.0],
                ],
                dtype=np.float32,
            ),
            np.array(
                [
                    [0.0, 0.21916324, 0.33228004, 0.92181665, 0.70235306],
                    [0.21916324, 0.0, 0.74839514, 0.6827668, 0.85995966],
                    [0.33228004, 0.74839514, 0.0, 0.9779691, 0.76635253],
                    [0.92181665, 0.6827668, 0.9779691, 0.0, 0.75438315],
                    [0.70235306, 0.85995966, 0.76635253, 0.75438315, 0.0],
                ],
                dtype=np.float32,
            ),
            np.array(
                [
                    [0.0, 0.0875845, 0.7389262, 0.6521999, 0.6909539],
                    [0.0875845, 0.0, 0.8177353, 0.90968525, 0.9204756],
                    [0.7389262, 0.8177353, 0.0, 0.48589236, 0.9977901],
                    [0.6521999, 0.90968525, 0.48589236, 0.0, 0.67365897],
                    [0.6909539, 0.9204756, 0.9977901, 0.67365897, 0.0],
                ],
                dtype=np.float32,
            ),
            np.array(
                [
                    [0.0, 0.41933358, 0.8078507, 0.9445013, 0.68695205],
                    [0.41933358, 0.0, 0.5688791, 0.6690128, 0.6686069],
                    [0.8078507, 0.5688791, 0.0, 0.6469513, 0.728955],
                    [0.9445013, 0.6690128, 0.6469513, 0.0, 0.9838852],
                    [0.68695205, 0.6686069, 0.728955, 0.9838852, 0.0],
                ],
                dtype=np.float32,
            ),
        ]

    @pytest.fixture
    def expected_mtopdivs(self) -> list[float]:
        return [
            0.654386967420578,
            0.6925599277019501,
            0.5797756612300873,
            0.6577790975570679,
        ]

    @pytest.mark.parametrize(
        ("layer", "head"),
        [(0, 0), (0, 1), (1, 0), (1, 1)],
    )
    def test_transform_attention_to_distances_reference(
        self,
        attention_weights: list[np.ndarray],
        expected_distances: list[np.ndarray],
        layer: int,
        head: int,
    ) -> None:
        idx = layer * 2 + head
        attn = attention_weights[layer][head : head + 1]
        result = transform_attention_scores_to_distances(attn)
        assert np.allclose(result[0], expected_distances[idx], rtol=1e-5)

    @pytest.mark.parametrize(
        ("layer", "head"),
        [(0, 0), (0, 1), (1, 0), (1, 1)],
    )
    def test_transform_distances_to_mtopdiv_reference(
        self,
        expected_distances: list[np.ndarray],
        expected_mtopdivs: list[float],
        layer: int,
        head: int,
    ) -> None:
        idx = layer * 2 + head
        distance = expected_distances[idx].copy()
        prompt_len = 3
        distance[:prompt_len, :prompt_len] = 0.0
        raw_score = transform_distances_to_mtopdiv(distance)
        normalized_score = raw_score / 2.0
        assert np.isclose(normalized_score, expected_mtopdivs[idx], rtol=1e-5)

    def test_compute_mtopdiv_full_reference(
        self,
        attention_weights: list[np.ndarray],
        expected_mtopdivs: list[float],
    ) -> None:
        attn = torch.tensor(
            np.stack(attention_weights, axis=0),
            dtype=torch.float32,
        )
        result = compute_mtopdiv(attn, response_length=2, n_jobs=1)
        expected = np.array(expected_mtopdivs).reshape(2, 2)
        assert np.allclose(result.numpy(), expected, rtol=1e-5)
