import math

import pykt.models.akt
import torch
import torch.nn.functional as F
from torch import nn

from pkg.model.utils.attentions.AKTMonotonicAttention import AKTMonotonicAttention
from pkg.utils.reproduce import seed_everything

pykt.models.akt.device = "cpu"


def test_akt_attention_pykt_equivalence():

    batch_size = 3
    nheads = 4
    S = 7
    d_model_per_head = 17
    dropout = 0.0
    B_times_nheads = batch_size * nheads

    data = torch.randn((B_times_nheads, S, d_model_per_head))
    attn_mask = torch.triu(torch.full((S, S), float("-inf")), diagonal=1)

    #### AKTMonotonicAttention
    seed_everything(0)
    akt_mono_attn = AKTMonotonicAttention(num_heads=nheads)

    q_scaled = data / math.sqrt(d_model_per_head)
    attn_output_weights = torch.bmm(q_scaled, data.transpose(-2, -1))

    akt_output = akt_mono_attn.forward(
        attn_output_weights=attn_output_weights, attn_mask=attn_mask
    )

    akt_output = F.softmax(akt_output, dim=-1)
    if dropout > 0.0:
        akt_output = F.dropout(akt_output, p=dropout)

    akt_output = torch.bmm(akt_output, data)
    assert akt_output.shape == (B_times_nheads, S, d_model_per_head)

    #### AKT
    q = k = v = data.reshape(-1, nheads, S, d_model_per_head)
    mask = (attn_mask == 0).unsqueeze(0)
    akt_pykt_output = pykt.models.akt.attention(
        q=q,
        k=k,
        v=v,
        d_k=d_model_per_head,
        mask=mask,
        dropout=nn.Dropout(dropout),
        gamma=akt_mono_attn.gamma,
        pdiff=None,
        zero_pad=False,
    )
    akt_pykt_output = akt_pykt_output.reshape(B_times_nheads, S, d_model_per_head)

    #### Check
    assert torch.allclose(akt_output, akt_pykt_output, atol=0.0001)
