import torch
from transformers import BatchEncoding

from lib_llm.eval.metrics.eval_tasks import _compute_token_probs
from lib_llm.eval.metrics.token_level import sequence_token_probs


def test_compute_token_probabilities_same_length():
    torch.manual_seed(853)
    input_ids = torch.randint(0, 64, (32, 8))
    attention_mask = torch.ones_like(input_ids)
    logits = torch.randn(32, 7, 64)
    logits_softmax = torch.nn.functional.log_softmax(logits, dim=-1)

    input_encoding = BatchEncoding(
        dict(input_ids=input_ids, attention_mask=attention_mask)
    )
    probs = _compute_token_probs(logits, input_encoding)
    token_probs = sequence_token_probs(probs, input_encoding)
    assert len(token_probs.values) == 32
    for i in range(len(input_ids)):  # enumerate(zip(probs, input_ids)):
        prob = token_probs.masked[i]
        assert prob.shape == (7,)
        token_ids = input_ids[i]
        # We check that the output probabilities for i-th token are
        # returned for the i+1-th token in the input_ids
        for j, token_id in enumerate(token_ids[1:]):
            torch.testing.assert_close(prob[j], logits_softmax[i, j, token_id])


def test_compute_token_probabilities_with_attention_mask():
    torch.manual_seed(69)
    input_ids = torch.randint(0, 64, (32, 8))
    attention_mask = torch.ones_like(input_ids)
    attention_mask[:, 0] = 0
    logits = torch.randn(32, 7, 64)
    logits_softmax = torch.nn.functional.log_softmax(logits, dim=-1)

    input_encoding = BatchEncoding(
        dict(input_ids=input_ids, attention_mask=attention_mask)
    )
    probs = _compute_token_probs(logits, input_encoding)
    token_probs = sequence_token_probs(probs, input_encoding)
    assert len(token_probs.values) == 32
    for i, (prob, token_ids) in enumerate(zip(token_probs.masked, input_ids)):
        assert prob.shape == (6,)
        for j, token_id in enumerate(token_ids[2:]):
            torch.testing.assert_close(
                prob[j], logits_softmax[i, j + 1, token_id]
            )


def test_compute_token_probabilities_of_differnt_length_sequences():
    input_ids = torch.tensor(
        [
            [1, 2, 3, 4, 5, 6, 7, 8],
            [0, 0, 1, 2, 3, 4, 5, 6],
            [0, 0, 0, 0, 1, 2, 3, 4],
            [0, 1, 2, 3, 4, 5, 6, 7],
        ]
    )
    attention_mask = torch.tensor(
        [
            [1, 1, 1, 1, 1, 1, 1, 1],
            [0, 0, 1, 1, 1, 1, 1, 1],
            [0, 0, 0, 0, 1, 1, 1, 1],
            [0, 1, 1, 1, 1, 1, 1, 1],
        ]
    )
    logits = torch.ones(4, 7, 16)
    logits_softmax = torch.nn.functional.log_softmax(logits, dim=-1)

    input_encoding = BatchEncoding(
        dict(input_ids=input_ids, attention_mask=attention_mask)
    )
    probs = _compute_token_probs(logits, input_encoding)
    token_probs = sequence_token_probs(probs, input_encoding)

    masked_token_probs = token_probs.masked
    assert len(masked_token_probs) == 4
    assert len(masked_token_probs[0]) == 7
    assert len(masked_token_probs[1]) == 5
    assert len(masked_token_probs[2]) == 3
    assert len(masked_token_probs[3]) == 6

    constant_prob_value = logits_softmax[0, 0, 0].item()
    torch.testing.assert_close(
        token_probs.values,
        torch.stack(
            [
                constant_prob_value * torch.ones(7),
                torch.cat(
                    (torch.zeros(2), constant_prob_value * torch.ones(5))
                ),
                torch.cat(
                    (torch.zeros(4), constant_prob_value * torch.ones(3))
                ),
                torch.cat(
                    (torch.zeros(1), constant_prob_value * torch.ones(6))
                ),
            ]
        ),
    )
