import torch
import math
from utils import get_causal_mask_like

x = torch.rand(24).reshape(1,2,3,4).float()
y = torch.rand(40).reshape(1,2,5,4).float()

w = []
for i in range(2):
    w.append(torch.matmul(x[0,i], y[0,i].T))
w = torch.vstack(w)[None]

print("w:")
print(w)
print(w.shape)
print()

z = torch.einsum("bnlp,bnsp->bnls", x, y)
print("z:")
print(z)
print(z.shape)

scale = math.sqrt(x.shape[-1])
attn_mask = get_causal_mask_like(z).repeat((1,z.shape[1],1,1))
print("amsk:")
print(attn_mask)
print(attn_mask.shape)
if attn_mask is not None:
    z = z.masked_fill_(attn_mask, float(-math.inf))
foo = torch.softmax(z, dim=-1)
print("foo", foo)
print("foo sum", foo.sum(-1))

