import math

import pytest
import torch
from src.model.attention import chunked_attention, dot_product_attention
from torch import Tensor

torch.random.manual_seed(0)


@pytest.mark.parametrize(
    "query, key, value, mask, ssmax, bias",
    [
        (
            torch.randn((6, 4, 12, 32)),
            torch.randn((6, 4, 10, 32)),
            torch.randn((6, 4, 10, 32)),
            torch.rand((6, 12, 10)) > 0.5,
            torch.randn((6, 12)),
            torch.randn((6, 4, 12, 10)),
        ),
        (
            torch.randn((6, 4, 12, 32)),
            torch.randn((6, 4, 10, 32)),
            torch.randn((6, 4, 10, 32)),
            torch.rand((6, 12, 10)) > 0.5,
            None,
            torch.randn((6, 4, 12, 10)),
        ),
        (
            torch.randn((6, 5, 14, 32)),
            torch.randn((6, 5, 10, 32)),
            torch.randn((6, 5, 10, 32)),
            torch.rand((6, 14, 10)) > 0.5,
            torch.randn((6, 14)),
            None,
        ),
        (
            torch.randn((6, 4, 12, 32)),
            torch.randn((6, 4, 10, 32)),
            torch.randn((6, 4, 10, 32)),
            torch.rand((6, 12, 10)) > 0.5,
            None,
            None,
        ),
    ],
)
def test_dot_product_attention(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    mask: Tensor,
    ssmax: Tensor | None,
    bias: Tensor | None,
):
    _, _, _, embed_dim = query.shape

    # NOTE: Computations are not as accurate on CPU. And strange things seems to happen on the
    # compilation side.
    device = "cuda" if torch.cuda.is_available() else "cpu"
    query, key, value = query.to(device), key.to(device), value.to(device)
    mask = mask.to(device)
    ssmax = ssmax.to(device) if ssmax is not None else None
    bias = bias.to(device) if bias is not None else None

    qk = torch.einsum("bhle,bhse->bhls", query, key) / math.sqrt(embed_dim)

    if bias is not None:
        qk = qk + bias

    if ssmax is not None:
        qk = torch.einsum("bhls,bl->bhls", qk, ssmax)

    qk = torch.where(mask[:, None], qk, -torch.inf)
    a = torch.softmax(qk, dim=-1)
    y = torch.einsum("bhls,bhse->bhle", a, value)

    y_hat = dot_product_attention(query, key, value, mask, ssmax=ssmax, bias=bias)
    assert torch.allclose(y, y_hat, atol=1e-6)


@pytest.mark.parametrize(
    "chunk_size, query, key, value, mask, ssmax, bias",
    [
        (
            6,
            torch.randn((6, 4, 12, 32)),
            torch.randn((6, 4, 10, 32)),
            torch.randn((6, 4, 10, 32)),
            torch.rand((6, 12, 10)) > 0.5,
            torch.randn((6, 12)),
            torch.randn((6, 4, 12, 10)),
        ),
        (
            5,
            torch.randn((6, 4, 12, 32)),
            torch.randn((6, 4, 10, 32)),
            torch.randn((6, 4, 10, 32)),
            torch.rand((6, 12, 10)) > 0.5,
            None,
            torch.randn((6, 4, 12, 10)),
        ),
        (
            14,
            torch.randn((6, 5, 14, 32)),
            torch.randn((6, 5, 10, 32)),
            torch.randn((6, 5, 10, 32)),
            torch.rand((6, 14, 10)) > 0.5,
            torch.randn((6, 14)),
            None,
        ),
        (
            1,
            torch.randn((6, 4, 12, 32)),
            torch.randn((6, 4, 10, 32)),
            torch.randn((6, 4, 10, 32)),
            torch.rand((6, 12, 10)) > 0.5,
            None,
            None,
        ),
    ],
)
def test_chunked_attention(
    chunk_size: int,
    query: Tensor,
    key: Tensor,
    value: Tensor,
    mask: Tensor,
    ssmax: Tensor | None,
    bias: Tensor | None,
):
    # NOTE: Computations are not as accurate on CPU. And strange things seems to happen on the
    # compilation side.
    device = "cuda" if torch.cuda.is_available() else "cpu"
    query, key, value = query.to(device), key.to(device), value.to(device)
    mask = mask.to(device)
    ssmax = ssmax.to(device) if ssmax is not None else None
    bias = bias.to(device) if bias is not None else None

    y = dot_product_attention(query, key, value, mask, ssmax, bias)
    y_chunked = chunked_attention(chunk_size, query, key, value, mask, ssmax, bias)
    assert torch.allclose(y, y_chunked, atol=1e-6)
