import torch
from torch import nn

from pkg.model.utils.attentions import BaseMultiheadAttention
from pkg.utils.reproduce import seed_everything


def test_standard_mha_equivalence():
    B, S, d = (11, 7, 16)

    data = torch.randn((B, S, d))

    d_model = 16
    nheads = 4
    dropout = 0.0

    seed_everything(0)
    base_mha = BaseMultiheadAttention(
        embed_dim=d_model, num_heads=nheads, dropout=dropout, attn_variant="standard"
    )

    seed_everything(0)
    standard_mha = nn.MultiheadAttention(
        embed_dim=d_model, num_heads=nheads, dropout=dropout, batch_first=True
    )

    # positive case
    out_base_mha, *_ = base_mha.forward(query=data, key=data, value=data)
    out_standard_mha, *_ = standard_mha.forward(query=data, key=data, value=data)
    assert torch.equal(out_base_mha, out_standard_mha)

    # negative case
    out_base_mha, *_ = base_mha.forward(query=data, key=data, value=data)
    out_standard_mha, *_ = standard_mha.forward(query=data, key=data + 1, value=data)
    assert not torch.equal(out_base_mha, out_standard_mha)


def test_standard_mha_equivalence_w_dropout():
    B, S, d = (11, 7, 16)

    data = torch.randn((B, S, d))

    d_model = 16
    nheads = 4
    dropout = 0.1

    seed_everything(0)
    base_mha = BaseMultiheadAttention(
        embed_dim=d_model, num_heads=nheads, dropout=dropout, attn_variant="standard"
    )

    seed_everything(0)
    standard_mha = nn.MultiheadAttention(
        embed_dim=d_model, num_heads=nheads, dropout=dropout, batch_first=True
    )

    base_mha.train()
    standard_mha.train()

    # positive case
    seed_everything(0)
    out_base_mha, *_ = base_mha.forward(query=data, key=data, value=data)
    seed_everything(0)
    out_standard_mha, *_ = standard_mha.forward(query=data, key=data, value=data)
    assert torch.equal(out_base_mha, out_standard_mha)

    # negative case
    seed_everything(0)
    out_base_mha, *_ = base_mha.forward(query=data, key=data, value=data)
    seed_everything(0)
    out_standard_mha, *_ = standard_mha.forward(query=data, key=data + 1, value=data)
    assert not torch.equal(out_base_mha, out_standard_mha)

    # TODO: Add .eval()
